top | item 46901716

Hypernetworks: Neural Networks for Hierarchical Data

97 points| mkmccjr | 24 days ago |blog.sturdystatistics.com

10 comments

order

QueensGambit|23 days ago

Factorization is key here. It separates dataset-level structure from observation-level computation so the model doesn't waste capacity rediscovering structure.

I've been arguing the same for code generation. LLMs flatten parse trees into token sequences, then burn compute reconstructing hierarchy as hidden states. Graph transformers could be a good solution for both: https://manidoraisamy.com/ai-mother-tongue.html

mkmccjr|15 days ago

Thank you for your comment, and I sincerely apologize for my slow response! "Rediscovering structure" is exactly the inefficiency I was trying to highlight.

In the physics/science cases I work with, the factorization is usually between the physical law (shared structure) and the experimental conditions (dataset-specific structure). If you don't separate them, the model wastes capacity trying to memorize the noise of the experimental conditions. (It's ineffective as well as wasteful.)

The analogy to code generation makes a lot of sense: flattening a tree into a sequence forces the model to infer syntax that was already explicit. Thank you for the link; I look forward to diving into it!

stephantul|24 days ago

What a good post! I loved the takeaways at the end of each section.

I think it would maybe get more traction if the code was in pytorch or JAX. It’s been a long while since I’ve seen people use Keras.

mkmccjr|15 days ago

You are absolutely right about the code: I haven't worked with neural networks in a while and I guess my post outs me!

That said, I do like Keras's functional API, and in this case I think it maps nicely to the "math" of the hypernetwork.

I really appreciate your suggestion of more popular libraries, and I'll look into JAX.

joefourier|24 days ago

Odd that the author didn’t try giving a latent embedding to the standard neural network (or modulated the activations with a FiLM layer) and had static embeddings as the baseline. There’s no real advantage to using a hypernetwork and they tend to be more unstable and difficult to train, and scale poorly unless you train a low rank adaptation.

mkmccjr|24 days ago

Hello. I am the author of the post. The goal of this was to provide a pedagogical example of applying Bayesian hierarchical modeling principles to real world datasets. These datasets often contain inherent structure that is important to explicitly model (eg clinical trials across multiple hospitals). Oftentimes a single model cannot capture this over-dispersion but there is not enough data to split out the results (nor should you).

The idea behind hypernetworks is that they enable Gelman-style partial pooling to explicitly modeling the data generation process while leveraging the flexibility of neural network tooling. I’m curious to read more about your recommendations: their connection to the described problems is not immediately obvious to me but I would be curious to dig a bit deeper.

I agree that hypernetworks have some challenges associated with them due to the fragility of maximum likelihood estimates. In the follow-up post, I dug into how explicit Bayesian sampling addresses these issues.

mkmccjr|15 days ago

Thank you for reading my post, and for your thoughtful critique. And I sincerely apologize for my slow response! You are right that there are other ways to inject latent structure, and FiLM is a great example.

I admit the "static embedding" baseline is a bit of a strawman, but I used it to illustrate the specific failure mode of models that can't adapt at inference time.

I then used the Hypernetwork specifically to demonstrate a "dataset-adaptive" architecture as a stepping stone toward the next post in the series. My goal was to show how even a flexible parameter-generating model eventually hits a wall with out-of-sample stability; this sets the stage for the Bayesian Hierarchical approach I cover later on.

I wasn't familiar with the FiLM literature before your comment, but looking at it now, the connection is spot on. Functionally, it seems similar to what I did here: conditioning the network on an external variable. In my case, I wanted to explicitly model the mapping E->θ to see if the network could learn the underlying physics (Planck's law) purely from data.

As for stability, you are right that Hypernetworks can be tricky in high dimensions, but for this low-dimensional scalar problem (4D embedding), I found it converged reliably.

yobbo|24 days ago

I think a latent embedding is almost equivalent to the article's hypernetwork, which I assume as y = (Wh + c)v + b, where h is a dataset-specific trainable vector. (The article uses multiple layers ...)

keepamovin|23 days ago

This is actually the way to AGI, ngl. Come back when it lands and see that it's right.

mkmccjr|15 days ago

I appreciate the optimism! This specific example is just a pedagogical toy designed to be simple enough to analyze fully.

That said, I do agree with the intuition that static networks have a ceiling. If we want systems that can truly adapt to new contexts (like different hospitals or different physical laws) without retraining, we likely need dynamic architectures.