Am I the only one surprised that the author of einops is looking for work? In an era of an AI arms race between many big labs? If you’re rolling your own networks, I’d definitely reach out to this guy!
And for those of us just tuning in, what is an "autograd"? From skimming the article and some quick web searching, it looks like some sort of math function that gets used a lot in ML (and I assume other places) so there are implementations in a bunch of major libraries?
reverse-mode automatic differentiation to compute the gradient of a scalar with respect to some large vector of inputs that it's computed from, which is critically important for the most popular and most efficient mathematical optimization algorithms such as adam
you can't really implement it in a library in a normal programming language, because it has to introspect on how your program is computing everything, which is not something that a library can normally do; you have to implement it in a programming language, at least an embedded domain-specific language but in some cases something like a fortran compiler
Jax is unrolling for-loops during compilation. I don't know about the hypothetical performance of canonical Jax code for this example, but the comparison as it is seems unfair.
> 1. we test many computation graphs (graph is changing constantly)
2. many-many scalar operations with roughly 10k—100k nodes in each graph
3. every graph should be compiled and ran around 10k times both forward and backward
Now I'm curious to know where that is useful. Some kind of meta-learning approach?
Automatic differentiation isn't just for machine learning, it's useful in several other applications where you have an optimization problem built of many variables interacting in differentiable ways. One possible guess could be structure-from- motion or SLAM, where the graph is often dynamic and several thousands of nodes isn't uncommon (the nodes would be things such as camera poses and 3D landmarks for example). However, in that case there are other frameworks that are built with that scenario in mind (Ceres, GTSAM), and would probably be better baselines.
Robotic planning of liquid handling. There are many potential alternative ways to achieve the same result, each resulting in its own computation graph.
I imagine any kind of trajectory planning for many agents faces similar challenges (e.g. robots in a factory).
And 10k ops is not a lot! SDXL has around 4k ops (from my memory) on the forward pass, from this test, that means you spend 40ms each iteration on autodiff!
The author's workload is sort of different than the usual ML workload since the author's expression tree is large (10k nodes), while a modern neural net has a relatively smaller expression tree, maybe fewer than 100 for the larger neural nets?
Another commenter mentioned Dr.Jit which seems to be designed for this use case. This is a quote from their project page.
> Why did we create Dr.Jit, when dynamic derivative compilation is already possible using Python-based ML frameworks like JAX, Tensorflow, and PyTorch along with backends like XLA and TorchScript?
> The reason is related to the typical workloads: machine learning involves small-ish computation graphs that are, however, made of arithmetically intense operations like convolutions, matrix multiplications, etc. The application motivating Dr.Jit (differentiable rendering) creates giant and messy computation graphs consisting of 100K to millions of "trivial" nodes (elementary arithmetic operations). In our experience, ML compilation backends use internal representations and optimization passes that are too rich for this type of input, causing them to crash or time out during compilation. If you have encountered such issues, you may find Dr.Jit useful.
I recently had a similar need for my “real-time” Python program to do some autograd operations – but not on big tensors – and also ended up leveraging Rust as part of my solution. But I didn’t know about rustimport (or cppimport). That’s really slick, and I’m going to file that nugget away for later.
[+] [-] ricklamers|2 years ago|reply
[+] [-] stealthcat|2 years ago|reply
https://github.com/casadi/casadi https://github.com/mitsuba-renderer/drjit
DrJit is made by same author of pybind11 and nanobind.
[+] [-] yjftsjthsd-h|2 years ago|reply
[+] [-] kragen|2 years ago|reply
you can't really implement it in a library in a normal programming language, because it has to introspect on how your program is computing everything, which is not something that a library can normally do; you have to implement it in a programming language, at least an embedded domain-specific language but in some cases something like a fortran compiler
[+] [-] cl3misch|2 years ago|reply
https://jax.readthedocs.io/en/latest/faq.html#jit-decorated-...
[+] [-] 317070|2 years ago|reply
Now I'm curious to know where that is useful. Some kind of meta-learning approach?
[+] [-] dimatura|2 years ago|reply
[+] [-] arogozhnikov|2 years ago|reply
I imagine any kind of trajectory planning for many agents faces similar challenges (e.g. robots in a factory).
[+] [-] HarHarVeryFunny|2 years ago|reply
[+] [-] mratsim|2 years ago|reply
[+] [-] tekknolagi|2 years ago|reply
I recommend trying out TCC, which compiles C very fast.
[+] [-] xiphias2|2 years ago|reply
[+] [-] unknown|2 years ago|reply
[deleted]
[+] [-] liuliu|2 years ago|reply
[+] [-] markisus|2 years ago|reply
Another commenter mentioned Dr.Jit which seems to be designed for this use case. This is a quote from their project page.
> Why did we create Dr.Jit, when dynamic derivative compilation is already possible using Python-based ML frameworks like JAX, Tensorflow, and PyTorch along with backends like XLA and TorchScript?
> The reason is related to the typical workloads: machine learning involves small-ish computation graphs that are, however, made of arithmetically intense operations like convolutions, matrix multiplications, etc. The application motivating Dr.Jit (differentiable rendering) creates giant and messy computation graphs consisting of 100K to millions of "trivial" nodes (elementary arithmetic operations). In our experience, ML compilation backends use internal representations and optimization passes that are too rich for this type of input, causing them to crash or time out during compilation. If you have encountered such issues, you may find Dr.Jit useful.
[+] [-] ofou|2 years ago|reply
https://github.com/tinygrad/tinygrad
[+] [-] rurban|2 years ago|reply
[+] [-] billyjmc|2 years ago|reply
[+] [-] unknown|2 years ago|reply
[deleted]