(no title)
smpanaro | 2 years ago
If you are generating token-by-token naively, you do need to pay the n^2 cost since every token must attend to all other tokens. Generating a sequence of 5 tokens starting from 2 (infer 3rd, infer 4th, infer 5th...) will be much faster than starting from 1024 (infer 1025th, infer 1026th, ...) since your n is smaller. But each time time it is n^2.
However that is a naive approach. There is a common optimization, KV caching[1], (on by default for HuggingFace models) that caches all the work from the prior step so you only have to compute the attention for the new token. So you get something like (infer 1025th, cache 1025 attend new 1 token over the other 1025, cache 1026 attend 1 new token over the other 1026, ...). Not quite constant time but much better than n^2.
I would imagine there are other optimizations too, but this is the one I've heard of.
[1] fairly code-y, but links to some other posts at the start: https://www.dipkumar.dev/becoming-the-unbeatable/posts/gpt-k...
No comments yet.