Stupid question but I thought transformers have an O(n^2) memory usage. With 2M tokens, won't I need dozens and dozens of GPUs just to run the base LLaMA2 models?
FlashAttention(2)[0] reduces context-length space complexity to linear. Compute is still O(n^2) in length though, AFAIK, so we'd expect these long sequence lengths to take some time to compute.
I'm a bit out of my depth, but I think ultra-long exact-attention work like this also probably has to answer some questions about where to put the KV-cache before it can be used in practice?
Maybe computer processor hacks are used? Like, it's the equivalent of finding the eigenvalues of a matrix.
I'm not as familiar with CPUs as I am with mathematical concepts. I don't know what the name for the processor bit hacking tricks is called. But that's maybe the general idea for data compression for LLMs/transformer models on CPUs, I think.
After all, notice how data compression improvements are only multiples of two. 128k tokens and 2048k tokens. There's an implementation dependent CPU optimization hack going on in there somewhere.
kristjansson|2 years ago
I'm a bit out of my depth, but I think ultra-long exact-attention work like this also probably has to answer some questions about where to put the KV-cache before it can be used in practice?
[0]: https://arxiv.org/abs/2205.14135
makerdiety|2 years ago
I'm not as familiar with CPUs as I am with mathematical concepts. I don't know what the name for the processor bit hacking tricks is called. But that's maybe the general idea for data compression for LLMs/transformer models on CPUs, I think.
After all, notice how data compression improvements are only multiples of two. 128k tokens and 2048k tokens. There's an implementation dependent CPU optimization hack going on in there somewhere.
ErikBjare|2 years ago