looks like a nice overview. i’ve implemented neural ODEs in Jax for low dimensional problems and it works well, but I keep looking for a good, fast, CPU-first implementation that is good for models that fit in cache and don’t require a GPU or big Torch/TF machinery.
sitkack|1 year ago
JAX Talk: Diffrax https://www.youtube.com/watch?v=Jy5Jw8hNiAQ
yberreby|1 year ago
`vmap`-able differential equation solving is really cool.
[1]: https://dynamicfieldtheory.org/ [2]: https://github.com/patrick-kidger/equinox
marmaduke|1 year ago
jax is fun but as effective as i’d like for CPU
barrenko|1 year ago
kk58|1 year ago
Neural ODE reframes this: instead of focusing on the weights, focus on how they change. It sees training as finding a path from untrained to trained state. At each step, it uses ODE solvers to compute the next state, continuing for N steps till it reaches values matching training data. This gives you the solution for the trained network.