09. Transformer block shapes — keep the axes steady¶
~8 min read. Track shapes once. Then every decoder block feels predictable.
Built on the ELI5 in 00-eli5.md. The answer sheet keeps the time axis fixed through the block. The parallel graders split model width, then merge it back.
1) Start with ids, embeddings, and positions¶
A decoder begins with token ids.
input_ids : [B, T]
B is batch size.
T is token count on the answer sheet.
Embedding lookup adds model width.
token_emb : [B, T, D]
Position information matches the same time axis.
pos_emb : [1, T, D] or pos_emb : [B, T, D]
Addition keeps the outer frame unchanged.
hidden_states = token_emb + pos_emb : [B, T, D]
The first habit is simple.
Batch stays outside.
Time stays outside.
Width lives in the last axis.
No heads exist yet.
No cache exists yet.
┌────────────────────────────────────────────────────┐
│ input_ids [B, T] ▼ token_emb + pos_emb ▼ hidden │
│ attention ──→ + residual ──→ norm │
│ FFN [D] ──→ [4D] ──→ [D] ──→ + residual ──→ norm │
│ × N layers ──→ logits [B, T, V] │
└────────────────────────────────────────────────────┘
2) Attention reshapes width, not token count¶
Attention starts from hidden_states : [B, T, D].
The projection lens makes three views.
Q : [B, T, D]
K : [B, T, D]
V : [B, T, D]
Heads split width into smaller channels.
Q_heads : [B, H, T, d_head]
K_heads : [B, H, T, d_head]
V_heads : [B, H, T, d_head]
Usually D = H × d_head.
Nothing created extra tokens.
Nothing changed batch size.
Scores compare each query position against visible key positions.
scores : [B, H, T, T]
The exam rule mask broadcasts over batch and heads.
causal_mask : [1, 1, T, T]
Softmax keeps the same score shape.
weights : [B, H, T, T]
Weighted values return per-head outputs.
context_heads : [B, H, T, d_head]
Merging heads restores model width.
context : [B, T, D]
Output projection still returns [B, T, D].
That is the whole attention shape story.
Time compares with time.
Width splits, then rejoins.
3) FFN widens channels, then shrinks them back¶
After attention, the residual stream stays [B, T, D].
The FFN acts independently at each token position.
It mixes channels, not sequence slots.
A common decoder pattern looks like this.
ffn_up : [B, T, 4D]
activation : [B, T, 4D]
ffn_down : [B, T, D]
residual_add : [B, T, D]
layer_norm : [B, T, D]
The temporary expansion happens only in width.
B stays fixed.
T stays fixed.
Residual paths keep debugging manageable.
Layer norms also preserve the outer frame.
Attention mixes across time.
The FFN mixes within each position.
4) One worked example with real numbers¶
Take B=2, T=4, D=512, H=8, d_head=64.
Start with ids.
input_ids : [2, 4]
Embedding lookup gives [2, 4, 512].
Positional terms can be [1, 4, 512].
After addition, hidden states stay [2, 4, 512].
The projection lens makes three tensors.
Q, K, V : [2, 4, 512]
Head splitting uses 512 = 8 × 64.
Q_heads, K_heads, V_heads : [2, 8, 4, 64]
Score matrices compare four query slots against four key slots.
scores : [2, 8, 4, 4]
Masking allows row one to see column one only.
Masking allows row four to see columns one through four.
Softmax keeps [2, 8, 4, 4].
Weighted values return [2, 8, 4, 64].
Merging heads returns [2, 4, 512].
The output projection keeps [2, 4, 512].
The FFN expands to [2, 4, 2048].
The FFN projects back to [2, 4, 512].
After many blocks, hidden states still keep [2, 4, 512].
Vocabulary projection changes only the last axis.
logits : [2, 4, vocab_size]
For vocab_size = 32000, logits become [2, 4, 32000].
The lesson is stable.
Track the outer frame first.
Then inspect temporary width changes.
Where this lives in the wild¶
- GPT-2 Small — shape tracing often starts here because
D=768andH=12stay approachable. - LLaMA-2 7B — serving teams track
[B, T, 4096]streams and[B, 32, T, 128]heads constantly. - Mistral 7B — sliding-window attention changes visible history, but block shape habits still hold.
- Phi-2 — smaller width makes tensor debugging easier while preserving the same decoder logic.
- TensorRT-LLM — inference kernels are tuned around these exact split, score, merge, and FFN shapes.
Pause and recall¶
- Why do
BandTusually stay unchanged through one decoder block? - Why do attention scores need shape
[B, H, T, T]? - Where does the exam rule appear in the shape walk?
- Which tensors use the split-head form needed by the later memory shortcut?
Interview Q&A¶
Q1. Why do attention scores have shape [B, H, T, T]?
Because each query position compares against every visible key position inside each head.
Common wrong answer to avoid: Saying scores keep hidden width, so they stay [B, T, D].
Q2. What dimensions usually stay constant through one block?
B and T stay fixed, while width changes only inside attention views and the FFN.
Common wrong answer to avoid: Saying masking shrinks the sequence length after attention.
Q3. Why is D = H × d_head useful?
It lets the parallel graders partition width cleanly without inventing new tokens.
Common wrong answer to avoid: Saying each head receives a full copy of width D.
Apply now (5 min)¶
Quick exercise.
Take B=1, T=6, D=384, H=6.
Compute d_head.
Then write the major shapes from input_ids to logits.
Sketch from memory.
Draw the residual stream as [B, T, D].
Mark where Q, K, and V split into heads.
Mark where the causal mask lands on scores.
Underline the fixed answer sheet axis.
Add one note for where the later memory shortcut attaches.
Bridge. Good. The block shapes are now stable. Next we reuse split-head K and V instead of recomputing them. → 10-kv-cache.md