I've noticed that many JAX libraries (including those from Google) seem to adopt an object-oriented style more similar to Torch/Keras rather than JAX's functional style demonstrated in modules like jax.experimental.stax. This is disappointing since stax is quite clean and these libraries seem to use a lot of hacks to make OO work with JAX. Is there an effort to implement and maintain more full-featured functional libraries in the jax/stax style?
alevskaya|5 years ago