(no title)
mattjjatgoogle | 5 years ago
In [1]: from jax import grad
In [2]: def f(x):
...: if x > 0:
...: return 3. * x ** 2
...: else:
...: return 5. * x ** 3
...:
In [3]: grad(f)(1.)
Out[3]: DeviceArray(6., dtype=float32)
In [4]: grad(f)(-1.)
Out[4]: DeviceArray(15., dtype=float32)
In the above example, the control flow happens in Python, just as it would in PyTorch. (That's not surprising, since JAX grew out of the original Autograd [1]!)Structured control flow functions like lax.cond, lax.scan, etc exist so that you can, for example, stage control flow out of Python and into an end-to-end compiled XLA computation with jax.jit. In other words, some JAX transformations place more constraints on your Python code than others, but you can just opt into the ones you want. (More generally, the lax module lets you program XLA HLO pretty directly [2].)
Disclaimer: I work on JAX!
[1] https://github.com/hips/autograd [2] https://www.tensorflow.org/xla/operation_semantics
p1esk|5 years ago