Skip to content

01. Week 4 — Causal Attention & Coding

Key concepts to master

  • causal masking is an information firewall for autoregressive training
  • lower-triangular mask means position i sees only <= i
  • masking happens on attention logits before softmax
  • Q, K, and V are separate learned projections of the same hidden states
  • multi-head attention = split width, run parallel attentions, concatenate, output project
  • KV cache stores past keys and values for fast decoding
  • prefill and decode are different inference phases
  • tensor shape tracking is a first-class engineering skill

🧠 Mental models

  • Causal mask: "A one-way privacy screen that lets each position read only its past."
  • Q/K/V projections: "Three views of the same token: what I am asking, what I contain, and what I can pass along."
  • Multi-head attention: "Slice the hidden width into parallel workbenches, run attention separately, then bolt the pieces back together."
  • KV cache: "A growing append-only ledger of past keys and values during decoding."
  • Prefill vs decode: "Prefill reads the whole prompt once; decode adds one new step at a time."
  • Tensor shape tracking: "A type system for dimensions — if the shape story is wrong, the code is wrong."

⚠️ Common traps

  • Applying the causal mask after softmax instead of on the logits, which leaks probability mass.
  • Building the wrong broadcast shape for masks over [batch, heads, query_len, key_len].
  • Mixing up reshape and transpose order when splitting into heads or stitching heads back together.
  • Concatenating cache tensors on the wrong axis, or caching queries instead of keys and values.
  • Forgetting that prefill handles many query tokens while decode usually handles one query token against long cached K/V.
  • Losing track of tensor layout assumptions when optimizing attention code.

🔗 Prerequisites & connections

  • Builds on: Module 03 transformer block structure, Module 02 attention math, and Module 01 tensor/backprop basics.
  • Feeds into: training and serving tradeoffs, inference latency analysis, and systems-level optimization in Modules 05-06.

💬 Interview phrasing

  • "Where exactly do you apply a causal mask in attention, and why there?"
  • "Given hidden states of shape [B, T, D], walk me through the Q/K/V and multi-head reshape pipeline."
  • "What does a KV cache store, and what changes between prefill and decode?"
  • "How would you debug a shape mismatch in batched multi-head attention?"
  • "Why can a model seem to train fine even if masking is subtly wrong?"

⏱️ Difficulty markers

  • 🟢 causal masking
  • 🟡 Q/K/V projections
  • 🟡 prefill vs decode
  • 🔴 multi-head tensor reshaping
  • 🔴 KV cache implementation
  • 🔴 tensor shape tracking

Self-check questions

  1. Why can unmasked training produce low loss but poor generation?
  2. What is the correct causal mask shape in the batched multi-head case?
  3. Why do we need separate Q/K/V projections?
  4. What exact reshape and transpose steps create heads?
  5. What does KV cache store, and what trade-off does it introduce?
  6. Walk through one transformer block from input embeddings to logits.

Bridge forward

Next module — 05_llm_training_pipeline — covers how these transformer blocks get trained at scale: pretraining, SFT, RLHF, and the full pipeline from raw text to deployed model.