Multi-head attention runs N separate attention operations in parallel on different projections of the input, then concatenates the results. The intuition — different heads can specialize on different patterns — is true. The math is just bookkeeping for tensor reshapes.

Advertisement

Single head vs multi-head

Single-head with d_k = d:
  Q = X · W_Q    ∈ ℝ^(N × d)

Multi-head with h heads, d_k = d/h:
  Q_full = X · W_Q   ∈ ℝ^(N × d)
  Q_split = reshape Q_full to (N, h, d_k)
  Q_per_head[k] = Q_split[:, k, :]   ∈ ℝ^(N × d_k)

One big matmul produces all heads' projections at once. Then reshape splits across the head dimension. Same total compute and parameters; conceptually h separate attention operations.

Parallel attention per head

for k in 0..h-1:
  head_k = Attention(Q[k], K[k], V[k])   ∈ ℝ^(N × d_k)

Each head computes attention independently on its slice of Q, K, V. In practice this loop is one batched matmul: scores ∈ ℝ^(h × N × N), output ∈ ℝ^(h × N × d_k).

Advertisement

Concatenation and output projection

concat = reshape (head_0, ..., head_h-1) → ℝ^(N × d)
output = concat · W_O    ∈ ℝ^(N × d)

Concatenated head outputs go through one more linear projection W_O ∈ ℝ^(d × d). This is where heads' outputs are mixed; it's essential. Without W_O the heads would never communicate.

Total parameter count

MHA params = 4 * d * d   (W_Q, W_K, W_V, W_O all d×d)
           = 4 * d²

For d=2048: ~16M params per attention block. With L=24 layers: ~400M just for attention. Significant for SLMs. GQA (next section) reduces this for K and V.

GQA — Grouped Query Attention

Modern variant: q heads outnumber kv heads. Llama 3 8B: 32 query heads, 8 kv heads. q heads in same group share K, V. Saves KV-cache memory at inference (4× smaller) with minimal quality loss. Required for long-context CPU inference.

Multi-head = one matmul + reshape + h parallel attentions + concat + W_O. GQA shares K,V across heads to save KV-cache.