Skip to content

Assignment 4 — Causal Self-Attention + KV Cache

This folder implements the Week 4 hands_on_lab from ../05_hands_on_lab.md.

Files

  • attention.py — Q/K/V projections, causal masking, split/merge heads, output projection
  • cache.py — KV cache container for inference-time decode
  • block.py — one GPT-style block with residual, norm, and FFN
  • generate.py — one-token decode loop with and without cache
  • verify.py — mask, shape, cache-growth, and cached-vs-naive checks

Shape flow

For input X with shape [B, T, d_model]:

X
-> Q/K/V projections              [B, T, d_model]
-> split heads                    [B, H, T, d_head]
-> score matrix QK^T              [B, H, T_q, T_k]
-> masked softmax weights         [B, H, T_q, T_k]
-> weighted values                [B, H, T_q, d_head]
-> merge heads                    [B, T_q, d_model]
-> output projection              [B, T_q, d_model]

With d_model = 128 and num_heads = 4, each head uses d_head = 32.

Mask logic

The causal mask is lower triangular in full-sequence mode. That means token t can see positions <= t, not future positions.

For cached one-token decode, the query length is 1 and the key length grows with the prefix. The new token is allowed to read every cached key because all those positions are in its past.

What caching changes

Without cache:

  • every new token recomputes K and V for the full prefix

With cache:

  • old K and V are stored once
  • each step computes Q/K/V only for the new token
  • new K/V append along the sequence axis

The verification harness checks that cached and naive decode produce the same logits on the same prefix.

Notes

  • This implementation keeps the scope focused on causal attention and cache mechanics.
  • It uses token embeddings and one GPT-style block.
  • It intentionally omits positional encodings so the cache comparison stays about attention correctness, not position bookkeeping.

Run

python3 verify.py
python3 generate.py