I recently tried to port my model to JAX. Got it all working the "JAX WAY", and I believe I did everything correct, with one neat top level .jit() applied to the training step. Unfortunately I could not replicate the performance boost of torch.compile(). I have not yet delved under the hood to find the culprit, but my model is fairly simple so I was sort of expecting JAX JIT to perform just as well if not better than torch.compile().Have anyone else had similiar experiences?
yberreby|4 months ago
leviliebvin|4 months ago