Skip to main content

LLM Inference: Autoregressive Generation and Attention KV Cache

This post explains the basic of LLM inference, mainly focusing on differences from training LLM.

Autreogressive Text Generation #

Unlike training, where tokens are parallelized and trained, inference generates tokens one by one. Therefore, to create a full sentence, several forward pass should be executed (# tokens times). The following video from HuggingFace illustrates how it works.

Autoregressive token generation. Source: HuggingFace

Before generating the first token, LLM first puts all input tokens to generate context. This stage is called prefill stage. Because input tokens are fixed, tokens in the prefill stage can be consumed in parallel. After prefill is done, decode stage begins; where LLM goes through iterations to create tokens one by one.

Transformer Attention Layer #

To understand the detailed explanation, I recommend to read JAy Alammar’s blog post, and a Youtube Video

Now let’s see why token generation should be done autoregressively. This is because how attention is calculated. A decoder-only transformer based models, such as GPT, looks like this:

$$ \text{Attention(Q, K, V)} = \text{softmax}\ \frac{QK^T}{\sqrt{d_k}} V $$

where query (Q), key (K), and value (V) vectors are calculated as:

$$ Q = W^qx, K = W^kx, V = W^vx$$ Here, key represents the relationship between all previous tokens and the current token, and value is an accumulated sum over the previous context.

Therefore, we have in total 5 matmul operations to calculate self attention:

  1. $Q = W^qx$
  2. $K = W^kx$
  3. $V = W^vx$
  4. $QK^T$
  5. $\text{softmax}(4)V$.

Key-Value Cache #

Because key and value matrices are for maintaining context, matrices that are generated in previous iterations can be cached and reused, instead of calculating the entire K/V from scratch. It is called key-value (KV) cache. Assuming model dimension as D and number of attention heads as H, dimensions per head is D/H. Then the dimension of key/value matrix is [D, N, D/H] for a single batch, where N is the number of tokens in the context.

As an example, say we have 7 tokens in the input:

from transformers import AutoTokenizer, OPTForCausalLM, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, use_cache=True).to("cuda")
model.config.pad_token_id = model.config.eos_token_id

input = "A red jacket compliments any outfit."
input = tokenizer(input, return_tensors="pt").to("cuda")

# input["input_ids"].shape: torch.Size([1, 7])

model.generate(**input, max_length=50)

In the first iteration, we have the following shape for query, key, and value:

query.shape, key.shape, value.shape
(torch.Size([1, 12, 7, 64]), torch.Size([1, 12, 7, 64]), torch.Size([1, 12, 7, 64]))

where D=12 and D/H=64.

Code location: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/gpt2/modeling_gpt2.py

query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

After prefill is done, key and value only have parameters for one token:

>>> query.shape, key.shape, value.shape
(torch.Size([1, 12, 1, 64]), torch.Size([1, 12, 1, 64]), torch.Size([1, 12, 1, 64]))

then concatenated to the past kv:

if layer_past is not None:
    past_key, past_value = layer_past
    key = torch.cat((past_key, key), dim=2)
    value = torch.cat((past_value, value), dim=2)
>>> key.shape, value.shape
(torch.Size([1, 12, 8, 64]), torch.Size([1, 12, 8, 64]))

which saves huge amount of computation for calculating 2 and 3. Note that this does not save computation for 4 and 5; as context is getting longer, K and V size is larger, and matmul takes longer to compute.