02. Causal Attention & Coding — Narrative Explainer¶
Companion to 03_study_material.md. The study material gives the compressed formulas. This file gives the moving picture.
Read this before you code. Then keep it beside your editor.
Table of contents¶
- ELI5 — the exam hall version (start here)
- Quick audit — what Module 05 quietly assumes
- Chapter 1: Opening failure
- 1.1 Attention without masking lets the model cheat
- 1.2 Why this matters in interviews and production
- Chapter 2: Causal masking
- 2.1 The autoregressive contract
- 2.2 The lower-triangular mask
- 2.3 Why
-infmust be added before softmax - 2.4 NumPy implementation
- 2.5 Numerical walkthrough
- 2.6 Common masking bugs
- Chapter 3: Q/K/V projections from scratch
- 3.1 One token, three jobs
- 3.2 Linear layers and tensor shapes
- 3.3 Numerical walkthrough
- 3.4 Why separate projections matter
- 3.5 Batched coding pattern
- Chapter 4: Multi-head attention coding
- 4.1 Why multiple heads exist
- 4.2 Split → compute → concatenate
- 4.3 Output projection
- 4.4 Full code path and shape trace
- 4.5 Common multi-head bugs
- Chapter 5: KV cache for inference
- 5.1 Why naive decoding wastes work
- 5.2 Cache mechanics
- 5.3 Memory cost
- 5.4 Cache bugs to expect
- Retrieval prompts
- Honest admission
- Chapter 6: Recap & application
- 6.1 Failure-fix chain
- 6.2 Full transformer block shape walk
- 6.3 Interview questions
- 6.4 Production experience
- 6.5 Apply now — exercises
- 6.6 Bridge to Module 05
ELI5 — the exam hall version¶
Imagine a student writing a long exam. She can re-read her own previous answers. She cannot peek at the next answer. That is the whole spirit of causal attention. Each new answer depends on the earlier answers. So the student keeps scanning backwards. Never forwards. That exam rule is the first big idea. Now give the story some fixed names. We will keep using these names repeatedly. - The exam rule = the causal mask. - The answer sheet = the context window. - The memory shortcut = the KV cache. - The projection lens = the Q/K/V projections. - The parallel graders = the multi-head attention heads. Suppose the student has already written:
Now she wants to write token four. She may look at tokens one, two, and three. She may not look at token four's future neighbours. Because they do not exist yet. The answer sheet is only so large. Maybe 128 tokens. Maybe 4096 tokens. Maybe more. But it is still a sheet with edges. That edge is the context window. The student does not re-read the entire notebook every second. She keeps sticky notes for earlier useful facts. Those sticky notes are the memory shortcut. That shortcut is the KV cache. The student also changes glasses depending on the job. One pair asks, "What am I searching for?" One pair asks, "What clues do I contain?" One pair asks, "What information should I share?" Those are the projection lenses. They become queries, keys, and values. Finally, imagine five graders sitting together. One grader checks subject continuity. One checks punctuation. One checks names and entities. One checks local grammar. One checks long-range callback words. Those are the parallel graders. That is multi-head attention. So the model is doing something very human-looking. It reads its own past. It keeps useful notes. It uses different lenses. It lets different graders focus on different patterns. Then it writes one more token. Now the important warning. If you accidentally let the student peek ahead, then training becomes an open-book answer-copying contest. Loss looks fantastic. Generation becomes terrible. Because the model learned to cheat. Not to predict. That is why causal attention feels subtle at first. The math is short. The discipline is not. You must enforce the exam rule everywhere. Good. Now leave the classroom picture in your head. We will write the actual matrix code.Quick audit — what Module 05 quietly assumes¶
Before you move ahead, Module 05 will quietly assume these are already automatic.
- You can implement scaled dot-product attention without a library helper.
- You know the causal mask shape for training: [T, T] or broadcasted [1, 1, T, T].
- You can explain why masking happens before softmax.
- You can derive Q, K, and V from the same input tensor.
- You can split [B, T, D] into heads and merge it back correctly.
- You can track tensor shapes through one transformer block.
- You understand what KV cache stores at inference.
- You know the compute-versus-memory trade-off of caching.
If even two bullets still feel foggy, this module is not optional revision. It is a gap-closing module.
Chapters 2 to 5 exist exactly for these gaps. Take them seriously.
Chapter 1: Opening failure¶
1.1 Attention without masking lets the model cheat¶
Here is the classic mistake. You implement attention exactly once. It runs. The shapes look right. The loss crashes downward. You feel proud for thirty minutes. Then generation looks like nonsense. The broken code often looks this innocent:
import numpy as np
def softmax(x, axis=-1):
x = x - np.max(x, axis=axis, keepdims=True)
exp = np.exp(x)
return exp / np.sum(exp, axis=axis, keepdims=True)
def bad_attention(q, k, v):
d_k = q.shape[-1]
scores = q @ k.T / np.sqrt(d_k)
weights = softmax(scores, axis=-1)
return weights @ v
i can now attend to tokens j > i. That means the answer is present inside the input. The model no longer learns prediction. It learns leakage.
Take a toy sequence:
At position zero, we want I to help predict love. At position one, we want [I, love] to help predict causal. At position two, we want [I, love, causal] to help predict attention.
But unmasked attention gives position zero visibility into love, causal, and attention. That is illegal information. The exam paper is open at future pages.
So train-time reality becomes this lie:
Then inference arrives. Now the future tokens are absent. Because the model is generating them step by step. The shortcut disappears. The model panics. Output quality collapses.
This mismatch is the real bug. Not just a wrong matrix. A wrong training game.
The model solved an easier task than the one you wanted. That is why loss can be excellent. And outputs can still be garbage.
This exact failure shows up in interviews. An interviewer may ask, "Why can training loss be low but generation bad?" If you say only "overfitting," you miss the sharper answer here. Sometimes the model simply saw the future.
1.2 Why this matters in interviews and production¶
Causal self-attention is probably the most common LLM coding question. Especially for AI engineer, applied scientist, and inference engineer roles. Why this one? Because it tests several real skills together. - Matrix multiplication intuition. - Tensor shape discipline. - Numerical stability discipline. - Ability to explain training versus inference mismatch. - Ability to debug a silent correctness bug. It also matters in production. A missing mask does not always explode. That is why it is dangerous. It often gives believable metrics first. Then betrays you later. Senior engineers learn this lesson repeatedly. The most expensive bugs are often silent bugs. Not loud ones. So keep the opening failure in mind. Every later mechanism is a guardrail against that failure.
Chapter 2: Causal masking¶
2.1 The autoregressive contract¶
Autoregressive means one thing. Predict token t using tokens <= t only. No future leakage. No exceptions.
For a sequence of length four, the legal visibility looks like this:
1 means allowed. A 0 means blocked.
That is a lower-triangular matrix. Lower triangle allowed. Upper triangle forbidden.
This is the exam rule in matrix form. Very plain. Very strict.
2.2 The lower-triangular mask¶
In NumPy, we usually build it like this:
Forseq_len = 4, this returns:
array([
[ True, False, False, False],
[ True, True, False, False],
[ True, True, True, False],
[ True, True, True, True],
])
[T, T].
But real models are usually batched. And multi-headed. So attention scores often have shape:
Where:
- B = batch size.
- H = number of heads.
- T = sequence length.
So we broadcast the mask to match. A common pattern is:
Now mask shape is:
Broadcasting stretches it over every batch item. And every head. That is exactly what we want.
2.3 Why -inf must be added before softmax¶
Attention starts with scores. Then softmax converts scores into probabilities. So the mask must act on the scores. Not on the already-normalized weights. Correct pattern:
Why-1e9 or -inf? Because softmax of a huge negative number is effectively zero. So forbidden positions receive zero probability mass.
Think row-wise. Suppose one score row is:
And the fourth position is illegal. After masking:
After softmax, that last probability becomes zero. Not small by accident. Zero by construction.
If you mask after softmax, you create a new problem. The remaining probabilities no longer sum to one. You must renormalize. And many beginners forget.
So the clean rule is this. Mask logits. Then softmax. Never the other way around.
2.4 A clean NumPy implementation¶
Start with the single-head case. That makes the logic obvious.
import numpy as np
def softmax(x, axis=-1):
x = x - np.max(x, axis=axis, keepdims=True)
exp = np.exp(x)
return exp / np.sum(exp, axis=axis, keepdims=True)
def causal_mask(seq_len):
return np.tril(np.ones((seq_len, seq_len), dtype=bool))
def causal_attention(q, k, v):
d_k = q.shape[-1]
scores = q @ k.T / np.sqrt(d_k) # [T, T]
mask = causal_mask(q.shape[0]) # [T, T]
scores = np.where(mask, scores, -1e9)
weights = softmax(scores, axis=-1) # [T, T]
out = weights @ v # [T, d_v]
return out, weights
q, k, and v are already projected. This function only does scoring, masking, softmax, and weighted averaging.
Now the batched multi-head version. Same logic. Only more axes.
import numpy as np
def causal_attention_batched(q, k, v):
# q, k, v: [B, H, T, d_head]
B, H, T, d_head = q.shape
scores = q @ k.transpose(0, 1, 3, 2) / np.sqrt(d_head) # [B, H, T, T]
mask = np.tril(np.ones((T, T), dtype=bool))[None, None, :, :]
scores = np.where(mask, scores, -1e9)
weights = softmax(scores, axis=-1)
out = weights @ v # [B, H, T, d_head]
return out, weights
2.5 Numerical walkthrough with real numbers¶
Let sequence length be four. Let value dimension be two. Suppose the raw attention scores are:
scores =
[[ 2.0, 1.0, 0.0, -1.0],
[ 1.0, 3.0, 2.0, 0.0],
[ 2.0, 1.0, 4.0, 3.0],
[ 0.0, 1.0, 2.0, 3.0]]
4.0.
So the final weights are approximately:
Good. The future token got exactly zero weight. That is the entire point.
Now compute the context vector. It is the weighted sum of allowed values.
Which gives approximately:
Notice something lovely here. The giant future value [9, 9] never enters. The mask prevented leakage. Not by luck. By design.
2.6 ASCII shape picture¶
Keep this picture in your head. It saves hours.
Q: [B, H, T, d]
K: [B, H, T, d]
V: [B, H, T, d]
K^T: [B, H, d, T]
Q @ K^T -> [B, H, T, T] attention scores
mask -> [1, 1, T, T] lower triangle
softmax -> [B, H, T, T] attention weights
weights @ V -> [B, H, T, d]
2.7 Common masking bugs¶
Here are the usual ways engineers break causal attention.
1. Using np.triu instead of np.tril.
2. Applying the mask after softmax.
3. Forgetting to broadcast over batch and heads.
4. Transposing the wrong axes of K.
5. Using a float mask with wrong values.
6. Accidentally allowing future positions in packed batches.
7. Confusing training mask shape with decode-time cache shape.
Two bugs deserve special warning.
Bug one: wrong triangle orientation. If you use the upper triangle, you literally allow the future and block the past. Generation becomes absurd.
Bug two: all-masked rows. This can produce NaNs. Standard causal self-attention avoids that, because the diagonal remains visible. Each token can always see itself.
2.8 ELI5 callback¶
Return to the exam hall. The exam rule is not a philosophy. It is a stencil. You place that stencil over every score sheet. Then softmax only grades visible answers. That is all causal masking is. A very disciplined stencil.
Chapter 3: Q/K/V projections from scratch¶
3.1 One token, three jobs¶
A token embedding enters attention once. But attention asks it to do three different jobs.
- Ask a question.
- Advertise what it contains.
- Carry actual content forward.
One raw vector is not ideal for all three. So we learn three separate linear projections.
Given input tensor X, we compute:
3.2 Linear layers and shapes¶
Suppose the input tensor is:
Where: -B = batch size.
- T = sequence length.
- D = model dimension, often called d_model.
If each projection keeps the same width, then each weight matrix has shape:
So the outputs become:
Simple picture:
When multi-head attention appears, we later split that final dimension D into H heads. For now, just hold the single-head view first.
3.3 Small NumPy code for the projections¶
import numpy as np
def linear(x, w, b=None):
y = x @ w
if b is not None:
y = y + b
return y
B, T, D = 2, 4, 8
X = np.random.randn(B, T, D)
W_Q = np.random.randn(D, D)
W_K = np.random.randn(D, D)
W_V = np.random.randn(D, D)
Q = linear(X, W_Q)
K = linear(X, W_K)
V = linear(X, W_V)
3.4 Numerical walkthrough with one token¶
Let one token embedding be:
Let the projection matrices be:W_Q =
[[1, 0],
[0, 1],
[1, 1],
[0, 1]]
W_K =
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 1, -1]]
W_V =
[[0, 1],
[1, 0],
[1, 1],
[1, 0]]
x as a row vector. Now multiply.
For the query:
For the key:
For the value:
So the same token became three different vectors. Not because the token changed. Because the job changed.
This is the heart of attention engineering. Represent the same object differently for different subproblems.
3.5 Two-token picture¶
Now add one more token. Suppose:
After projection, you might get: Now queryq₂ compares itself with both keys. That produces attention scores. Then those scores weight the values.
Notice the separation carefully. Matching happens between Q and K. Content transfer happens through V.
If you collapse this distinction mentally, attention stays foggy.
3.6 Why not use one shared projection?¶
Good question. Many people ask it.
Suppose you used one matrix W for everything. Then:
3.7 Shapes before and after projection¶
Keep the shape story boring. Boring shapes are good.
Nothing fancy yet. The shape explosion comes when you split heads. That is next chapter.3.8 Batched coding pattern you should memorize¶
def project_qkv(X, W_Q, W_K, W_V, b_Q=None, b_K=None, b_V=None):
Q = X @ W_Q
K = X @ W_K
V = X @ W_V
if b_Q is not None:
Q = Q + b_Q
if b_K is not None:
K = K + b_K
if b_V is not None:
V = V + b_V
return Q, K, V
3.9 Common Q/K/V bugs¶
- Mixing up
Dandd_headtoo early. - Forgetting that
Klater needs a transpose. - Reusing one weight matrix accidentally.
- Reshaping before projection instead of after.
- Dropping batch dimension during experiments.
- Adding bias with wrong broadcast shape. When debugging, print shapes after every line. Especially when writing from scratch. Experts do this too. Not only beginners.
3.10 ELI5 callback¶
Return to the classroom story. The student is still the same student. But she now uses three different lenses. One lens asks questions. One lens labels what earlier answers contain. One lens copies content into the new answer. That is all Q, K, and V are. Different lenses on the same answer sheet.
Chapter 4: Multi-head attention coding¶
4.1 Why multiple heads exist¶
Single-head attention can work. But it has one pattern budget. One similarity space. One kind of focus at a time. Real language needs more. One head may track nearby syntax. Another may track long-range subject agreement. Another may track quote boundaries. Another may track entities. So we split model dimension into smaller subspaces. Each subspace gets its own attention computation. Then we merge the results. That is multi-head attention. Parallel graders. Same exam sheet. Different grading criteria.
4.2 Head dimensions¶
Suppose:
Then each head gets: So after projection, we reshape each ofQ, K, V from:
into:
The 2 is the head axis. The 4 is the per-head width.
4.3 Split-head code¶
This helper is worth memorizing.
def split_heads(x, num_heads):
# x: [B, T, D]
B, T, D = x.shape
d_head = D // num_heads
x = x.reshape(B, T, num_heads, d_head)
x = x.transpose(0, 2, 1, 3)
return x
def merge_heads(x):
# x: [B, H, T, d_head]
B, H, T, d_head = x.shape
x = x.transpose(0, 2, 1, 3)
x = x.reshape(B, T, H * d_head)
return x
4.4 The full multi-head causal attention path¶
Here is the full from-scratch path. This is the code interviewers usually want.
import numpy as np
def softmax(x, axis=-1):
x = x - np.max(x, axis=axis, keepdims=True)
exp = np.exp(x)
return exp / np.sum(exp, axis=axis, keepdims=True)
def split_heads(x, num_heads):
B, T, D = x.shape
d_head = D // num_heads
return x.reshape(B, T, num_heads, d_head).transpose(0, 2, 1, 3)
def merge_heads(x):
B, H, T, d_head = x.shape
return x.transpose(0, 2, 1, 3).reshape(B, T, H * d_head)
def multi_head_causal_attention(X, W_Q, W_K, W_V, W_O, num_heads):
B, T, D = X.shape
d_head = D // num_heads
Q = split_heads(X @ W_Q, num_heads) # [B, H, T, d_head]
K = split_heads(X @ W_K, num_heads) # [B, H, T, d_head]
V = split_heads(X @ W_V, num_heads) # [B, H, T, d_head]
scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_head) # [B, H, T, T]
mask = np.tril(np.ones((T, T), dtype=bool))[None, None, :, :]
scores = np.where(mask, scores, -1e9)
weights = softmax(scores, axis=-1) # [B, H, T, T]
head_out = weights @ V # [B, H, T, d_head]
merged = merge_heads(head_out) # [B, T, D]
out = merged @ W_O # [B, T, D]
return out, weights
4.5 ASCII tensor movie¶
Input X [B, T, D]
│
├─ X @ W_Q [B, T, D]
├─ X @ W_K [B, T, D]
└─ X @ W_V [B, T, D]
Split into heads
Q, K, V [B, H, T, d]
Scores = Q @ K^T [B, H, T, T]
Mask [1, 1, T, T]
Softmax weights [B, H, T, T]
Head outputs = weights @ V [B, H, T, d]
Merge heads [B, T, H*d] = [B, T, D]
Output projection @ W_O [B, T, D]
4.6 Numerical intuition with two heads¶
Suppose token two is generating its next representation. Head zero produces weights:
Head one produces: These are not redundant. They are two different views. Head zero strongly prefers token zero. Maybe it is tracking long-range topic. Head one strongly prefers token one. Maybe it is tracking local syntax. Each head outputs a vector of lengthd_head. Then we concatenate those vectors. So one token now carries multiple attended summaries. Not one.
That is the true power of multi-head attention. It preserves multiple relations side by side.
4.7 Why concatenate and then project again?¶
After all heads finish, we simply glue the head vectors together. That gives one wide vector.
But concatenation alone keeps head channels separate. So we add one more learned matrix W_O. That matrix mixes information across heads.
Without W_O, heads cannot easily coordinate downstream. With W_O, the model can recombine head-specific features into one shared representation.
So the full rhythm is:
4.8 One transformer block shape trace¶
Module 05 will assume this is easy for you. So let us make it easy now. Suppose input hidden states are:
Inside a GPT-style block:attention out : [B, T, D]
residual add : [B, T, D]
layer norm : [B, T, D]
ffn hidden (4*D) : [B, T, 4D]
ffn output : [B, T, D]
residual add : [B, T, D]
layer norm : [B, T, D]
4.9 Common multi-head bugs¶
Dnot divisible byH.- Reshaping before projection instead of after.
- Forgetting the transpose in
split_heads. - Forgetting the reverse transpose in
merge_heads. - Using
sqrt(D)instead ofsqrt(d_head). - Applying one shared mask with wrong shape.
- Returning
[B, H, T, d]without merging. - Forgetting the output projection
W_O. That fifth bug is sneaky. Scaling must use per-head dimension. Not full model width. Because each head computes its own dot products.
4.10 ELI5 callback¶
Return to the parallel graders. Each grader reads the same answer sheet. But each grader has different attention. Then the head teacher combines their notes. That head teacher is the output projection. So multi-head attention is not duplication. It is division of labour.
Chapter 5: KV cache for inference¶
5.1 Why naive decoding wastes work¶
Training is parallel across positions. Inference is not. At inference, we generate one token at a time. Suppose you already generated these tokens:
Now you want the next token. A naive implementation feeds the full prefix again:step 1: run attention on [The]
step 2: run attention on [The, cat]
step 3: run attention on [The, cat, sat]
step 4: run attention on [The, cat, sat, ?]
5.2 What exactly gets cached¶
For each layer, we cache past K and V tensors. We usually do not cache past Q. Because each new decode step only needs the new query.
So at time t, we compute:
- q_t for the new token.
- k_t for the new token.
- v_t for the new token.
Then we append k_t and v_t to the cache. Now the new query can attend over all cached keys and values.
That is the memory shortcut. Very practical. Very important.
5.3 Shapes inside the cache¶
For one layer, cache shapes usually look like this:
At the next step, new tensors have shape: After appending: Then scores become: Notice the beautiful simplification. The query length is one. Only the key length grows.5.4 Clean cache code¶
import numpy as np
class KVCache:
def __init__(self):
self.k = None
self.v = None
def append(self, new_k, new_v):
if self.k is None:
self.k = new_k
self.v = new_v
else:
self.k = np.concatenate([self.k, new_k], axis=2)
self.v = np.concatenate([self.v, new_v], axis=2)
def decode_step(x_t, cache, W_Q, W_K, W_V, W_O, num_heads):
B, T, D = x_t.shape
assert T == 1
d_head = D // num_heads
q = split_heads(x_t @ W_Q, num_heads) # [B, H, 1, d_head]
new_k = split_heads(x_t @ W_K, num_heads) # [B, H, 1, d_head]
new_v = split_heads(x_t @ W_V, num_heads) # [B, H, 1, d_head]
cache.append(new_k, new_v)
k_all = cache.k # [B, H, t, d_head]
v_all = cache.v # [B, H, t, d_head]
scores = q @ k_all.transpose(0, 1, 3, 2) / np.sqrt(d_head)
weights = softmax(scores, axis=-1) # [B, H, 1, t]
head_out = weights @ v_all # [B, H, 1, d_head]
merged = merge_heads(head_out) # [B, 1, D]
out = merged @ W_O # [B, 1, D]
return out
5.5 Do we still need a mask during one-token decoding?¶
Subtle question. Very good interview question. If your cache only contains past tokens plus the current token, then the future is absent by construction. So an explicit causal mask is not strictly necessary for that one-token step. But during prompt prefill, when you process a whole prompt in parallel, yes, you still need the causal mask. So remember the distinction: - Prefill phase: full prompt, masked attention. - Decode phase: one token at a time, cache usually removes the future automatically. That distinction impresses interviewers. Because it shows real inference understanding.
5.6 Why the speedup is large¶
Per decode step, naive full-prefix attention recomputes a growing square matrix. That gives roughly O(t^2) work for that step. With cache, you only score one query against t keys. That is roughly O(t) for that step.
Across the full generated sequence, naive repeated recomputation behaves like a cubic pile-up. Cached decoding drops that total cost significantly.
The practical takeaway is simpler. Without cache, long generations feel painfully slow. With cache, latency becomes usable.
5.7 Numerical intuition with four steps¶
Assume one layer, one head, and ignore constants. Naive recomputation of keys and values:
step 1: compute 1 old token again -> 1 unit
step 2: compute 2 old tokens again -> 2 units
step 3: compute 3 old tokens again -> 3 units
step 4: compute 4 old tokens again -> 4 units
step 1: compute only new token K/V -> 1 unit
step 2: compute only new token K/V -> 1 unit
step 3: compute only new token K/V -> 1 unit
step 4: compute only new token K/V -> 1 unit
5.8 Memory cost is the trade-off¶
Cache gives speed. It also consumes memory. You store keys and values for every layer, for every generated token, for every active sequence in the batch. A rough formula is:
The factor2 is for K and V.
Example:
Then memory is approximately:
So yes, cache is wonderful. And yes, cache can dominate inference memory.
This is why long-context serving is hard. Not only because of compute. Because of memory pressure.
5.9 Cache bugs to expect¶
- Forgetting to clear the cache between prompts.
- Appending along the wrong axis.
- Mismatching position ids with cache length.
- Reusing cache from a different batch ordering.
- Copying cache incorrectly during beam search.
- Storing cache on the wrong device.
- Forgetting that every layer needs its own cache. That first bug is very real. You will see ghost tokens from previous prompts. It feels spooky. It is only stale cache.
5.10 ELI5 callback¶
Return to the exam hall again. The student does not reread every old answer word by word. She keeps sticky notes. Those sticky notes are keys and values. When writing the next line, she consults the sticky notes. Not the whole notebook. That is the memory shortcut. Simple. Powerful. Expensive in memory.
Retrieval prompts¶
Use these without notes. Speak them aloud. Then write them from memory.
1. Explain causal masking using the exam rule and one 4x4 matrix.
2. Write the exact shapes from X [B, T, D] to scores [B, H, T, T].
3. Explain why masking after softmax is wrong or incomplete.
4. Explain Q, K, and V using the sentence, "who talks, who matches, what gets said."
5. Derive why KV cache speeds up decoding but increases memory.
6. Write split_heads and merge_heads from memory.
Honest admission¶
This module makes attention feel clean. Real production attention is still messier. You will later meet: - FlashAttention. - Grouped-query attention. - Multi-query attention. - Rotary position embeddings. - Paged KV cache. - Prefix caching. - Quantized KV cache. Those are real systems topics. They matter. But if the basics above are shaky, those optimizations become memorized jargon. So be honest with yourself. If you still need paper for shapes, that is fine. Use the paper. Shape fluency comes from repetition. Not ego.
Chapter 6: Recap & application¶
6.1 Failure-fix chain¶
Here is the compressed debugging table. Read it repeatedly.
| Failure | What you see | Root cause | Fix |
|---|---|---|---|
| Loss looks amazing, generation is nonsense | Train/test mismatch | Future leakage | Apply causal mask before softmax |
| Token cannot attend to itself | NaNs or weak outputs | Diagonal accidentally masked | Keep lower triangle including diagonal |
| Attention weights do not sum to one | Odd scaling, unstable outputs | Mask applied after softmax | Mask logits first, then softmax |
| Shape error at score computation | Matrix multiply crash | Wrong K transpose | Use K.transpose(0, 1, 3, 2) |
| Output still has head axis | Downstream layer breaks | Forgot merge step | Transpose back and reshape to [B, T, D] |
| Head outputs feel isolated | Weak expressivity | Missing W_O projection | Concatenate heads, then project |
| Decode latency is terrible | Each token gets slower | Full prefix recomputed each step | Add KV cache |
| Decode output leaks previous prompt | Weird cross-prompt contamination | Stale cache reused | Reset cache per request |
| Attention becomes too sharp or unstable | Softmax saturates | Used sqrt(D) wrong or no scaling | Divide by sqrt(d_head) |
| Interview answer sounds vague | Hand-wavy intuition | No shape tracking habit | Walk tensors line by line |
That table is the story of this module. Each mechanism exists because something broke.
6.2 Full transformer block shape walk¶
Let us trace a GPT-style block once. No skipping. Start with token ids:
Embedding lookup gives:token_embeddings : [B, T, D]
positional_embeddings : [1, T, D] or [B, T, D]
hidden_states : [B, T, D]
Q, K, V projections : [B, T, D]
split into heads : [B, H, T, d_head]
attention scores : [B, H, T, T]
causal mask : [1, 1, T, T]
attention weights : [B, H, T, T]
head outputs : [B, H, T, d_head]
merged attention output : [B, T, D]
output projection : [B, T, D]
residual add : [B, T, D]
layer norm : [B, T, D]
ffn hidden : [B, T, 4D]
nonlinearity : [B, T, 4D]
ffn output : [B, T, D]
residual add : [B, T, D]
layer norm : [B, T, D]
N such blocks:
That full walk matters. Module 05 will talk about training these blocks. If shapes still wobble here, training discussions stay superficial.
6.3 Key points to remember¶
- Causal masking is an information firewall.
- The mask is lower triangular.
- Mask scores before softmax.
- Q, K, and V are separate learned projections.
- Multi-head attention splits model width into parallel subspaces.
W_Omixes head outputs back together.- KV cache stores past keys and values.
- Cache speeds decoding by avoiding repeated work.
- Cache also consumes serious memory.
- Shape tracking is not optional in transformer code.
6.4 Important interview questions¶
Practice these aloud. Not silently.
1. Why does missing causal masking sometimes produce low training loss but bad generation?
2. Why do we need three projections W_Q, W_K, and W_V instead of one shared matrix?
3. What are the exact tensor shapes through multi-head self-attention?
4. Why do we divide by sqrt(d_head) in scaled dot-product attention?
5. Why does multi-head attention help more than one large head?
6. What exactly does KV cache store?
7. Why does cache reduce latency but increase memory usage?
8. During one-token decoding, when is an explicit mask still needed or not needed?
9. What breaks if you forget the output projection W_O?
10. Walk me through one GPT block from input ids to logits.
If you can answer these cleanly, you are already above many candidates.
6.5 Production experience notes¶
Here is what real engineering experience starts to feel like.
First, you stop treating attention as one formula. You treat it as a pipeline of invariants. Every invariant is testable.
Second, you start writing tiny checks. Examples:
- mask shape equals [1, 1, T, T] for batched heads.
- masked probabilities on future positions are exactly zero.
- merged output returns to [B, T, D].
- cache length increases by one each decode step.
- cache resets between requests.
Third, you become suspicious of nice metrics. A silent leakage bug can make curves look beautiful. Always pair metrics with behavioural tests.
Fourth, you learn that inference is a systems problem. Attention math is only half the story. Latency, memory, batching, and cache management matter equally.
That is why this module is career-relevant. It sits right at the border of theory and engineering.
6.6 Apply now — graded exercises¶
Warm-up¶
- Write
causal_mask(T)from memory. - For
T = 5, draw the exact allowed-attention matrix. - Explain in one sentence why upper-triangular masking is wrong for GPT.
Build¶
- Implement single-head causal attention in NumPy.
- Add shape asserts for every intermediate tensor.
- Create a toy example where masking zeroes out a future token.
- Implement
split_headsandmerge_heads. - Extend your code to full multi-head causal attention.
Stretch¶
- Add a tiny KV cache class for one decode step.
- Compare decode runtime with and without cache on longer sequences.
- Print cache memory usage for different context lengths.
- Trace every tensor shape through one transformer block on paper.
Interview drill¶
- Whiteboard the code without looking.
- Explain the same code to a beginner without using jargon.
- Explain the same code to a senior engineer using exact shapes. If you can do the three interview drills, this module has worked.
6.7 Bridge to Module 05¶
Now the next question becomes obvious. You can code the transformer block. Fine. How does a real organisation train these blocks at scale? How do they move from raw text to a deployed model? Next module — 05_llm_training_pipeline — covers how these transformer blocks get trained at scale: pretraining, SFT, RLHF, and the full pipeline from raw text to deployed model. That bridge matters. This module taught the engine. The next module teaches the factory.