top | item 41638410

(no title)

jdeaton | 1 year ago

JAX has a sub-system called Pallas[1] with a Triton-like programming model and an example implementation of Flash Attention [2]. It is quite fast. On TPUs I've heard that the XLA compiler already emits a flash-attention-like computation graph for a regular JAX implementation of attention so there's no need to have some specialized kernel in that case.

1. https://jax.readthedocs.io/en/latest/pallas/index.html

2. https://github.com/jax-ml/jax/blob/main/jax/experimental/pal...

discuss

order

No comments yet.