(no title)
alexandercheema | 1 year ago
The way this works is that each device holds a partition of the model (for now a continuous set of layers). E.g. let's say you have 3 devices and the model is 32 layers. Device 1 could hold layers 1-10, device 2 holds 11-20 and device 3 holds 21-32. Each device executes the layers it's responsible for and passes on the output of its last layer (the activations) to the next device.
The activations are ~8KB for Llama-3-8B and ~32KB for Llama-3-70B (it's linear in the number of parameters in that layer and Llama-3-70B has more layers). Generally the larger the model gets (in terms of parameters), the more layers it ends up having, so we end up with sub-linear scaling so I expect Llama-3-405B to have activations on the order of ~100KB.
This is totally acceptable to send over a local network. The main issue you run into is latency, not bandwidth. Since LLMs are autoregressive (tokens are generated serially), additional latency limits throughput. However, over a local network latency is generally very low (<5ms in my experience). And if not, it's still useful depending on the use-case since you can get a lot of throughput with pipeline parallelism (overlapping requests): https://pytorch.org/docs/stable/pipeline.html
No comments yet.