12. GQA and MQA — fewer notebooks for the same crews¶
Long contexts fail on memory before they fail on math. This file shows the fix.
Built on the ELI5 in
00-eli5.md. The parallel crews — many heads asking different questions — now start sharing notebooks, so the cache stops eating the GPU.
Mental model — same questions, fewer notebooks¶
Picture the social bench first. In full multi-head attention, every parallel crew gets its own key notebook and its own value notebook. That feels clean. It also makes the cache huge during long decoding.
So what to do? Keep many query crews, but shrink the number of stored key-value notebooks. The questions stay diverse. The storage gets cheaper.
full MHA: many question-askers, many notebooks
GQA: many question-askers, fewer shared notebooks
MQA: many question-askers, one shared notebook
First see the three layouts¶
Take 8 query heads. Do not jump to formulas yet. Just see who shares what.
MHA
Q1 -> K1,V1
Q2 -> K2,V2
Q3 -> K3,V3
Q4 -> K4,V4
Q5 -> K5,V5
Q6 -> K6,V6
Q7 -> K7,V7
Q8 -> K8,V8
GQA with 4 KV heads
Q1,Q2 -> K1,V1
Q3,Q4 -> K2,V2
Q5,Q6 -> K3,V3
Q7,Q8 -> K4,V4
MQA
Q1,Q2,Q3,Q4,Q5,Q6,Q7,Q8 -> K1,V1
What each variant means¶
In standard MHA, every head has its own Q, K, and V projections. Conceptually, head i does Qi = xWQi, Ki = xWKi, Vi = xWVi. So if n_heads = 32, then n_kv_heads = 32.
In MQA, all heads keep separate queries, but they share one key projection and one value projection. So n_kv_heads = 1. Cache shrinks by 32x versus full MHA.
In GQA, several query heads share one KV head. Example: 32 query heads, 8 KV heads, so 4 query heads per KV group. Cache shrinks by 4x versus full MHA, but keeps more specialization than MQA.
Picture before math — why the cache is the bottleneck¶
During decoding, we generate one token at a time. For the new token, we compute fresh queries. That part is cheap enough. Then those queries must read old keys and values from every previous position, at every layer, across all KV heads. As context grows, memory traffic dominates.
new token xt
-> make fresh Q heads
-> read old K,V cache for all past positions
from all layers
from all KV heads
The formula — only KV heads control cache size¶
For fp16 cache, batch size 1:
The first2 is for keys and values. The second 2 bytes is fp16 storage. The key point is this: queries are not cached across time. Only keys and values are stored. So when we move from MHA to GQA or MQA, the cache scales with KV heads, not query heads.
That one substitution changes serving economics.
Worked tiny example — 8 query heads¶
Take this toy setup:
Then the stored bytes are:MHA: 2 × 1 × 8 × 4 × 5 × 2 = 640 bytes
GQA: 2 × 1 × 4 × 4 × 5 × 2 = 320 bytes
MQA: 2 × 1 × 1 × 4 × 5 × 2 = 80 bytes
Worked numerical comparison — 32-layer, 4096-d model¶
Now do the real serving-style comparison. Assume:
First find head width.Full MHA¶
Here n_kv_heads = 32.
per layer bytes
= 2 × 32 × 128 × 4096 × 2
= 67,108,864 bytes
≈ 64 MiB
total across 32 layers
= 64 MiB × 32
= 2048 MiB
= 2.0 GiB
GQA with 8 KV heads¶
Here n_kv_heads = 8. So 32 / 8 = 4 query heads share each KV group.
per layer bytes
= 2 × 8 × 128 × 4096 × 2
= 16,777,216 bytes
≈ 16 MiB
total across 32 layers
= 16 MiB × 32
= 512 MiB
MQA¶
Here n_kv_heads = 1.
per layer bytes
= 2 × 1 × 128 × 4096 × 2
= 2,097,152 bytes
≈ 2 MiB
total across 32 layers
= 2 MiB × 32
= 64 MiB
Side-by-side summary¶
Same model depth. Samed_model. Same 32 query heads. Only KV sharing changed. The memory difference is brutal.
Why this matters in production¶
Long-context serving hits the cache wall quickly. More users means bigger batches. Longer prompts mean longer caches. More layers mean the cache repeats again and again.
more seq_len -> more cached rows
more batch -> more copies of cache
more layers -> more cache stacks
result -> HBM pressure
Why not always choose MQA?¶
Because sharing too aggressively can blur head specialization. In full MHA, every crew writes its own notebook. Very expressive. Very expensive. In MQA, all crews read from one shared notebook. Very cheap. Sometimes slightly worse quality.
GQA is the compromise. You keep many query heads, but only a small set of KV notebooks. So you save most of the memory without forcing every crew through one single bottleneck.
The exact gap depends on model size and training recipe. But the usual answer is stable: MQA saves the most memory, while GQA usually stays closer to MHA quality.Where this lives in the wild¶
- Meta Llama 2 70B uses GQA, because 70B-scale decoding needs a much smaller KV cache to serve long prompts sanely.
- Google PaLM uses MQA, pushing the sharing idea to the extreme so decode-time memory drops sharply.
- Older GPT-style decoder stacks use full MHA, which keeps every head independent but makes KV cache memory much heavier.
- Mistral and Mixtral serving benefit from GQA, letting more long-context requests fit on one GPU.
- NVIDIA TensorRT-LLM and vLLM deployments love reduced KV heads, because smaller caches directly improve batching and throughput.
Interview Q&A¶
Q: What is the exact difference between MHA, GQA, and MQA?
A: MHA gives every query head its own keys and values. GQA lets several query heads share one KV head. MQA is the extreme case where all query heads share one KV head.
Common wrong answer to avoid: "GQA reduces the number of query heads." No. Query heads stay. KV heads shrink.
Q: Why do GQA and MQA help inference so much?
A: Because KV cache size scales with the number of KV heads. Long-context decoding is often memory-bandwidth-bound, so smaller caches mean better throughput, less OOM risk, and better batching.
Common wrong answer to avoid: "Because they reduce the attention formula from O(n²) to O(n)." No. They shrink cache memory. Flash-style kernels attack the compute path differently.
Q: Why is GQA usually preferred over pure MQA in large modern LLMs?
A: GQA keeps most of the cache savings while preserving more head specialization. MQA saves more memory, but quality can slip because all heads share the same keys and values.
Common wrong answer to avoid: "GQA and MQA are basically identical." No. GQA is a compromise, not the same design.
Apply now (5 min)¶
Take a decoder with n_layers = 16, n_query_heads = 16, d_model = 2048, seq_len = 2048, and fp16.
- Compute
d_head. - Compute full-MHA KV cache size.
- Recompute it for GQA with
4KV heads and for MQA. - Say which design you would pick for quality-first serving and which for memory-first serving.
Then sketch from memory:
- The 8-query-head diagram for MHA, GQA, and MQA.
- The cache formula with n_kv_heads in the middle.
- The sentence: "queries stay many, notebooks become fewer."
If you can say why GQA is the compromise and MQA is the extreme, you own this file.
Bridge. GQA shrinks the cache. But even with smaller caches, attention is still O(n²) in sequence length. The next file tackles the other bottleneck — making the attention computation itself faster. Read 13-flash-attention.md next.