Inference for an autoregressive LLM is a loop: given a prompt, predict the next token, append it, repeat. The naive version is slow because it recomputes everything. The optimized version uses the KV cache. Both are short pseudocode.
Naive version (no cache)
def generate(model, prompt, max_new):
tokens = tokenize(prompt)
for _ in range(max_new):
logits = model(tokens) # full forward over ALL tokens
next_token = sample(logits[-1]) # only need last position
tokens.append(next_token)
if next_token == EOS:
break
return tokensSimple but wasteful: at step t, recomputes attention over t tokens. Total compute: O(T²) for T new tokens. Acceptable for T<100; painful past that.
KV-cached version
def generate_cached(model, prompt, max_new):
tokens = tokenize(prompt)
# First pass: prefill — process the prompt and fill the cache
logits, kv_cache = model(tokens, cache=None)
next_token = sample(logits[-1])
tokens.append(next_token)
# Decode loop: process ONE token at a time, reusing cache
for _ in range(max_new - 1):
logits, kv_cache = model(
[next_token], cache=kv_cache)
next_token = sample(logits[0])
tokens.append(next_token)
if next_token == EOS:
break
return tokensEach decode step processes just 1 token. Past tokens' K and V are cached in kv_cache. Total compute: O(T·N) for T new tokens with N-length prefill. Linear in T instead of quadratic.
Prefill vs decode
Prefill: process the user's prompt (potentially thousands of tokens). One forward pass; can be batched in seq dim. Decode: generate one token at a time. Each step is small (1 token). Prefill is compute-bound; decode is memory-bandwidth-bound. Optimize both separately.
The cache structure
# Per layer, per head:
k_cache: [batch, n_heads, seq_so_far, head_dim]
v_cache: [batch, n_heads, seq_so_far, head_dim]
# Total cache size for whole model:
2 * L * h * head_dim * seq * batch * bytes_per_valFor Llama 3 8B at seq=8192, batch=1, BF16: ~4 GB just for KV cache. For batch=8: 32 GB. KV cache often exceeds model weights at long context — biggest reason for FP8/INT4 KV cache quantization.
Stopping criteria
Generate until: EOS token sampled, max_tokens reached, or a stop string is matched. For chat models: usually a chat-format token like <|im_end|>. Beam search adds early stopping. Best practice: set max_tokens to bound runaway generation.