(no title)
bodono
|
5 years ago
I know where you're coming from, but TF in my opinion was very user-hostile even on arrival. I can't tell you how much hair-pulling I did over tf.conds, tf.while_loops and the whole gather / scatter paradigm for simple indexing into arrays. I really think the people working on it wanted users to write TF code in a certain, particular way and made it really difficult to use it in other ways. Just thinking back on that time still raises my blood pressure! So far Jax is much better and I'm cautiously optimistic they have learned lessons from TF.
gas9S9zw3P9c|5 years ago
Having barriers of entry is not always a bad thing - it forces people to learn and understand concepts instead of blindly following and copying and pasting code from a Medium article and praying that it works.
But I agree with you that there are many different use cases. Those people who want to do high-level work (I have some images, just give me a classifier) shouldn't need to deal with that complexity. IMO the big mistake was trying to merge all these different use cases into one framework. Let's hope JAX doesn't go down the same route.
brilee|5 years ago
Not quite sure why you picked those particular examples... JAX also requires usage of lax.cond, lax.while_loop, and ops.segment_sum. Only gather has been improved with slice notation support. IMO, TF has landed on a pretty nice solution to cond/while_loop via AutoGraph.
joaogui1|5 years ago
iflp|5 years ago
JAX is indeed a different situation as it has a more original design (although TF1 came with a huge improvement in compilation speed, so maybe there were innovations under the hood). But I don't know if I like it. The framework itself is quite neat, but last time I checked, the accompanying NN libraries had horrifying designs.
MiroF|5 years ago
I'm ill-informed - but isn't that exactly what lax is?
mattjjatgoogle|5 years ago
Structured control flow functions like lax.cond, lax.scan, etc exist so that you can, for example, stage control flow out of Python and into an end-to-end compiled XLA computation with jax.jit. In other words, some JAX transformations place more constraints on your Python code than others, but you can just opt into the ones you want. (More generally, the lax module lets you program XLA HLO pretty directly [2].)
Disclaimer: I work on JAX!
[1] https://github.com/hips/autograd [2] https://www.tensorflow.org/xla/operation_semantics