Multi-head attention runs N separate attention operations in parallel on different projections of the input, then concatenates the results. The intuition — different heads can specialize on different patterns — is true. The math is just bookkeeping for tensor reshapes.
Single head vs multi-head
Single-head with d_k = d:
Q = X · W_Q ∈ ℝ^(N × d)
Multi-head with h heads, d_k = d/h:
Q_full = X · W_Q ∈ ℝ^(N × d)
Q_split = reshape Q_full to (N, h, d_k)
Q_per_head[k] = Q_split[:, k, :] ∈ ℝ^(N × d_k)One big matmul produces all heads' projections at once. Then reshape splits across the head dimension. Same total compute and parameters; conceptually h separate attention operations.
Parallel attention per head
for k in 0..h-1:
head_k = Attention(Q[k], K[k], V[k]) ∈ ℝ^(N × d_k)Each head computes attention independently on its slice of Q, K, V. In practice this loop is one batched matmul: scores ∈ ℝ^(h × N × N), output ∈ ℝ^(h × N × d_k).
Concatenation and output projection
concat = reshape (head_0, ..., head_h-1) → ℝ^(N × d)
output = concat · W_O ∈ ℝ^(N × d)Concatenated head outputs go through one more linear projection W_O ∈ ℝ^(d × d). This is where heads' outputs are mixed; it's essential. Without W_O the heads would never communicate.
Total parameter count
MHA params = 4 * d * d (W_Q, W_K, W_V, W_O all d×d)
= 4 * d²For d=2048: ~16M params per attention block. With L=24 layers: ~400M just for attention. Significant for SLMs. GQA (next section) reduces this for K and V.
GQA — Grouped Query Attention
Modern variant: q heads outnumber kv heads. Llama 3 8B: 32 query heads, 8 kv heads. q heads in same group share K, V. Saves KV-cache memory at inference (4× smaller) with minimal quality loss. Required for long-context CPU inference.