Assignment 4 — Causal Self-Attention + KV Cache¶
This folder implements the Week 4 hands_on_lab from ../05_hands_on_lab.md.
Files¶
attention.py— Q/K/V projections, causal masking, split/merge heads, output projectioncache.py— KV cache container for inference-time decodeblock.py— one GPT-style block with residual, norm, and FFNgenerate.py— one-token decode loop with and without cacheverify.py— mask, shape, cache-growth, and cached-vs-naive checks
Shape flow¶
For input X with shape [B, T, d_model]:
X
-> Q/K/V projections [B, T, d_model]
-> split heads [B, H, T, d_head]
-> score matrix QK^T [B, H, T_q, T_k]
-> masked softmax weights [B, H, T_q, T_k]
-> weighted values [B, H, T_q, d_head]
-> merge heads [B, T_q, d_model]
-> output projection [B, T_q, d_model]
With d_model = 128 and num_heads = 4, each head uses d_head = 32.
Mask logic¶
The causal mask is lower triangular in full-sequence mode.
That means token t can see positions <= t, not future positions.
For cached one-token decode, the query length is 1 and the key length grows with the prefix.
The new token is allowed to read every cached key because all those positions are in its past.
What caching changes¶
Without cache:
- every new token recomputes K and V for the full prefix
With cache:
- old K and V are stored once
- each step computes Q/K/V only for the new token
- new K/V append along the sequence axis
The verification harness checks that cached and naive decode produce the same logits on the same prefix.
Notes¶
- This implementation keeps the scope focused on causal attention and cache mechanics.
- It uses token embeddings and one GPT-style block.
- It intentionally omits positional encodings so the cache comparison stays about attention correctness, not position bookkeeping.