top | item 22815505

(no title)

jimfleming | 5 years ago

I've noticed that many JAX libraries (including those from Google) seem to adopt an object-oriented style more similar to Torch/Keras rather than JAX's functional style demonstrated in modules like jax.experimental.stax. This is disappointing since stax is quite clean and these libraries seem to use a lot of hacks to make OO work with JAX. Is there an effort to implement and maintain more full-featured functional libraries in the jax/stax style?

discuss

order

alevskaya|5 years ago

I've been involved w. jax/stax/trax/flax - I think the real issue w. the stax-like functional form is that it gets unwieldy very quickly when dealing w. more complicated models that are natively general-graphs as opposed to simple sequential pipelines that can be trivially mapped to a combinator expression tree. Of course there are many solutions here, but ultimately if you're building an NN library you need to build something that ML researchers actually want to use daily, and that tends to look closer to hackable pytorch-like DSLs rather than higher-order functional code - which often looks elegant but tends to hurt readability and rework speed.