Skip to content

13. Debugging causal attention — silent bugs, direct checks

~10 min read. Attention failures often look legal to PyTorch and wrong to humans.

Built on the ELI5 in 00-eli5.md.

The exam rule is the first thing bugs violate.

The safest response is not guessing.

It is checking shapes, probabilities, and cache parity.


1) Attention bugs are usually silent

Many attention bugs do not crash.

They still produce tensors.

They still produce gradients.

Loss can even move in the right direction.

That is why these bugs waste time.

A wrong transpose is still valid tensor math.

A misplaced mask is still valid tensor math.

A missing merge step is still valid tensor math.

The model keeps running.

Generation just feels strange.

So debugging must be procedural.

Do not stare only at the final text.

Check the layer contracts directly.

Check legality first.

Check tensor shapes next.

Check cache behavior last.

That order catches most failures.

2) Walk the failure chain in execution order

Attention has a natural path.

Make Q, K, and V.

Split heads.

Compute scores.

Apply the causal mask.

Run softmax.

Mix values.

Merge heads.

Reuse cached keys and values during decode.

Debug in that same order.

If score matmul crashes, inspect head shapes and key transpose.

If future positions get probability mass, inspect mask placement.

If outputs keep a head axis, inspect the merge step.

If cached decode feels wrong, compare it with naive full-prefix decode.

A tiny trace is enough.

X [B,T,D]

Q,K,V [B,T,D]

split [B,H,T,d_head]

scores [B,H,T,T]

weights [B,H,T,T]

head output [B,H,T,d_head]

merge [B,T,D]

That trace should feel boring.

Boring is good.

Reliable debugging is repetitive.

3) One numerical check catches many mask bugs

Use one row and compute it by hand.

Let raw logits be [2, 4, 6].

Let the legal mask row be [1, 1, 0].

The correct path masks logits before softmax.

So logits become [2, 4, -inf].

Exponentials become [7.39, 54.60, 0].

The sum is 61.99.

Probabilities become [0.119, 0.881, 0.000].

The future position gets zero mass.

Now take the wrong path.

Softmax first on [2, 4, 6].

Exponentials become [7.39, 54.60, 403.43].

The sum is 465.42.

Probabilities become [0.016, 0.117, 0.867].

Now multiply by the mask after softmax.

You get [0.016, 0.117, 0.000].

That row sums to 0.133.

That is broken.

Attention rows should sum to one.

So the rule is simple.

Mask logits first.

Then softmax.

Never reverse that order.

4) Use one decision tree, not intuition

When the model acts weird, run a short tree.

┌─ output wrong? ─┐
│                 │
├── score matmul crashes?
│   └── yes ──→ check K transpose dims
├── future positions get nonzero weight?
│   └── yes ──→ mask logits before softmax
├── self-position missing?
│   └── yes ──→ keep diagonal in tril mask
├── output shape still has H axis?
│   └── yes ──→ transpose + reshape to [B,T,D]
└── cached ≠ naive decode?
    └── yes ──→ check cache reset + append axis

This tree is plain on purpose.

It matches the execution path.

It turns “weird output” into direct checks.

One final habit matters.

Compare cached decode with naive decode on the same prefix.

The logits should match closely.

If they do not, suspect stale cache state.

Also suspect append order or position handling.

That comparison catches many serving bugs fast.


Where this lives in the wild

  • Hugging Face attention modules — reference implementations rely on exact shape contracts and mask tests.
  • TensorRT-LLM kernel validation — optimized kernels are checked numerically against slower reference kernels.
  • vLLM cache debugging — decode correctness is often verified by comparing cached and uncached outputs.
  • Open-source LLaMA forks — early attention bugs frequently came from transpose mistakes and stale KV state.
  • Internal model eval pipelines — behaviour checks reveal leakage even when loss curves look normal.

Pause and recall

  • Why are attention bugs often harder than ordinary shape bugs?
  • What exact symptom appears when masking happens after softmax?
  • Which shape trace should you print before chasing fancy explanations?
  • Why is cached-versus-naive decode such a strong test?

Interview Q&A

Q. What do you inspect first when debugging attention?

A. I print shapes after each transformation from X through head merge.

Common wrong answer to avoid: “I read the generated text and guess from the vibe.”

Q. How do you verify a causal mask quickly?

A. I inspect one attention row and confirm every future position gets zero probability.

Common wrong answer to avoid: “Low training loss means the mask is probably correct.”

Q. How do you validate a KV cache implementation?

A. I compare cached decode logits with naive full-prefix decode on the same prefix.

Common wrong answer to avoid: “If decode is faster, the cache must be right.”


Apply now (5 min)

Take logits [1, 3, 5].

Mask the future position correctly.

Compute the legal softmax by hand.

Then repeat the wrong path with masking after softmax.

Write one sentence saying what breaks.

Next, sketch from memory:

  • the seven-line shape trace,
  • the decision tree,
  • and the rule for cache parity testing.

Bridge. Good. We can now debug the clean attention story.

The next file admits which production details this module intentionally simplified.

14-honest-admission.md