(no title)
zackangelo | 2 months ago
The way it typically works in an attention block is: smaller portions of the Q, K and V linear layers are assigned to each node and are processed independently. Attention, rope norm etc is run on the node-specific output of that. Then, when the output linear layer is applied an "all reduce" is computed which combines the output of all the nodes.
EDIT: just realized it wasn't clear -- this means that each node ends up holding a portion of the KV cache specific to its KV tensor shards. This can change based on the specific style of attention (e.g., in GQA where there are fewer KV heads than ranks you end up having to do some replication etc)
liuliu|2 months ago
I am asking, however, is whether that will speed up decoding as linearly as it would for prefilling.
awnihannun|2 months ago
In our benchmarks with MLX / mlx-lm it's as much as 3.5x for token generation (decoding) at batch size 1 over 4 machines. In that case you are memory bandwidth bound so sharding the model and KV cache 4-ways means each machine only needs to access 1/4th as much memory.