06. Multi-head split and merge — one wide stream becomes many narrow views¶
~10 min read. Most transformer shape bugs begin here. Get this reshape right once.
Built on the ELI5 in 00-eli5.md. The parallel graders appear because one wide tensor gets reorganised into several head-sized views.
1) Why one head is not enough¶
- A single head gets one attention pattern budget.
- Language usually needs several pattern budgets.
- One token may need local grammar.
- The same token may need long-range topic tracking.
- It may need entity matching too.
- It may need quote matching too.
- So one head becomes limiting.
- Multi-head attention solves that limit.
- In the exam hall picture,
- we hire several parallel graders.
- One grader can focus on syntax.
- Another can focus on entities.
- Another can focus on discourse callbacks.
- Another can focus on nearby rhythm.
- The exam rule stays identical for every head.
- No head may look forward.
- The answer sheet stays identical too.
- Every head sees the same token positions.
- The projection lens prepares Q, K, and V first.
- Splitting happens after that preparation.
- The memory shortcut arrives later during decoding.
- It will cache per-head keys and values.
- So multi-head does not replace earlier ideas.
- It builds on them.
2) Splitting means dividing width, not inventing width¶
- Suppose d_model = 8.
- Suppose num_heads = 2.
- Then d_head = 8 / 2 = 4.
- Nothing new was created.
- One length-8 vector became two length-4 slices.
- Each slice becomes one head view.
- The split is bookkeeping, not magic.
- The code is short.
- Reshape groups channels into head-sized chunks.
- Transpose moves the head axis forward.
- Attention math wants that layout.
- After split, the shape becomes [B, H, T, d_head].
- That gives each head one clean lane.
- Each parallel grader now has its own slice.
- The exam rule will later mask inside each lane.
- The answer sheet still contributes the same T positions.
- Forgetting the transpose is the classic bug.
- Using the wrong axis order is the second classic bug.
- If shapes feel confusing,
- write them after every line.
- A good habit is shape comments beside every tensor line.
- Another good habit is checking divisibility before reshaping anything.
- Head count changes layout, not sequence order.
- Token order survives the split unchanged.
- Channel order survives too.
- We simply expose a new indexing axis.
- That extra axis makes per-head math natural.
- It also makes later masking and caching cleaner.
3) One worked numerical walkthrough¶
- Take x with shape [1, 3, 8].
- So B = 1.
- So T = 3.
- So D = 8.
- Let the values be these.
- Choose num_heads = 2.
- Then d_head = 4.
- First reshape to [1, 3, 2, 4].
- Each token row now has two chunks.
- Token 0 gives [1, 2, 3, 4] and [5, 6, 7, 8].
- Token 1 gives [9, 10, 11, 12] and [13, 14, 15, 16].
- Token 2 gives [17, 18, 19, 20] and [21, 22, 23, 24].
- Then transpose axes to [1, 2, 3, 4].
- Now head 0 sees all token slices together.
- Now head 1 sees all token slices together.
┌────────────────────────────┐ ┌──────────────────────────────────────┐ │ Before split: [B, T, D] │ │ After split: [B, H, T, d_head] │ │ t0 │ 1 2 3 4 │ 5 6 7 8 │ │ ──→ │ head 0 │ [1 2 3 4] [9 a b c] [h i j k] │ │ t1 │ 9 a b c │ d e f g │ │ ├────────┼──────────────────────────────┤ │ t2 │ h i j k │ l m n o │ │ │ head 1 │ [5 6 7 8] [d e f g] [l m n o] │ └────────────────────────────┘ └──────────────────────────────────────┘ - No number changed anywhere.
- Only the view changed.
- That is the whole trick.
- The projection lens made the width first.
- The parallel graders now receive separate slices.
- Later, the memory shortcut can store those per-head tensors.
4) Merging is the exact reverse¶
- After attention, head outputs have shape [B, H, T, d_head].
- The next layer expects width D again.
- So we reverse the earlier reordering.
- First move token axis ahead of heads.
- Then flatten the head widths back together.
- Merge does not average heads.
- Merge concatenates them back into one vector.
- That detail matters in interviews.
- Suppose head 0 rows are [1, 2, 3, 4], [9, 10, 11, 12], [17, 18, 19, 20].
- Suppose head 1 rows are [5, 6, 7, 8], [13, 14, 15, 16], [21, 22, 23, 24].
- After transpose back, token 0 becomes two adjacent chunks.
- After reshape, token 0 becomes [1, 2, 3, 4, 5, 6, 7, 8].
- Token 1 becomes [9, 10, 11, 12, 13, 14, 15, 16].
- Token 2 becomes [17, 18, 19, 20, 21, 22, 23, 24].
- Merge restores the original channel order.
- The numbers survive untouched.
- The indexing changes back.
- That is why split and merge are perfect inverses.
5) Common bugs to catch fast¶
- d_model must divide evenly by num_heads.
- Otherwise d_head is invalid.
- Splitting before Q, K, and V is wrong.
- Heads should split projected features, not raw embedding chunks.
- Forgetting the split transpose breaks later score math.
- Forgetting the merge transpose scrambles channels.
- Applying the exam rule on the wrong axes also breaks shapes.
- Cached tensors in the memory shortcut must keep head layout consistent.
- The fastest debugging move is shape printing.
- Check projected tensors first.
- Check split tensors second.
- Check scores third.
- Check weights fourth.
- Check merged output last.
- If weights are not [B, H, T, T], stop immediately.
- If head outputs are not [B, H, T, d_head], stop immediately.
- If merged output is not [B, T, D], stop immediately.
- In multi-head code, shapes are logic.
- They are not decoration.
- Remember the full story.
- The projection lens prepares features.
- The parallel graders get separate lanes.
- The exam rule blocks future cheating.
- The answer sheet fixes the visible token count.
- The memory shortcut stores old per-head facts later.
- This is why head layout appears in every serious implementation.
- It is not framework ceremony.
- It is the contract that makes attention math line up.
- Once this contract is right,
- the rest of multi-head code becomes much easier.
- Once this contract is wrong,
- every later tensor looks suspicious.
- Learn this rearrangement once.
- Then reuse it everywhere.
Where this lives in the wild¶
- GPT-2 — 12 heads split width 768 into d_head = 64 slices.
- BERT-base — 12 heads each operate on width 64 after projection.
- LLaMA-2 — decoder layers split model width into many per-head lanes.
- Vision Transformer — image patches get split into multiple head views too.
- Mistral 7B — decoder attention keeps the same split-then-merge rhythm.
Pause and recall¶
- Why does reshape alone not finish the head split?
- In the example, what slice becomes head 1 for token 2?
- Why must d_model be divisible by num_heads?
- Why is merge not the same as averaging heads?
Interview Q&A¶
Q1. Why do we transpose after reshaping into heads? - We move the head axis forward so each head gets a clean [T, d_head] view. Common wrong answer to avoid: "Transpose is only for prettier printing." Q2. What does split_heads do to the numbers? - It changes indexing, not values. - The same numbers are simply regrouped by head. Common wrong answer to avoid: "split_heads learns new features by itself." Q3. Why is merge_heads the reverse of split_heads? - Attention runs per head, but the next layer expects one combined width again. Common wrong answer to avoid: "merge_heads averages every head together."
Apply now (5 min)¶
- Quick exercise.
- Take a [1, 2, 6] tensor and split it into 3 heads.
- Write the reshaped tensor explicitly.
- Then write the transposed tensor explicitly.
- Merge it back and confirm the original order returns.
- Sketch from memory.
- Draw the width slices first.
- Then draw the head-major layout second.
Bridge. Good. We can now split and merge correctly. Next we code the full multi-head causal attention path. → 07-multi-head-coding.md