03. Week 4 — Causal Attention & Coding¶
Companion to 02_explainer.md. Use this file as the compact formula-and-pattern sheet.
Section 1 — The opening bug¶
If attention is unmasked during training,
position i can see tokens j > i.
That leaks the answer.
Loss falls for the wrong reason.
Generation quality collapses later.
Core mismatch: - training without mask = future visible - inference during generation = future absent
Section 2 — Causal mask¶
Autoregressive rule:
Lower-triangular mask for T=4:
NumPy:
Batched multi-head broadcast shape:
Section 3 — Scaled dot-product attention¶
Formula:
Masked version:
scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_k)
scores = np.where(mask, scores, -1e9)
weights = softmax(scores, axis=-1)
out = weights @ V
Why sqrt(d_k)?
Without scaling,
dot products grow with dimension.
Softmax saturates.
Gradients become poor.
Why -1e9 or -inf before softmax?
Because forbidden positions must receive probability zero.
Section 4 — Q/K/V projections¶
Given hidden states X [B, T, D]:
Typical shapes:
Intuition: - Query = what this token wants - Key = what this token can match - Value = what this token can contribute
Useful interview sentence: Queries and keys decide who talks to whom. Values decide what gets said.
Section 5 — Multi-head attention¶
If D = 512 and H = 8,
then d_head = 64.
Shape flow:
X [B, T, D]
Q/K/V projections [B, T, D]
split heads [B, H, T, d_head]
scores [B, H, T, T]
weights [B, H, T, T]
head outputs [B, H, T, d_head]
merge heads [B, T, D]
output projection [B, T, D]
Reference helpers:
def split_heads(x, H):
B, T, D = x.shape
d_head = D // H
return x.reshape(B, T, H, d_head).transpose(0, 2, 1, 3)
def merge_heads(x):
B, H, T, d_head = x.shape
return x.transpose(0, 2, 1, 3).reshape(B, T, H * d_head)
Section 6 — KV cache¶
During one-token-at-a-time decoding, recomputing old keys and values is wasteful. So we cache them.
Per-layer cache shapes:
New step shapes:
Append new_k, new_v to cache,
then compute:
Trade-off: - pro: lower decode latency - con: higher inference memory use
Section 7 — Prefill vs decode¶
This distinction matters.
- Prefill: process the whole prompt in parallel; causal mask required.
- Decode: process one new token at a time; future tokens are absent, so cache structure itself enforces causality.
Section 8 — Common bugs¶
| Symptom | Likely cause | Fix |
|---|---|---|
| Low loss, bad generation | missing mask | apply lower-triangular mask before softmax |
| Shape crash in score matmul | wrong K transpose |
use transpose(..., 3, 2) |
| Wrong number of channels after attention | forgot merge | merge heads back to [B, T, D] |
| Slow decoding | no KV cache | store past K/V per layer |
| Weird outputs across prompts | stale cache | reset cache per request |
| NaNs in attention | numerical issue or all-masked row | keep diagonal, use stable softmax |
Section 9 — Minimal reference implementation¶
import numpy as np
def softmax(x, axis=-1):
x = x - np.max(x, axis=axis, keepdims=True)
exp = np.exp(x)
return exp / np.sum(exp, axis=axis, keepdims=True)
def causal_self_attention(X, W_Q, W_K, W_V, W_O, num_heads):
B, T, D = X.shape
d_head = D // num_heads
Q = split_heads(X @ W_Q, num_heads)
K = split_heads(X @ W_K, num_heads)
V = split_heads(X @ W_V, num_heads)
scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_head)
mask = np.tril(np.ones((T, T), dtype=bool))[None, None, :, :]
scores = np.where(mask, scores, -1e9)
weights = softmax(scores, axis=-1)
out = weights @ V
out = merge_heads(out)
out = out @ W_O
return out, weights
Section 10 — Reference material¶
YouTube¶
- Let's build GPT: from scratch, in code, spelled out. — the best live-coding walkthrough of decoder-only causal attention.
- Let's reproduce GPT-2 (124M) — useful for seeing prefill, training, and generation code together.
Blogs¶
- Understanding and Coding Self-Attention, Multi-Head Attention, Cross-Attention, and Causal-Attention in LLMs — practical coding-first guide.
- Transformers from Scratch - Peter Bloem — excellent derivation of attention mechanics.
- The Illustrated GPT-2 — good for intuition on decoder-only generation and attention flow.
Self-check¶
- What exact bug appears if you forget causal masking?
- Why do we mask logits instead of masking probabilities afterward?
- What are the roles of Q, K, and V?
- What exact shape does the score tensor have in multi-head attention?
- Why does KV cache improve speed but hurt memory?
- What is the difference between prefill and decode?