01. Week 4 — Causal Attention & Coding¶
Key concepts to master¶
- causal masking is an information firewall for autoregressive training
- lower-triangular mask means position
isees only<= i - masking happens on attention logits before softmax
- Q, K, and V are separate learned projections of the same hidden states
- multi-head attention = split width, run parallel attentions, concatenate, output project
- KV cache stores past keys and values for fast decoding
- prefill and decode are different inference phases
- tensor shape tracking is a first-class engineering skill
🧠 Mental models¶
- Causal mask: "A one-way privacy screen that lets each position read only its past."
- Q/K/V projections: "Three views of the same token: what I am asking, what I contain, and what I can pass along."
- Multi-head attention: "Slice the hidden width into parallel workbenches, run attention separately, then bolt the pieces back together."
- KV cache: "A growing append-only ledger of past keys and values during decoding."
- Prefill vs decode: "Prefill reads the whole prompt once; decode adds one new step at a time."
- Tensor shape tracking: "A type system for dimensions — if the shape story is wrong, the code is wrong."
⚠️ Common traps¶
- Applying the causal mask after softmax instead of on the logits, which leaks probability mass.
- Building the wrong broadcast shape for masks over
[batch, heads, query_len, key_len]. - Mixing up reshape and transpose order when splitting into heads or stitching heads back together.
- Concatenating cache tensors on the wrong axis, or caching queries instead of keys and values.
- Forgetting that prefill handles many query tokens while decode usually handles one query token against long cached K/V.
- Losing track of tensor layout assumptions when optimizing attention code.
🔗 Prerequisites & connections¶
- Builds on: Module 03 transformer block structure, Module 02 attention math, and Module 01 tensor/backprop basics.
- Feeds into: training and serving tradeoffs, inference latency analysis, and systems-level optimization in Modules 05-06.
💬 Interview phrasing¶
- "Where exactly do you apply a causal mask in attention, and why there?"
- "Given hidden states of shape
[B, T, D], walk me through the Q/K/V and multi-head reshape pipeline." - "What does a KV cache store, and what changes between prefill and decode?"
- "How would you debug a shape mismatch in batched multi-head attention?"
- "Why can a model seem to train fine even if masking is subtly wrong?"
⏱️ Difficulty markers¶
- 🟢 causal masking
- 🟡 Q/K/V projections
- 🟡 prefill vs decode
- 🔴 multi-head tensor reshaping
- 🔴 KV cache implementation
- 🔴 tensor shape tracking
Self-check questions¶
- Why can unmasked training produce low loss but poor generation?
- What is the correct causal mask shape in the batched multi-head case?
- Why do we need separate Q/K/V projections?
- What exact reshape and transpose steps create heads?
- What does KV cache store, and what trade-off does it introduce?
- Walk through one transformer block from input embeddings to logits.
Bridge forward¶
Next module — 05_llm_training_pipeline — covers how these transformer blocks get trained at scale: pretraining, SFT, RLHF, and the full pipeline from raw text to deployed model.