top | item 35116191

(no title)

wjessup | 3 years ago

The limitation is because of the word position embedding matrix size. This isn't a config issue, or an API limitation. This is a limitation on the size of a matrix that is part of the model and is decided on before training. You can't change it.

What does that mean?

For each token in your input or inference output it requires the model to have some understanding of what the position of the word means.

So there is the word position embedding matrix that contains a vector per position. The matrix has "only" 1024 entries in it for GPT2 or 4096 for GPT3. The size of each entry varies as well, containing a vector from 768 for GPT2 small and up to 12,288 for GPT3.

So the WPE (word position embeddings) for GPT2 is (1024x768) and for GPT3 (4096x12288)

Inference requires info from this vector to be added to the word tokens embedding for each token in the original prompt + each generated token.

discuss

order

kir-gadjello|3 years ago

Positional embeddings are tricky - it very much depends on the specific embedding method chosen. Some advanced methods allow conserved or even slightly improved performance with context length increased beyond what was used for the main pretraining run.

As often is the case with these large models, you can change it with some finetuning on longer context samples from the same dataset, with what is really a small amount of compute invested compared to the million hours spent on training the thing.

toxik|3 years ago

You get this issue without position embeddings. Attention computes an inner product between each pair of input tokens, so N^2 x E. Squares grow really fast.

visarga|3 years ago

Where did you get that GPT3 has 12288 size token embeddings? I thin that's the internal or output size of the token inside the transformer layers, not in the embedding table.

afro88|3 years ago

Thanks for explaining, very enlightening.

7to2|3 years ago

Do you know what the WPEs are for llama?

sebzim4500|3 years ago

It doesn't really use them, it uses something called RoPE which is hardcoded rather than learned and is applied multiplicatively at every layer to both the key and the value.

https://arxiv.org/abs/2104.09864