(no title)
Mehdi2277 | 2 years ago
Additionally many operations that run on GPU but are just unsupported for TPU. Sparse tensors have pretty limited support and there's bunch of models that will crash on TPU and require refactoring. Sometimes pretty heavy thousands of lines refactoring.
edit: Pytorch is even worse. Pytorch does not implement efficient tpu device data loading and generally has poor performance no where comparable to tensorflow/jax numbers. I'm unaware of any pytorch benchmarks where tpu actually wins. For tensorflow/jax if you can get it running and your model suits tpu assumptions (so basic CNN) then yes it can be cost effective. For pytorch even simple cases tend to lose.
No comments yet.