Skip to content

06. Multi-head split and merge — one wide stream becomes many narrow views

~10 min read. Most transformer shape bugs begin here. Get this reshape right once.

Built on the ELI5 in 00-eli5.md. The parallel graders appear because one wide tensor gets reorganised into several head-sized views.


1) Why one head is not enough

  • A single head gets one attention pattern budget.
  • Language usually needs several pattern budgets.
  • One token may need local grammar.
  • The same token may need long-range topic tracking.
  • It may need entity matching too.
  • It may need quote matching too.
  • So one head becomes limiting.
  • Multi-head attention solves that limit.
  • In the exam hall picture,
  • we hire several parallel graders.
  • One grader can focus on syntax.
  • Another can focus on entities.
  • Another can focus on discourse callbacks.
  • Another can focus on nearby rhythm.
  • The exam rule stays identical for every head.
  • No head may look forward.
  • The answer sheet stays identical too.
  • Every head sees the same token positions.
  • The projection lens prepares Q, K, and V first.
  • Splitting happens after that preparation.
  • The memory shortcut arrives later during decoding.
  • It will cache per-head keys and values.
  • So multi-head does not replace earlier ideas.
  • It builds on them.

2) Splitting means dividing width, not inventing width

  • Suppose d_model = 8.
  • Suppose num_heads = 2.
  • Then d_head = 8 / 2 = 4.
  • Nothing new was created.
  • One length-8 vector became two length-4 slices.
  • Each slice becomes one head view.
  • The split is bookkeeping, not magic.
  • The code is short.
    def split_heads(x, num_heads):
        B, T, D = x.shape
        d_head = D // num_heads
        x = x.reshape(B, T, num_heads, d_head)
        return x.transpose(0, 2, 1, 3)
    
  • Reshape groups channels into head-sized chunks.
  • Transpose moves the head axis forward.
  • Attention math wants that layout.
  • After split, the shape becomes [B, H, T, d_head].
  • That gives each head one clean lane.
  • Each parallel grader now has its own slice.
  • The exam rule will later mask inside each lane.
  • The answer sheet still contributes the same T positions.
  • Forgetting the transpose is the classic bug.
  • Using the wrong axis order is the second classic bug.
  • If shapes feel confusing,
  • write them after every line.
  • A good habit is shape comments beside every tensor line.
  • Another good habit is checking divisibility before reshaping anything.
  • Head count changes layout, not sequence order.
  • Token order survives the split unchanged.
  • Channel order survives too.
  • We simply expose a new indexing axis.
  • That extra axis makes per-head math natural.
  • It also makes later masking and caching cleaner.

3) One worked numerical walkthrough

  • Take x with shape [1, 3, 8].
  • So B = 1.
  • So T = 3.
  • So D = 8.
  • Let the values be these.
    x[0] =
    [
      [ 1,  2,  3,  4,  5,  6,  7,  8],
      [ 9, 10, 11, 12, 13, 14, 15, 16],
      [17, 18, 19, 20, 21, 22, 23, 24]
    ]
    
  • Choose num_heads = 2.
  • Then d_head = 4.
  • First reshape to [1, 3, 2, 4].
  • Each token row now has two chunks.
  • Token 0 gives [1, 2, 3, 4] and [5, 6, 7, 8].
  • Token 1 gives [9, 10, 11, 12] and [13, 14, 15, 16].
  • Token 2 gives [17, 18, 19, 20] and [21, 22, 23, 24].
  • Then transpose axes to [1, 2, 3, 4].
  • Now head 0 sees all token slices together.
  • Now head 1 sees all token slices together.
    ┌────────────────────────────┐      ┌──────────────────────────────────────┐
    │ Before split: [B, T, D]    │      │ After split: [B, H, T, d_head]      │
    │ t0 │ 1 2 3 4 │ 5 6 7 8 │   │ ──→  │ head 0 │ [1 2 3 4] [9 a b c] [h i j k] │
    │ t1 │ 9 a b c │ d e f g │   │      ├────────┼──────────────────────────────┤
    │ t2 │ h i j k │ l m n o │   │      │ head 1 │ [5 6 7 8] [d e f g] [l m n o] │
    └────────────────────────────┘      └──────────────────────────────────────┘
    
  • No number changed anywhere.
  • Only the view changed.
  • That is the whole trick.
  • The projection lens made the width first.
  • The parallel graders now receive separate slices.
  • Later, the memory shortcut can store those per-head tensors.

4) Merging is the exact reverse

  • After attention, head outputs have shape [B, H, T, d_head].
  • The next layer expects width D again.
  • So we reverse the earlier reordering.
    def merge_heads(x):
        B, H, T, d_head = x.shape
        x = x.transpose(0, 2, 1, 3)
        return x.reshape(B, T, H * d_head)
    
  • First move token axis ahead of heads.
  • Then flatten the head widths back together.
  • Merge does not average heads.
  • Merge concatenates them back into one vector.
  • That detail matters in interviews.
  • Suppose head 0 rows are [1, 2, 3, 4], [9, 10, 11, 12], [17, 18, 19, 20].
  • Suppose head 1 rows are [5, 6, 7, 8], [13, 14, 15, 16], [21, 22, 23, 24].
  • After transpose back, token 0 becomes two adjacent chunks.
  • After reshape, token 0 becomes [1, 2, 3, 4, 5, 6, 7, 8].
  • Token 1 becomes [9, 10, 11, 12, 13, 14, 15, 16].
  • Token 2 becomes [17, 18, 19, 20, 21, 22, 23, 24].
  • Merge restores the original channel order.
  • The numbers survive untouched.
  • The indexing changes back.
  • That is why split and merge are perfect inverses.

5) Common bugs to catch fast

  • d_model must divide evenly by num_heads.
  • Otherwise d_head is invalid.
  • Splitting before Q, K, and V is wrong.
  • Heads should split projected features, not raw embedding chunks.
  • Forgetting the split transpose breaks later score math.
  • Forgetting the merge transpose scrambles channels.
  • Applying the exam rule on the wrong axes also breaks shapes.
  • Cached tensors in the memory shortcut must keep head layout consistent.
  • The fastest debugging move is shape printing.
  • Check projected tensors first.
  • Check split tensors second.
  • Check scores third.
  • Check weights fourth.
  • Check merged output last.
  • If weights are not [B, H, T, T], stop immediately.
  • If head outputs are not [B, H, T, d_head], stop immediately.
  • If merged output is not [B, T, D], stop immediately.
  • In multi-head code, shapes are logic.
  • They are not decoration.
  • Remember the full story.
  • The projection lens prepares features.
  • The parallel graders get separate lanes.
  • The exam rule blocks future cheating.
  • The answer sheet fixes the visible token count.
  • The memory shortcut stores old per-head facts later.
  • This is why head layout appears in every serious implementation.
  • It is not framework ceremony.
  • It is the contract that makes attention math line up.
  • Once this contract is right,
  • the rest of multi-head code becomes much easier.
  • Once this contract is wrong,
  • every later tensor looks suspicious.
  • Learn this rearrangement once.
  • Then reuse it everywhere.

Where this lives in the wild

  • GPT-2 — 12 heads split width 768 into d_head = 64 slices.
  • BERT-base — 12 heads each operate on width 64 after projection.
  • LLaMA-2 — decoder layers split model width into many per-head lanes.
  • Vision Transformer — image patches get split into multiple head views too.
  • Mistral 7B — decoder attention keeps the same split-then-merge rhythm.

Pause and recall

  1. Why does reshape alone not finish the head split?
  2. In the example, what slice becomes head 1 for token 2?
  3. Why must d_model be divisible by num_heads?
  4. Why is merge not the same as averaging heads?

Interview Q&A

Q1. Why do we transpose after reshaping into heads? - We move the head axis forward so each head gets a clean [T, d_head] view. Common wrong answer to avoid: "Transpose is only for prettier printing." Q2. What does split_heads do to the numbers? - It changes indexing, not values. - The same numbers are simply regrouped by head. Common wrong answer to avoid: "split_heads learns new features by itself." Q3. Why is merge_heads the reverse of split_heads? - Attention runs per head, but the next layer expects one combined width again. Common wrong answer to avoid: "merge_heads averages every head together."


Apply now (5 min)

  • Quick exercise.
  • Take a [1, 2, 6] tensor and split it into 3 heads.
  • Write the reshaped tensor explicitly.
  • Then write the transposed tensor explicitly.
  • Merge it back and confirm the original order returns.
  • Sketch from memory.
  • Draw the width slices first.
  • Then draw the head-major layout second.

Bridge. Good. We can now split and merge correctly. Next we code the full multi-head causal attention path. → 07-multi-head-coding.md