Skip to content

05. Assignment 4 — Causal Self-Attention + KV Cache from Scratch

Week 4. Build the decoder-only attention core that Module 05 will assume you already understand.

Goal

Implement causal self-attention end to end: - Q/K/V projections - lower-triangular causal masking - multi-head splitting and concatenation - output projection - KV cache for one-token decode

Do this in NumPy or PyTorch. Do not use nn.MultiheadAttention as the implementation. You may use it only for verification.

Required build

  1. Single-head causal attention that visibly masks future tokens.
  2. Multi-head causal self-attention with split_heads and merge_heads.
  3. KV cache for inference-time decode.
  4. One tiny GPT-style block that wires attention + residual + norm + FFN.
  5. Verification harness that checks shapes and at least one numerical property.

Suggested config for a manageable implementation:

  • d_model = 128
  • num_heads = 4
  • d_head = 32
  • max_seq_len = 64 or 128
  • batch_size = 2 for tests

You are not chasing model quality this week. You are chasing correctness and explainability.

Deliverables

  1. attention.py — Q/K/V projections, mask, multi-head attention, output projection
  2. cache.py — KV cache class or helper functions
  3. block.py — one GPT-style block with residual + norm + FFN
  4. verify.py — shape tests, mask tests, and a small comparison against PyTorch or a hand-worked example
  5. generate.py — one-token decode loop using the cache
  6. README.md — explain the shape flow, mask logic, and what caching changes

Success criteria

  • Future positions receive zero probability after masking.
  • Attention output returns to shape [B, T, d_model] after merging heads.
  • Cache length increases exactly one token per decode step.
  • Naive decode and cached decode produce the same logits for the same prefix.
  • README explains the code well enough for an interview whiteboard answer.

Required checks

Check 1 — Mask correctness

Use a toy score matrix and prove that future positions become zero after softmax.

Check 2 — Shape trace

Print or document the shapes for:

X -> Q/K/V -> split heads -> scores -> weights -> head outputs -> merged output

Check 3 — Cache correctness

Run decode twice: - once by recomputing the full prefix each step - once using cache

Confirm the logits match.

Check 4 — Interview explanation

Record a 3-5 minute verbal explanation of your implementation. If you cannot explain it aloud, you do not own it yet.

Stretch goals

  • Add causal attention heatmap visualization.
  • Add RoPE after you finish the base version.
  • Explain grouped-query attention versus full multi-head attention.
  • Measure cache memory use for context lengths 128, 512, and 2048.

Common failure modes

Failure What to inspect first
low loss but nonsense generation missing or wrong mask orientation
shape mismatch in score computation wrong transpose of K
output shape stuck at [B, H, T, d_head] forgot merge step
cache grows incorrectly concatenating along wrong axis
outputs differ between naive and cached decode position handling or stale cache

Why this hands_on_lab matters

This is the decoder core. If you can build this calmly, then Module 05's training pipeline discussion lands on solid ground. If you cannot, Module 05 becomes vocabulary without mechanics.

Bridge forward

Next module — 05_llm_training_pipeline — covers how these transformer blocks get trained at scale: pretraining, SFT, RLHF, and the full pipeline from raw text to deployed model.