Inference for an autoregressive LLM is a loop: given a prompt, predict the next token, append it, repeat. The naive version is slow because it recomputes everything. The optimized version uses the KV cache. Both are short pseudocode.

Advertisement

Naive version (no cache)

def generate(model, prompt, max_new):
    tokens = tokenize(prompt)
    for _ in range(max_new):
        logits = model(tokens)              # full forward over ALL tokens
        next_token = sample(logits[-1])     # only need last position
        tokens.append(next_token)
        if next_token == EOS:
            break
    return tokens

Simple but wasteful: at step t, recomputes attention over t tokens. Total compute: O(T²) for T new tokens. Acceptable for T<100; painful past that.

KV-cached version

def generate_cached(model, prompt, max_new):
    tokens = tokenize(prompt)
    # First pass: prefill — process the prompt and fill the cache
    logits, kv_cache = model(tokens, cache=None)
    next_token = sample(logits[-1])
    tokens.append(next_token)

    # Decode loop: process ONE token at a time, reusing cache
    for _ in range(max_new - 1):
        logits, kv_cache = model(
            [next_token], cache=kv_cache)
        next_token = sample(logits[0])
        tokens.append(next_token)
        if next_token == EOS:
            break
    return tokens

Each decode step processes just 1 token. Past tokens' K and V are cached in kv_cache. Total compute: O(T·N) for T new tokens with N-length prefill. Linear in T instead of quadratic.

Advertisement

Prefill vs decode

Prefill: process the user's prompt (potentially thousands of tokens). One forward pass; can be batched in seq dim. Decode: generate one token at a time. Each step is small (1 token). Prefill is compute-bound; decode is memory-bandwidth-bound. Optimize both separately.

The cache structure

# Per layer, per head:
k_cache: [batch, n_heads, seq_so_far, head_dim]
v_cache: [batch, n_heads, seq_so_far, head_dim]

# Total cache size for whole model:
2 * L * h * head_dim * seq * batch * bytes_per_val

For Llama 3 8B at seq=8192, batch=1, BF16: ~4 GB just for KV cache. For batch=8: 32 GB. KV cache often exceeds model weights at long context — biggest reason for FP8/INT4 KV cache quantization.

Stopping criteria

Generate until: EOS token sampled, max_tokens reached, or a stop string is matched. For chat models: usually a chat-format token like <|im_end|>. Beam search adds early stopping. Best practice: set max_tokens to bound runaway generation.

AR generation = loop with KV cache. Prefill the prompt, then decode one token at a time. Cache memory often exceeds weight memory.