looks like a nice overview. i’ve implemented neural ODEs in Jax for low dimensional problems and it works well, but I keep looking for a good, fast, CPU-first implementation that is good for models that fit in cache and don’t require a GPU or big Torch/TF machinery.
marmaduke|1 year ago
sitkack|1 year ago
JAX Talk: Diffrax https://www.youtube.com/watch?v=Jy5Jw8hNiAQ
barrenko|1 year ago