Skip to content

12. GQA and MQA — fewer notebooks for the same crews

Long contexts fail on memory before they fail on math. This file shows the fix.

Built on the ELI5 in 00-eli5.md. The parallel crews — many heads asking different questions — now start sharing notebooks, so the cache stops eating the GPU.


Mental model — same questions, fewer notebooks

Picture the social bench first. In full multi-head attention, every parallel crew gets its own key notebook and its own value notebook. That feels clean. It also makes the cache huge during long decoding.

So what to do? Keep many query crews, but shrink the number of stored key-value notebooks. The questions stay diverse. The storage gets cheaper.

full MHA: many question-askers, many notebooks
GQA:      many question-askers, fewer shared notebooks
MQA:      many question-askers, one shared notebook
That is the whole topic.


First see the three layouts

Take 8 query heads. Do not jump to formulas yet. Just see who shares what.

MHA
Q1 -> K1,V1
Q2 -> K2,V2
Q3 -> K3,V3
Q4 -> K4,V4
Q5 -> K5,V5
Q6 -> K6,V6
Q7 -> K7,V7
Q8 -> K8,V8

GQA with 4 KV heads
Q1,Q2 -> K1,V1
Q3,Q4 -> K2,V2
Q5,Q6 -> K3,V3
Q7,Q8 -> K4,V4

MQA
Q1,Q2,Q3,Q4,Q5,Q6,Q7,Q8 -> K1,V1
Important. Query heads still stay separate. Crew 1 can ask one pattern. Crew 7 can ask another. What changes is only the stored keys and values. Simple, no?


What each variant means

In standard MHA, every head has its own Q, K, and V projections. Conceptually, head i does Qi = xWQi, Ki = xWKi, Vi = xWVi. So if n_heads = 32, then n_kv_heads = 32.

In MQA, all heads keep separate queries, but they share one key projection and one value projection. So n_kv_heads = 1. Cache shrinks by 32x versus full MHA.

In GQA, several query heads share one KV head. Example: 32 query heads, 8 KV heads, so 4 query heads per KV group. Cache shrinks by 4x versus full MHA, but keeps more specialization than MQA.


Picture before math — why the cache is the bottleneck

During decoding, we generate one token at a time. For the new token, we compute fresh queries. That part is cheap enough. Then those queries must read old keys and values from every previous position, at every layer, across all KV heads. As context grows, memory traffic dominates.

new token xt
   -> make fresh Q heads
   -> read old K,V cache for all past positions
      from all layers
      from all KV heads
So the pain is not only FLOPs. The pain is moving cache bytes again and again. That is why long-context inference is often memory-bound.


The formula — only KV heads control cache size

For fp16 cache, batch size 1:

KV cache bytes
= 2 × n_layers × n_kv_heads × d_head × seq_len × 2 bytes
The first 2 is for keys and values. The second 2 bytes is fp16 storage. The key point is this: queries are not cached across time. Only keys and values are stored. So when we move from MHA to GQA or MQA, the cache scales with KV heads, not query heads.
MHA: n_kv_heads = n_query_heads
GQA: 1 < n_kv_heads < n_query_heads
MQA: n_kv_heads = 1
That one substitution changes serving economics.


Worked tiny example — 8 query heads

Take this toy setup:

n_layers  = 1
n_query   = 8
d_head    = 4
seq_len   = 5
precision = fp16
Then the stored bytes are:
MHA: 2 × 1 × 8 × 4 × 5 × 2 = 640 bytes
GQA: 2 × 1 × 4 × 4 × 5 × 2 = 320 bytes
MQA: 2 × 1 × 1 × 4 × 5 × 2 =  80 bytes
Same 8 query crews. Very different notebook count. See the pattern before the bigger example.


Worked numerical comparison — 32-layer, 4096-d model

Now do the real serving-style comparison. Assume:

n_layers       = 32
d_model        = 4096
n_query_heads  = 32
seq_len        = 4096
precision      = fp16
First find head width.
d_head = d_model / n_query_heads
       = 4096 / 32
       = 128

Full MHA

Here n_kv_heads = 32.

per layer bytes
= 2 × 32 × 128 × 4096 × 2
= 67,108,864 bytes
≈ 64 MiB

total across 32 layers
= 64 MiB × 32
= 2048 MiB
= 2.0 GiB

GQA with 8 KV heads

Here n_kv_heads = 8. So 32 / 8 = 4 query heads share each KV group.

per layer bytes
= 2 × 8 × 128 × 4096 × 2
= 16,777,216 bytes
≈ 16 MiB

total across 32 layers
= 16 MiB × 32
= 512 MiB

MQA

Here n_kv_heads = 1.

per layer bytes
= 2 × 1 × 128 × 4096 × 2
= 2,097,152 bytes
≈ 2 MiB

total across 32 layers
= 2 MiB × 32
= 64 MiB

Side-by-side summary

MHA          2.0 GiB
GQA (8 KV)   0.5 GiB
MQA          0.0625 GiB
Same model depth. Same d_model. Same 32 query heads. Only KV sharing changed. The memory difference is brutal.


Why this matters in production

Long-context serving hits the cache wall quickly. More users means bigger batches. Longer prompts mean longer caches. More layers mean the cache repeats again and again.

more seq_len  -> more cached rows
more batch    -> more copies of cache
more layers   -> more cache stacks
result        -> HBM pressure
So the GPU often waits on memory bandwidth, not only arithmetic. Smaller KV cache gives three immediate wins: more concurrent requests per GPU, longer safe context before OOM, and lower latency once memory traffic stops choking everything.


Why not always choose MQA?

Because sharing too aggressively can blur head specialization. In full MHA, every crew writes its own notebook. Very expressive. Very expensive. In MQA, all crews read from one shared notebook. Very cheap. Sometimes slightly worse quality.

GQA is the compromise. You keep many query heads, but only a small set of KV notebooks. So you save most of the memory without forcing every crew through one single bottleneck.

quality:     MHA █████   GQA ████▊   MQA ████▍
cache size:  MHA █████   GQA █▎      MQA ▏
The exact gap depends on model size and training recipe. But the usual answer is stable: MQA saves the most memory, while GQA usually stays closer to MHA quality.


Where this lives in the wild

  • Meta Llama 2 70B uses GQA, because 70B-scale decoding needs a much smaller KV cache to serve long prompts sanely.
  • Google PaLM uses MQA, pushing the sharing idea to the extreme so decode-time memory drops sharply.
  • Older GPT-style decoder stacks use full MHA, which keeps every head independent but makes KV cache memory much heavier.
  • Mistral and Mixtral serving benefit from GQA, letting more long-context requests fit on one GPU.
  • NVIDIA TensorRT-LLM and vLLM deployments love reduced KV heads, because smaller caches directly improve batching and throughput.

Interview Q&A

Q: What is the exact difference between MHA, GQA, and MQA?
A: MHA gives every query head its own keys and values. GQA lets several query heads share one KV head. MQA is the extreme case where all query heads share one KV head.
Common wrong answer to avoid: "GQA reduces the number of query heads." No. Query heads stay. KV heads shrink.

Q: Why do GQA and MQA help inference so much?
A: Because KV cache size scales with the number of KV heads. Long-context decoding is often memory-bandwidth-bound, so smaller caches mean better throughput, less OOM risk, and better batching.
Common wrong answer to avoid: "Because they reduce the attention formula from O(n²) to O(n)." No. They shrink cache memory. Flash-style kernels attack the compute path differently.

Q: Why is GQA usually preferred over pure MQA in large modern LLMs?
A: GQA keeps most of the cache savings while preserving more head specialization. MQA saves more memory, but quality can slip because all heads share the same keys and values.
Common wrong answer to avoid: "GQA and MQA are basically identical." No. GQA is a compromise, not the same design.


Apply now (5 min)

Take a decoder with n_layers = 16, n_query_heads = 16, d_model = 2048, seq_len = 2048, and fp16.

  1. Compute d_head.
  2. Compute full-MHA KV cache size.
  3. Recompute it for GQA with 4 KV heads and for MQA.
  4. Say which design you would pick for quality-first serving and which for memory-first serving.

Then sketch from memory: - The 8-query-head diagram for MHA, GQA, and MQA. - The cache formula with n_kv_heads in the middle. - The sentence: "queries stay many, notebooks become fewer."

If you can say why GQA is the compromise and MQA is the extreme, you own this file.


Bridge. GQA shrinks the cache. But even with smaller caches, attention is still O(n²) in sequence length. The next file tackles the other bottleneck — making the attention computation itself faster. Read 13-flash-attention.md next.