FlashAttention is a re-implementation of the attention operation that respects GPU memory hierarchy. Standard attention writes a huge N×N matrix to slow HBM memory; FlashAttention computes attention in tiles that fit in fast SRAM and never materializes the full matrix. Result: 5-10x speedup and 10-20x memory reduction for long contexts.

Advertisement

The bottleneck

Attention has two steps: compute S = Q@K^T (the score matrix), then output = softmax(S)@V. The N×N S matrix is enormous for long sequences (8K context = 64M entries). Writing and reading this matrix from HBM dominates runtime — it's memory-bound, not compute-bound.

The trick

Process Q and K in tiles. For each tile of Q, iterate over tiles of K/V, accumulate the softmax incrementally. Use an online softmax algorithm (one-pass mean+variance) to combine partial results without ever holding the full S matrix. All in fast SRAM.

Advertisement

Versions

FlashAttention v1 (2022): the original; ~3x speedup. v2 (2023): better thread parallelism; ~2x more. v3 (2024): uses Hopper TMA + FP8; ~2x more on H100. Each version maintains numerical equivalence with vanilla attention.

When you don't need to think about it

PyTorch 2.0+ torch.nn.functional.scaled_dot_product_attention auto-selects FlashAttention when applicable. Hugging Face Transformers uses it via the attn_implementation='flash_attention_2' arg. You almost never write it yourself.

Where it matters

Long context (>2K). Large batch inference. Training. For short-sequence inference (<512 tokens) the gain is small. For 8K+ context it's the difference between 'works' and 'OOM on a 80GB GPU'.

Don&#x27;t materialize the N×N matrix. Tile + online softmax + SRAM. 5-10x speed, 10-20x memory savings.