(no title)
mattjjatgoogle | 3 years ago
> It's been a bit, but I think the most frustrating errors were around mapping pytrees (like this issue https://github.com/google/jax/issues/9928).
We've improved some of these pytree error messages but it seems that vmap one is still not great. Thanks for the ping on it.
> Also the barriers where I couldn't disable jit. IIRC pmap automatically jits, so there was no way to avoid staging that part out.
That was indeed a longstanding issue in pmap's implementation. And since people came to expect jit to be "built in" to pmap, it wasn't easy to revise.
However, we recently (https://github.com/google/jax/pull/11854) made `jax.disable_jit()` work with pmap, in the sense that it makes pmap execute eagerly, so that you can print/pdb/etc to your heart's content. (The pmap successor, shard_map (https://jax.readthedocs.io/en/latest/jep/14273-shard-map.htm...), is eager by default. Also it has uniformly good error messages from the start!)
> Next time I encounter something particularly opaque, I'll share on the github issue tracker.
Thank you for the constructive feedback!
6gvONxR4sf7o|3 years ago
Higher order functions are difficult in general, and it would be fantastic to have core patterns or tools for breaking them open.
patrickkidger|3 years ago
If so, then allow me to make my usual advert here for Equinox:
https://github.com/patrick-kidger/equinox
This actually works with JAX's native transformations. (There's no `equinox.vmap` for example.)
On higher-order functions more generally, Equinox offers a way to control these quite carefully, by making ubiquitous use of callables that are also pytrees. E.g. a neural network is both a callable in that it has a forward pass, and a pytree in that it records its parameters in its tree structure.
mattjjatgoogle|3 years ago
1. as you say, exposing patterns and tools for library authors to implement transformations/higher-order primitives using JAX's machinery rather than requiring each library to introduce bespoke magic to do the same;
2. adding JAX core infrastructure which directly solves the common problems that libraries tend to solve independently (and with bespoke magic).