top | item 24549195

(no title)

mattjjatgoogle | 5 years ago

The difference is that in TF1 you had to use tf.cond, tf.while_loop etc for differentiable control flow. In JAX you can differentiate Python control flow directly, e.g.:

  In [1]: from jax import grad
  
  In [2]: def f(x):
     ...:     if x > 0:
     ...:         return 3. * x ** 2
     ...:     else:
     ...:         return 5. * x ** 3
     ...:
  
  In [3]: grad(f)(1.)
  Out[3]: DeviceArray(6., dtype=float32)
  
  In [4]: grad(f)(-1.)
  Out[4]: DeviceArray(15., dtype=float32)
In the above example, the control flow happens in Python, just as it would in PyTorch. (That's not surprising, since JAX grew out of the original Autograd [1]!)

Structured control flow functions like lax.cond, lax.scan, etc exist so that you can, for example, stage control flow out of Python and into an end-to-end compiled XLA computation with jax.jit. In other words, some JAX transformations place more constraints on your Python code than others, but you can just opt into the ones you want. (More generally, the lax module lets you program XLA HLO pretty directly [2].)

Disclaimer: I work on JAX!

[1] https://github.com/hips/autograd [2] https://www.tensorflow.org/xla/operation_semantics

discuss

order

p1esk|5 years ago

What would you say the main advantage of Jax is over Pytorch?