top | item 39469388

(no title)

maytc | 2 years ago

Stupid question but I thought transformers have an O(n^2) memory usage. With 2M tokens, won't I need dozens and dozens of GPUs just to run the base LLaMA2 models?

discuss

order

kristjansson|2 years ago

FlashAttention(2)[0] reduces context-length space complexity to linear. Compute is still O(n^2) in length though, AFAIK, so we'd expect these long sequence lengths to take some time to compute.

I'm a bit out of my depth, but I think ultra-long exact-attention work like this also probably has to answer some questions about where to put the KV-cache before it can be used in practice?

[0]: https://arxiv.org/abs/2205.14135

makerdiety|2 years ago

Maybe computer processor hacks are used? Like, it's the equivalent of finding the eigenvalues of a matrix.

I'm not as familiar with CPUs as I am with mathematical concepts. I don't know what the name for the processor bit hacking tricks is called. But that's maybe the general idea for data compression for LLMs/transformer models on CPUs, I think.

After all, notice how data compression improvements are only multiples of two. 128k tokens and 2048k tokens. There's an implementation dependent CPU optimization hack going on in there somewhere.

ErikBjare|2 years ago

Such optimizations generally don't change the time complexity