Skip to content

03. Week 4 — Causal Attention & Coding

Companion to 02_explainer.md. Use this file as the compact formula-and-pattern sheet.

Section 1 — The opening bug

If attention is unmasked during training, position i can see tokens j > i. That leaks the answer. Loss falls for the wrong reason. Generation quality collapses later.

Core mismatch: - training without mask = future visible - inference during generation = future absent

Section 2 — Causal mask

Autoregressive rule:

query position i may attend only to key positions <= i

Lower-triangular mask for T=4:

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]

NumPy:

def causal_mask(T):
    return np.tril(np.ones((T, T), dtype=bool))

Batched multi-head broadcast shape:

mask = causal_mask(T)[None, None, :, :]   # [1, 1, T, T]

Section 3 — Scaled dot-product attention

Formula:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V

Masked version:

scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_k)
scores = np.where(mask, scores, -1e9)
weights = softmax(scores, axis=-1)
out = weights @ V

Why sqrt(d_k)? Without scaling, dot products grow with dimension. Softmax saturates. Gradients become poor.

Why -1e9 or -inf before softmax? Because forbidden positions must receive probability zero.

Section 4 — Q/K/V projections

Given hidden states X [B, T, D]:

Q = X W_Q
K = X W_K
V = X W_V

Typical shapes:

W_Q, W_K, W_V : [D, D]
Q, K, V       : [B, T, D]

Intuition: - Query = what this token wants - Key = what this token can match - Value = what this token can contribute

Useful interview sentence: Queries and keys decide who talks to whom. Values decide what gets said.

Section 5 — Multi-head attention

If D = 512 and H = 8, then d_head = 64.

Shape flow:

X                     [B, T, D]
Q/K/V projections     [B, T, D]
split heads           [B, H, T, d_head]
scores                [B, H, T, T]
weights               [B, H, T, T]
head outputs          [B, H, T, d_head]
merge heads           [B, T, D]
output projection     [B, T, D]

Reference helpers:

def split_heads(x, H):
    B, T, D = x.shape
    d_head = D // H
    return x.reshape(B, T, H, d_head).transpose(0, 2, 1, 3)

def merge_heads(x):
    B, H, T, d_head = x.shape
    return x.transpose(0, 2, 1, 3).reshape(B, T, H * d_head)

Section 6 — KV cache

During one-token-at-a-time decoding, recomputing old keys and values is wasteful. So we cache them.

Per-layer cache shapes:

cache_k : [B, H, t, d_head]
cache_v : [B, H, t, d_head]

New step shapes:

new_q, new_k, new_v : [B, H, 1, d_head]

Append new_k, new_v to cache, then compute:

scores = new_q @ cache_k^T   -> [B, H, 1, t]

Trade-off: - pro: lower decode latency - con: higher inference memory use

Section 7 — Prefill vs decode

This distinction matters.

  • Prefill: process the whole prompt in parallel; causal mask required.
  • Decode: process one new token at a time; future tokens are absent, so cache structure itself enforces causality.

Section 8 — Common bugs

Symptom Likely cause Fix
Low loss, bad generation missing mask apply lower-triangular mask before softmax
Shape crash in score matmul wrong K transpose use transpose(..., 3, 2)
Wrong number of channels after attention forgot merge merge heads back to [B, T, D]
Slow decoding no KV cache store past K/V per layer
Weird outputs across prompts stale cache reset cache per request
NaNs in attention numerical issue or all-masked row keep diagonal, use stable softmax

Section 9 — Minimal reference implementation

import numpy as np

def softmax(x, axis=-1):
    x = x - np.max(x, axis=axis, keepdims=True)
    exp = np.exp(x)
    return exp / np.sum(exp, axis=axis, keepdims=True)

def causal_self_attention(X, W_Q, W_K, W_V, W_O, num_heads):
    B, T, D = X.shape
    d_head = D // num_heads

    Q = split_heads(X @ W_Q, num_heads)
    K = split_heads(X @ W_K, num_heads)
    V = split_heads(X @ W_V, num_heads)

    scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_head)
    mask = np.tril(np.ones((T, T), dtype=bool))[None, None, :, :]
    scores = np.where(mask, scores, -1e9)
    weights = softmax(scores, axis=-1)

    out = weights @ V
    out = merge_heads(out)
    out = out @ W_O
    return out, weights

Section 10 — Reference material

YouTube

Blogs

Self-check

  1. What exact bug appears if you forget causal masking?
  2. Why do we mask logits instead of masking probabilities afterward?
  3. What are the roles of Q, K, and V?
  4. What exact shape does the score tensor have in multi-head attention?
  5. Why does KV cache improve speed but hurt memory?
  6. What is the difference between prefill and decode?