I have a project using tfjs and jax-js is very exciting alternative. However during porting I struggle a lot with `.ref` and `.dispose()` API. Coming from tfjs where you garbage collect with `tf.tidy(() => { ... })`, API in jax-js seems very low-level and error-prone. Is that something that can be improved or is it inherent to how jax-js works?
I don’t think tf.tidy() is a sound API under jvp/grad transformations, also it prevents you from using async which makes it incompatible with GPU backends (or blocks the page), a pretty big issue. https://github.com/tensorflow/tfjs/issues/5468
Thanks for the feedback though, just explaining how we arrived at this API. I hope you’d at least try it out — hopefully you will see when developing that the refs are more flexible than alternatives.
ekzhang|1 month ago
Thanks for the feedback though, just explaining how we arrived at this API. I hope you’d at least try it out — hopefully you will see when developing that the refs are more flexible than alternatives.
mlajtos|1 month ago