LLM training occasionally produces gradient spikes — values orders of magnitude larger than typical. Unmitigated, they push parameters to bad regions and the loss explodes. Gradient clipping caps the norm, keeping training stable.
Why spikes happen
Mostly: rare token sequences with extreme attention patterns + numerically unstable softmax inputs + RoPE positions outside training distribution. Sometimes: bad batches (corrupt data, prompt-injection-shaped tokens). Modern data pipelines filter the obvious ones but not all.
Norm clipping
# Compute global L2 norm of all gradients:
norm = sqrt(sum over params p of sum(p.grad²))
# If norm exceeds max_norm, rescale:
if norm > max_norm:
scale = max_norm / norm
for p in params:
p.grad *= scaleTreats all gradients as one big vector. Preserves direction (gradient ratios across parameters intact). Common max_norm: 1.0 for LLMs. PyTorch: torch.nn.utils.clip_grad_norm_(params, 1.0).
Value clipping (less common)
for p in params:
p.grad.clamp_(-c, c)Hard clip each gradient entry independently. Simpler but distorts gradient direction. Used in some RL contexts; rare for LM training.
Tuning max_norm
Too low: throttles real learning (most gradients get clipped). Too high: spikes get through. Watch the gradient norm over time. If it's mostly 0.5 with occasional 100 spikes, max_norm=1 keeps the normal flow and kills spikes. If it's mostly 5, max_norm=1 throttles everything; raise to 10.
Skip-step on bigger spikes
if norm > skip_threshold: # e.g., 10 * max_norm
optim.zero_grad() # discard this step entirely
continueCatastrophic gradient → skip the optimizer step entirely. Avoids polluting moments. Recover from bad batch without polluting Adam's state. Standard in production LLM training (Llama, Mistral training code).