top | item 44093402

(no title)

que-encrypt | 9 months ago

pytorch xla is barely supported in the pytorch ecosystem (for instance, pytorch lightning still doesn't easily support tpu pods, with only a singular short page about google colab v2-8 tpus that is out of date. Then there are the various libraries/implementations with pytorch that have a .cuda(), etc. More limitations at: https://lightning.ai/docs/pytorch/stable/accelerators/tpu_fa...). I haven't worked with tensorflow, but I've heard it's a pain even when using gpus. JAX is the real deal, and does make my code transferrable between GPUs/TPUs relatively easily (excluding any custom pallas kernels for flash vs splash attention, but this is usually not a massive code change). However, with JAX, there are often not a bunch of pre-existing implementations due to network effects, etc.

discuss

order

No comments yet.