Naive attention writes the N×N attention matrix to memory. FlashAttention computes it block-by-block in SRAM (on-chip cache), never materializing the full matrix. Same math, ~5× faster on long context.

Advertisement

The IO problem

# Naive:
# 1. Compute Q·Kᵀ → write [N, N] to HBM   (slow)
# 2. Apply softmax over rows                (slow read+write)
# 3. Multiply by V → write [N, dv] to HBM    (slow)
#
# For N=8192, d=64: 256 MB just for scores

HBM (GPU global memory) is slow relative to SRAM (on-chip cache, ~10× faster but ~1000× smaller). Materializing the full N×N matrix forces it into HBM. FlashAttention avoids this.

Tiling and recompute

for q_block in Q_tiles:
    for k_block, v_block in zip(K_tiles, V_tiles):
        # all in SRAM:
        scores = q_block · k_blockᵀ
        softmax_partial(scores, running_max, running_sum)
        out_block += softmax_partial · v_block
    write out_block to HBM

Each block is small enough to fit in SRAM. Online softmax (maintain running max and sum across blocks) gives exact result without ever materializing the full softmax. Output written once. Memory I/O: O(N·d) instead of O(N²).

Advertisement

The math equivalence

The output is identical to naive attention. Online softmax is an algebraic rearrangement: softmax(s) = exp(s - max) / sum(exp(s - max)). The running max and sum can be combined across blocks correctly. Standard 'log-sum-exp trick' applied incrementally.

Speed and memory

# FlashAttention 2:
# Memory: O(N·d)  (no full attention matrix)
# Time:   O(N²·d), same FLOPs, but 5-10x faster wall-time

Same compute count; far better memory access pattern. PyTorch's F.scaled_dot_product_attention uses FlashAttention 2 automatically on supported GPUs. Standard in production training and inference.

CPU equivalents

CPU attention can use the same tiling idea. llama.cpp's CPU attention does block-wise computation to fit L2/L3 cache. Not as dramatic a win as GPU FlashAttention, but still ~2× for long contexts.

FlashAttention: tile attention in SRAM, never materialize N×N matrix. Same math, 5-10× faster. Standard in PyTorch SDPA.