To expand on this link, this is probably the closest you're going to get to 'I'll "program" in LinAlg, and a JIT can compile it to whatever wonky way your HW requires.' right now. JAX implements a good portion of the Numpy interface - which is the most common interface for linear algebra-heavy code in Python - so you can often just write Numpy code, but with `jax.numpy` instead of `numpy`, then wrap it in a `jax.jit` to have it run on the GPU.
creata|11 months ago
imtringued|11 months ago
It genuinely deserves to exist alongside pytorch. It's not just Google's latest framework that you're forced to use to target TPUs.