05. Assignment 4 — Causal Self-Attention + KV Cache from Scratch¶
Week 4. Build the decoder-only attention core that Module 05 will assume you already understand.
Goal¶
Implement causal self-attention end to end: - Q/K/V projections - lower-triangular causal masking - multi-head splitting and concatenation - output projection - KV cache for one-token decode
Do this in NumPy or PyTorch.
Do not use nn.MultiheadAttention as the implementation.
You may use it only for verification.
Required build¶
- Single-head causal attention that visibly masks future tokens.
- Multi-head causal self-attention with
split_headsandmerge_heads. - KV cache for inference-time decode.
- One tiny GPT-style block that wires attention + residual + norm + FFN.
- Verification harness that checks shapes and at least one numerical property.
Recommended scope¶
Suggested config for a manageable implementation:
d_model = 128num_heads = 4d_head = 32max_seq_len = 64or128batch_size = 2for tests
You are not chasing model quality this week. You are chasing correctness and explainability.
Deliverables¶
attention.py— Q/K/V projections, mask, multi-head attention, output projectioncache.py— KV cache class or helper functionsblock.py— one GPT-style block with residual + norm + FFNverify.py— shape tests, mask tests, and a small comparison against PyTorch or a hand-worked examplegenerate.py— one-token decode loop using the cacheREADME.md— explain the shape flow, mask logic, and what caching changes
Success criteria¶
- Future positions receive zero probability after masking.
- Attention output returns to shape
[B, T, d_model]after merging heads. - Cache length increases exactly one token per decode step.
- Naive decode and cached decode produce the same logits for the same prefix.
- README explains the code well enough for an interview whiteboard answer.
Required checks¶
Check 1 — Mask correctness¶
Use a toy score matrix and prove that future positions become zero after softmax.
Check 2 — Shape trace¶
Print or document the shapes for:
Check 3 — Cache correctness¶
Run decode twice: - once by recomputing the full prefix each step - once using cache
Confirm the logits match.
Check 4 — Interview explanation¶
Record a 3-5 minute verbal explanation of your implementation. If you cannot explain it aloud, you do not own it yet.
Stretch goals¶
- Add causal attention heatmap visualization.
- Add RoPE after you finish the base version.
- Explain grouped-query attention versus full multi-head attention.
- Measure cache memory use for context lengths 128, 512, and 2048.
Common failure modes¶
| Failure | What to inspect first |
|---|---|
| low loss but nonsense generation | missing or wrong mask orientation |
| shape mismatch in score computation | wrong transpose of K |
output shape stuck at [B, H, T, d_head] |
forgot merge step |
| cache grows incorrectly | concatenating along wrong axis |
| outputs differ between naive and cached decode | position handling or stale cache |
Why this hands_on_lab matters¶
This is the decoder core. If you can build this calmly, then Module 05's training pipeline discussion lands on solid ground. If you cannot, Module 05 becomes vocabulary without mechanics.
Bridge forward¶
Next module — 05_llm_training_pipeline — covers how these transformer blocks get trained at scale: pretraining, SFT, RLHF, and the full pipeline from raw text to deployed model.