Self-attention is the operation that lets a transformer relate tokens to each other. The formula is short. The intuition — query asks, keys answer, values get aggregated — is mechanical. Implemented well (FlashAttention) it's also the operation that scales transformers to long context.

Advertisement

The formula

Attention(Q, K, V) = softmax(Q · Kᵀ / sqrt(d_k)) · V

Shapes (single-head):
  Q ∈ ℝ^(N × d_k)   queries
  K ∈ ℝ^(N × d_k)   keys
  V ∈ ℝ^(N × d_v)   values
  output ∈ ℝ^(N × d_v)

Inputs come from linear projections of the token embeddings: Q = X·W_Q, K = X·W_K, V = X·W_V. W_Q, W_K, W_V are learned d × d_k matrices (one per head).

Step 1: scores

scores = Q · Kᵀ            ∈ ℝ^(N × N)
         scores[i, j] = ⟨Q[i], K[j]⟩

Each entry is the dot product (similarity) of query i with key j. N=2048 → 2048² = ~4M scores. This is the operation that makes attention O(N²) in memory/compute — the key bottleneck for long context.

Advertisement

Step 2: scale and softmax

attn = softmax(scores / sqrt(d_k))    ∈ ℝ^(N × N)
row i is the attention distribution from query i over keys

Dividing by sqrt(d_k) keeps the logit distribution well-behaved (see softmax article). Each row of attn sums to 1. attn[i,j] = how much token i attends to token j.

Step 3: aggregate values

output = attn · V    ∈ ℝ^(N × d_v)
output[i] = sum over j of attn[i,j] * V[j]

Each output is a soft-weighted average of value vectors. The attention weights determine which values contribute. This is the 'differentiable lookup' that distinguishes attention from earlier mechanisms.

Causal mask for next-token training

# add -inf above the diagonal (positions in the future)
masked_scores = scores + mask
mask[i, j] = -inf if j > i else 0

For autoregressive LLMs, queries at position i should only attend to keys at positions ≤ i. The mask sets future positions to -∞ before softmax → exp(-∞) = 0 → those positions get zero weight. Implemented as an additive mask of zeros and -inf.

scores = Q·Kᵀ/sqrt(d_k). softmax → attention weights. Aggregate V. Add causal mask for autoregressive LLMs.