Unlike training, where tokens are parallelized and trained, inference generates tokens one by one. Therefore, to create a full sentence, several forward pass should be executed (# tokens times).
The following video from HuggingFace illustrates how it works.
Before generating the first token, LLM first puts all input tokens to generate context. This stage is called prefill stage.
Because input tokens are fixed, tokens in the prefill stage can be consumed in parallel.
After prefill is done, decode stage begins; where LLM goes through iterations to create tokens one by one.
Now let’s see why token generation should be done autoregressively. This is because how attention is calculated.
A decoder-only transformer based models, such as GPT, looks like this:
$$ \text{Attention(Q, K, V)} = \text{softmax}\ \frac{QK^T}{\sqrt{d_k}} V $$
where query (Q), key (K), and value (V) vectors are calculated as:
$$ Q = W^qx, K = W^kx, V = W^vx$$
Here, key represents the relationship between all previous tokens and the current token, and value is an accumulated sum over the previous context.
Therefore, we have in total 5 matmul operations to calculate self attention:
Because key and value matrices are for maintaining context, matrices that are generated in previous iterations can be cached and reused, instead of calculating the entire K/V from scratch. It is called key-value (KV) cache.
Assuming model dimension as D and number of attention heads as H, dimensions per head is D/H. Then the dimension of key/value matrix is [D, N, D/H] for a single batch, where N is the number of tokens in the context.
As an example, say we have 7 tokens in the input:
fromtransformersimportAutoTokenizer,OPTForCausalLM,AutoModelForCausalLMtokenizer=AutoTokenizer.from_pretrained("gpt2")tokenizer.padding_side="left"tokenizer.pad_token=tokenizer.eos_tokenmodel=AutoModelForCausalLM.from_pretrained("gpt2",torch_dtype=torch.float16,use_cache=True).to("cuda")model.config.pad_token_id=model.config.eos_token_idinput="A red jacket compliments any outfit."input=tokenizer(input,return_tensors="pt").to("cuda")# input["input_ids"].shape: torch.Size([1, 7])model.generate(**input,max_length=50)
In the first iteration, we have the following shape for query, key, and value:
which saves huge amount of computation for calculating 2 and 3.
Note that this does not save computation for 4 and 5; as context is getting longer, K and V size is larger, and matmul takes longer to compute.