(no title)
dacox | 1 year ago
For example, consider a prompt sent to Llama 3.1 405B that uses 128k input tokens.
The KV cache will be 123GB. No matter how many GPUs you shard the model across, you are not fitting that KV cache in GPU memory (a H100 has 80GB)
mmoskal|1 year ago
Also, 405b has 8 KV heads of 128 size (hidden_size/num_attention_heads) times 126 layers [0] times 2 (K and V) times 2 bytes (bf16) is 504k per token. At FP8 it's 252k.
[0] https://huggingface.co/meta-llama/Meta-Llama-3.1-405B/blob/m...