JAX code usually ends up being way faster than equivalent torch code for me, even with torch.compile. There are common performance killers, though. Notably, using Python control flow (if statements, loops) instead of jax.lax primitives (where, cond, scan, etc).
leviliebvin|4 months ago
pama|4 months ago