top | item 24534243

(no title)

bodono | 5 years ago

I know where you're coming from, but TF in my opinion was very user-hostile even on arrival. I can't tell you how much hair-pulling I did over tf.conds, tf.while_loops and the whole gather / scatter paradigm for simple indexing into arrays. I really think the people working on it wanted users to write TF code in a certain, particular way and made it really difficult to use it in other ways. Just thinking back on that time still raises my blood pressure! So far Jax is much better and I'm cautiously optimistic they have learned lessons from TF.

discuss

order

gas9S9zw3P9c|5 years ago

I had the opposite experience. The early TF versions were difficult to use in that they required a lot of boilerplate code to do simple things, but at least there was no hidden complexity. I knew exactly what my code did and what was going on under the hood. When I use today's high-level opaque TF libraries I have no idea what's going on. It's much harder to debug subtle problems. The workflow went wrong "Damn, I need to write 200 lines of code to do this simple thing" to "I need to spend 1 hour looking through library documentations, gotchas, deprecation issues and TF-internal code to figure out which function to call with what parameters and check if it actually does exactly what I need" - I much prefer the former.

Having barriers of entry is not always a bad thing - it forces people to learn and understand concepts instead of blindly following and copying and pasting code from a Medium article and praying that it works.

But I agree with you that there are many different use cases. Those people who want to do high-level work (I have some images, just give me a classifier) shouldn't need to deal with that complexity. IMO the big mistake was trying to merge all these different use cases into one framework. Let's hope JAX doesn't go down the same route.

brilee|5 years ago

(googler)

Not quite sure why you picked those particular examples... JAX also requires usage of lax.cond, lax.while_loop, and ops.segment_sum. Only gather has been improved with slice notation support. IMO, TF has landed on a pretty nice solution to cond/while_loop via AutoGraph.

joaogui1|5 years ago

While jax has those operations you don't always need them, it depends on what transformations you want to do (JIT or grad) and they have been working on making normal control structures compatible with all transformations

iflp|5 years ago

You can't blame the TF people for things like while_loop. Those are inherited from Theano, and back then the dynamic graph idea wasn't obvious.

JAX is indeed a different situation as it has a more original design (although TF1 came with a huge improvement in compilation speed, so maybe there were innovations under the hood). But I don't know if I like it. The framework itself is quite neat, but last time I checked, the accompanying NN libraries had horrifying designs.

MiroF|5 years ago

> tf.conds, tf.while_loops and the whole gather / scatter paradigm

I'm ill-informed - but isn't that exactly what lax is?

mattjjatgoogle|5 years ago

The difference is that in TF1 you had to use tf.cond, tf.while_loop etc for differentiable control flow. In JAX you can differentiate Python control flow directly, e.g.:

  In [1]: from jax import grad
  
  In [2]: def f(x):
     ...:     if x > 0:
     ...:         return 3. * x ** 2
     ...:     else:
     ...:         return 5. * x ** 3
     ...:
  
  In [3]: grad(f)(1.)
  Out[3]: DeviceArray(6., dtype=float32)
  
  In [4]: grad(f)(-1.)
  Out[4]: DeviceArray(15., dtype=float32)
In the above example, the control flow happens in Python, just as it would in PyTorch. (That's not surprising, since JAX grew out of the original Autograd [1]!)

Structured control flow functions like lax.cond, lax.scan, etc exist so that you can, for example, stage control flow out of Python and into an end-to-end compiled XLA computation with jax.jit. In other words, some JAX transformations place more constraints on your Python code than others, but you can just opt into the ones you want. (More generally, the lax module lets you program XLA HLO pretty directly [2].)

Disclaimer: I work on JAX!

[1] https://github.com/hips/autograd [2] https://www.tensorflow.org/xla/operation_semantics