I'm curious what interop difficulties you've run into in JAX? In my experience, the JAX ecosystem is quite modular and most JAX libraries work pretty well together. Penzai's core visualization tooling should work for most JAX NN libraries out of the box, and Penzai's neural net components are compatible with existing JAX optimization libraries (like Optax) and data loaders (like tfds/seqio or grain).(Interop with PyTorch seems more difficult, of course!)
yklcs|1 year ago
1. Milestone paper introducing novel method is published with green-field implementation
2. Bunch of papers extend milestone paper with brown-field implementation
3. Goto 1
Most things in 1 are written in PyTorch, meaning 2 also has to be in PyTorch. I know this isn't JAX's fault, but I don't think JAX's philosophy to stay unopinionated and low-level is helping. Seems like the community agreeing on a single set of DL libraries around JAX will help it gain some momentum.