(no title)
timlarshanson | 1 year ago
After reading a few times, I gather that, rather than kernelizing or linearizing attention (which has been thoroughly explored in the literature), they are using a MLP to do run-time modelling of the attention operation. If that's the case (?), (which is interesting, sure): 1 -- Why did they not say this plainly. 2 -- Why does eq. 12 show the memory MLP being indexed by the key, whereas eq. 15 shows it indexed by the query? 3 -- What's with all the extra LSTM-esque forget and remember gates? Meh. Wouldn't trust it without ablations.
I guess if a MLP can model a radiance field (NeRF) well, stands to reason it can approx attention too. The Q,K,V projection matrices will need to be learned beforehand using standard training.
While the memory & compute savings are clear, uncertain if this helps with reasoning or generalization thereof. I doubt that too.
andy12_|1 year ago
The eq. 15 is simply the operation to query a value that was previously inserted in previous tokens using eq. 12.
Basically, for each autoregressively processed segmented you do:
1) Test-time inference: query values from memory with eq. 15.
2) Test-time training: associate new keys and values into the memory with the loss from eq. 12.
The forget and remember gates is because... well, the architecture in general is very similar to a LSTM, but using test-time gradient descent to decide what to insert to the long-term memory.
timlarshanson|1 year ago
Seems the implicit assumption then is that M(q) -> v 'looks like' or 'is smooth like' the dot product, otherwise 'train on keys, inference on queries' wouldn't work ? (safe assumption imo with that l2 norm & in general; unsafe if q and k are from different distributions).
Correct me if I'm wrong, but typically k and v are generated via affine projections K, V of the tokens; if M is matrix-valued and there are no forget and remember gates (to somehow approx the softmax?), then M = V K^-1
ljlolel|1 year ago