(no title)
time_to_smile | 2 years ago
As an example: if you want to implement logistic regression in JAX, you need to optimize the weights. This is easy enough since this can be modeled as a single value, a matrix of weights. If you want to model a 2 layer MLP, now you have to use 2 matrices of weights (at least). You could treat this as two parameters to your function (which makes the derivative more complicated to manage) or you could concatenate the weights and split them up, etc. Annoying, but managable.
When you get to something like a diffusion model you now need to manage parameters for a variety of different, quite complex, models. It really helps if you can keep track of all these parameters in whatever data structure you like, but also trivially just call "grad" with regard to these and get your models derivative with respect to its parameters.
Pytrees make this incredibly simple, and is a major quality of life improvement in automatic differentiation.
No comments yet.