Skip to content

07. Multi-head causal attention — one tensor story from X to Y

~12 min read. This is the implementation interviewers want. Every tensor shape stays visible.

Built on the ELI5 in 00-eli5.md. The parallel graders and the exam rule now combine into one complete masked attention pass.


1) The whole movie first

  • We want one clean tensor story.
  • Input goes in.
  • Output comes out.
  • Nothing stays mysterious.
    ┌──────────────────────────────┐
    │ X                [B, T, D]   │
    └──────────────┬───────────────┘
            ┌──────┼──────┐
            ▼      ▼      ▼
          XW_Q   XW_K   XW_V ──→ Q,K,V [B,T,D]
         split_heads ──→ [B,H,T,d_head]
          Q @ K^T ──→ scores [B,H,T,T]
       mask ──→ softmax ──→ weights
     weights @ V ──→ heads [B,H,T,d_head]
    merge_heads ──→ [B,T,D] ──→ @ W_O ──→ Y [B,T,D]
    
  • That is the entire attention movie.
  • The projection lens creates Q, K, and V.
  • The parallel graders appear after the split.
  • The exam rule appears at the score matrix.
  • The answer sheet supplies sequence length T.
  • Later, the memory shortcut will reuse old K and V.
  • This file covers the unfused baseline implementation.
  • Learn this version first.
  • Fused kernels are still doing this math.
  • They just do it faster.
  • If you know the movie,
  • optimised code becomes readable.
  • If you do not know the movie,
  • optimised code feels magical.

2) Helper functions stay tiny

  • Small helpers make the main function readable.
  • We need softmax, split, merge, and a causal mask.
    import numpy as np
    
    def softmax(x, axis=-1):
        x = x - np.max(x, axis=axis, keepdims=True)
        exp_x = np.exp(x)
        return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
    
    def causal_mask(T):
        return np.triu(np.ones((T, T), dtype=bool), k=1)
    
  • Numerical stability matters in softmax.
  • Subtracting the row max prevents overflow.
  • The mask is upper triangular.
  • True means block that future position.
  • split_heads and merge_heads were covered already.
  • We reuse them unchanged here.
  • One detail causes many mistakes.
  • Scale scores by sqrt(d_head).
  • Do not scale by sqrt(D).
  • Each head computes its own dot products.
  • Each head lives in width d_head.
  • So each head uses its own denominator.
  • Wrong scaling makes softmax too flat.
  • Flat weights weaken selective attention.
  • That bug often hides because code still runs.
  • Another hidden bug is masking with the wrong broadcast shape.
  • Add two singleton axes for batch and head.
  • That lets one mask apply across every example cleanly.
  • Good helper design keeps this broadcast obvious.
  • Good variable names keep debugging short.
  • Small details like this separate readable attention code from brittle attention code.

3) The core function with shapes

def multi_head_causal_attention(X, W_Q, W_K, W_V, W_O, num_heads):
    B, T, D = X.shape
    d_head = D // num_heads
    Q = split_heads(X @ W_Q, num_heads)
    K = split_heads(X @ W_K, num_heads)
    V = split_heads(X @ W_V, num_heads)
    scores = (Q @ K.transpose(0, 1, 3, 2)) / np.sqrt(d_head)
    mask = causal_mask(T)
    scores = np.where(mask[None, None, :, :], -1e9, scores)
    weights = softmax(scores, axis=-1)
    heads = weights @ V
    merged = merge_heads(heads)
    return merged @ W_O, weights
- Start with X shaped [B, T, D]. - Project once for queries. - Project once for keys. - Project once for values. - Split all three into head-major layout. - Compute dot products inside each head. - That gives scores shaped [B, H, T, T]. - Divide by sqrt(d_head). - Apply the exam rule before softmax. - Softmax turns scores into probabilities. - Probabilities sum to one across the last axis. - Mix values using those probabilities. - Merge head outputs back to width D. - Apply output projection W_O. - Return both Y and weights for inspection. - The projection lens appears in four linears. - The parallel graders appear in split, score, and merge. - The answer sheet fixes the T by T attention table. - The memory shortcut is absent here only because this is the full-pass version.

4) One worked numerical example

  • Use one token position with two heads.
  • Suppose token 2 is the current row.
  • Head 0 weights are [0.80, 0.20, 0.00].
  • Head 1 weights are [0.05, 0.90, 0.05].
  • The same token now has two reading styles.
  • Head 0 looks farther back.
  • Head 1 looks more locally.
  • Give head 0 these value vectors.
  • v0 = [10, 1].
  • v1 = [4, 2].
  • v2 = [7, 3].
  • Then head 0 output is
  • 0.80 * [10, 1] + 0.20 * [4, 2] + 0.00 * [7, 3].
  • That becomes [8.8, 1.2].
  • Give head 1 these value vectors.
  • u0 = [1, 9].
  • u1 = [3, 6].
  • u2 = [5, 2].
  • Then head 1 output is
  • 0.05 * [1, 9] + 0.90 * [3, 6] + 0.05 * [5, 2].
  • That becomes [3.0, 5.95].
  • The two heads are not redundant.
  • One head emphasised long-range information.
  • One head emphasised local information.
  • Merge concatenates those outputs.
  • W_O later mixes them into one layer voice.
  • Because of the exam rule,
  • neither head looked ahead to a future token.
  • Because of the answer sheet,
  • both heads used the same three visible positions.

5) What changes during decoding and debugging

  • During training or prefill, we process all T positions together.
  • During autoregressive decoding, we add one new token each step.
  • That is where the memory shortcut matters.
  • We cache old keys and old values.
  • We compute only the new query.
  • We also compute the new token's key and value.
  • The same score logic still applies.
  • The same exam rule still applies.
  • The new token may attend to cached positions and itself.
  • Debugging still starts with shapes.
  • Confirm Q, K, and V are [B, H, T, d_head].
  • Confirm scores are [B, H, T, T].
  • Confirm weights sum to one on the last axis.
  • Confirm merged output is [B, T, D].
  • Confirm scaling used d_head, not D.
  • Confirm masking happened before softmax.
  • If any of those checks fail,
  • the attention path is wrong.
  • Learn these checks once.
  • They save hours later.
  • Also print one sample row of weights.
  • Future positions should receive effectively zero probability.
  • Allowed past positions should sum to one.
  • Those two checks validate masking and softmax together.
  • They are cheap and reliable.
  • Use them whenever a new attention implementation behaves strangely.

Where this lives in the wild

  • nanoGPTCausalSelfAttention shows the full decoder attention path clearly.
  • Hugging Face GPT-2 — production PyTorch code implements masked multi-head attention with split and merge.
  • PyTorchscaled_dot_product_attention fuses the same math behind one call.
  • JAX/FlaxMultiHeadDotProductAttention packages projections, masking, and value mixing.
  • TensorRT-LLM — inference kernels optimise the same tensor movie for speed.

Pause and recall

  1. Why must scaling use sqrt(d_head) instead of sqrt(D)?
  2. Which line enforces the exam rule in code?
  3. In the worked example, which head focused more locally?
  4. Where does the memory shortcut enter this implementation during decoding?

Interview Q&A

Q1. Walk me from X to Y in multi-head causal attention. - Project to Q, K, V, split heads, score, mask, softmax, mix values, merge, then apply W_O. Common wrong answer to avoid: "Attention is just Q times K times V." Q2. Why are different heads useful? - They learn different attention patterns over the same sequence positions. Common wrong answer to avoid: "Heads only exist for model parallelism." Q3. Where does the causal mask act? - It acts on scores before softmax, not after value mixing. Common wrong answer to avoid: "Apply the mask after the weighted sum."


Apply now (5 min)

  • Quick exercise.
  • Write the full function from memory without copying.
  • Annotate every tensor with its shape.
  • Then change D = 12 and H = 3.
  • State the new d_head.
  • Sketch from memory.
  • Draw the whole tensor movie from X to Y.
  • Mark where the projection lens, parallel graders, exam rule, answer sheet, and memory shortcut each appear.

Bridge. We now have several head outputs merged back to width D. Concatenation alone still does not mix them fully. → 08-output-projection.md