(no title)
kanaffa12345 | 4 years ago
given who you are (googling your name) i'm surprised that you would say this. jax does not jit compile python in any sense of the word `Python`. jax is a tracing mechanism for a very particular set of "programs" specified using python; i put programs in quotes because it's not like you could even use it to trace through `if __name__ == "__main__"` since it doesn't know (and doesn't care) anything about python namespaces. it's right there in the first sentence of the description:
>JAX is Autograd and XLA
autograd for tracing and building the tape (wengert list) and xla for the backend (i.e., actual kernels). there is no sense in which jax will ever play a role in something like faster hash tables or more efficient loads/stores or virtual function calls.
in fact it doesn't even jit in the conventional understanding of jit, since there is no machine code that gets generated anew based on code paths (it simply picks different kernels and such that have already been compiled). not that i fault you for this substitution since everyone in ML does this (pytorch claims to jit as well).
erwincoumans|4 years ago
You miss my point that all of those efforts are making slow Python code run faster. So claiming that 'these two things have nothing to do with each other' is wrong, because they share 'making Python code run faster'.
Some of that involves making cpython faster, some of that means moving execution into c (numpy is mentioned in that PDF) and some involves jit and moving execution onto GPU or TPU (for example using XLA). The common part is 'making Python code run faster'. Some of that is automatic, some requires some manual effort.
Jax can jit some Python functions, but it cannot efficiently jit everything. That is what I meant by decoration and 'some effort'. For example replacing IF conditions by np.where etc. See also https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html
My background is in physics simulation, and I advise the Brax team, basically accelerating a physics engine written in Python run on accelerators, see https://github.com/google/brax The entire physics step, including collision detection and physics solver, is jit compiled.
kanaffa12345|4 years ago
no you miss my point
>making slow Python code run faster
there is not a single org anywhere in the world that uses pure python to do numerics. kids do that during their first linear algebra or ml class. that's it.
>For example replacing IF conditions by np.where etc
i've already addressed this - this is not jit compilation.