top | item 40109867

(no title)

ddjohnson | 1 year ago

Author of Penzai here! In idiomatic Penzai usage, you should always discharge all effects before running your model. While it's true you can't do `discharge_effect(jax.grad(your_model_here))`, you can still do `jax.grad(discharge_effect(your_model_here))`, which is probably what you meant to do anyway in most cases. Once you've wrapped your model in a handler layer, it has a pure interface again, which makes it fully compatible with all arbitrary JAX transformations. The intended use of effects is as an internal helper to simplify plumbing of values into and out of layers, not as something that affects the top-level interface of using the model!

(As an example of this, the GemmaTransformer example model uses the SideInput effect internally to do attention masking. But it exposes a pure functional interface by using a handler internally, so you can call it anywhere you could call an Equinox model, and you shouldn't have to think about the effect system at all as a user of the model.)

It's not clear to me what the semantics of ordinary JAX transformations like `lax.scan` should be if the model has side effects. But if you don't have any effects in your model, or if you've explicitly handled them already, then it's perfectly fine to use `lax.scan`. This is similar to how it works in ordinary JAX; if you try to do a `lax.scan` over a function that mutates Python state, you'll probably hit an error or get something unexpected. But if you mutate Python state internally inside `lax.scan`, it works fine.

I'll also note that adding support for higher-order layer combinators (like "layer scan") is something that's on the roadmap! The goal would be to support some of the fancier features of libraries like Flax when you need them, while still admitting a simple purely-functional mental model when you don't.

discuss

order

No comments yet.