LLM training occasionally produces gradient spikes — values orders of magnitude larger than typical. Unmitigated, they push parameters to bad regions and the loss explodes. Gradient clipping caps the norm, keeping training stable.

Advertisement

Why spikes happen

Mostly: rare token sequences with extreme attention patterns + numerically unstable softmax inputs + RoPE positions outside training distribution. Sometimes: bad batches (corrupt data, prompt-injection-shaped tokens). Modern data pipelines filter the obvious ones but not all.

Norm clipping

# Compute global L2 norm of all gradients:
norm = sqrt(sum over params p of sum(p.grad²))

# If norm exceeds max_norm, rescale:
if norm > max_norm:
    scale = max_norm / norm
    for p in params:
        p.grad *= scale

Treats all gradients as one big vector. Preserves direction (gradient ratios across parameters intact). Common max_norm: 1.0 for LLMs. PyTorch: torch.nn.utils.clip_grad_norm_(params, 1.0).

Advertisement

Value clipping (less common)

for p in params:
    p.grad.clamp_(-c, c)

Hard clip each gradient entry independently. Simpler but distorts gradient direction. Used in some RL contexts; rare for LM training.

Tuning max_norm

Too low: throttles real learning (most gradients get clipped). Too high: spikes get through. Watch the gradient norm over time. If it's mostly 0.5 with occasional 100 spikes, max_norm=1 keeps the normal flow and kills spikes. If it's mostly 5, max_norm=1 throttles everything; raise to 10.

Skip-step on bigger spikes

if norm > skip_threshold:    # e.g., 10 * max_norm
    optim.zero_grad()        # discard this step entirely
    continue

Catastrophic gradient → skip the optimizer step entirely. Avoids polluting moments. Recover from bad batch without polluting Adam's state. Standard in production LLM training (Llama, Mistral training code).

Norm-clip at max_norm=1. Watch gradient norms over time. Catastrophic spikes → skip optimizer step entirely.