Skip to content

11. KV cache — not rewriting yesterday's notes

Generation should add one new note, not redo the whole notebook. This file makes that obvious.

Built on the ELI5 in 00-eli5.md. The cache — stored past keys and values — now becomes latency math, memory math, and deployment reality.


Mental model — the inference waste problem

A decoder writes one token at a time. At step 1, it reads the prompt and predicts one token. At step 2, it reads the prompt plus one generated token. At step 3, it reads the prompt plus two generated tokens. So far, fair. But now see the waste. The past tokens did not change. Their key vectors did not change. Their value vectors did not change. Yet a naive decoder recomputes them every step. That is like rewriting yesterday's class notes before writing today's line. So what to do? Keep the old notes. Store past K and V tensors. When a new token arrives, compute only the new token's Q, K, and V. Append the new K and V to the cache. Then let the new query attend over the whole stored history. ASCII picture:

step t-1 cache:  K1 K2 K3 ... Kt-1
                 V1 V2 V3 ... Vt-1

new token xt  ->  Qt, Kt, Vt
                     |
                     +--> Qt attends to [K1 ... Kt]
                     +--> output uses [V1 ... Vt]

append:
cache <- K1 K2 K3 ... Kt
         V1 V2 V3 ... Vt
This is inference-only optimization. Training usually does not use it. Because training processes full sequences in parallel.


The formula — cost without cache versus with cache

For one decoding step with current context length T:

Naive recomputation

Recompute attention for all T positions. Per layer, attention score work is roughly:

T queries × T keys = T^2
So per-step cost is O(T^2).

Cached decoding

Compute attention only for the new token. That means:

1 query × T keys = T
So per-step cost is O(T). That is the headline you must remember.
naive per step   = O(T^2)
cached per step  = O(T)
A small note for completeness. If you sum across many generated steps, total decoding becomes:
naive total over T steps   = O(T^3)
cached total over T steps  = O(T^2)
But interviewers usually ask the per-step contrast. So start there.


Worked numerical comparison — T = 4

Suppose the current prefix length is 4.

Without cache

You recompute a full 4 x 4 attention pattern. Per layer:

score products = 4^2 = 16

With cache

The new query attends to four stored keys. Per layer:

score products = 4
Savings:
16 -> 4
4x less score work
ASCII picture:
without cache:  q1->kkkk
                q2->kkkk
                q3->kkkk
                q4->kkkk

with cache:     q4->kkkk
Same answer for the new token. Much less repeated work.


Worked numerical comparison — T = 16

Now the prefix length is 16.

Without cache

score products = 16^2 = 256

With cache

score products = 16
Savings:
256 -> 16
16x less score work
See the pattern. The speedup factor grows with sequence length. At T = 16, you save 16x. At T = 128, it gets dramatic.


Worked numerical comparison — T = 128

Current prefix length is 128.

Without cache

score products = 128^2 = 16,384

With cache

score products = 128
Savings:
16,384 -> 128
128x less score work
ASCII bar sketch:
T=4    naive ################    cache ####
T=16   naive ################################################    cache ####
T=128  naive [very huge 16384]                                 cache ########
The exact latency also depends on kernels, memory bandwidth, and batching. But the direction is not subtle. Without cache, long-context decoding feels slow fast.


How the KV cache actually works

Inside one attention head, each token produces:

Q = x W_Q
K = x W_K
V = x W_V
During cached decoding at step t: 1. Take the new residual-stream vector x_t. 2. Compute Q_t, K_t, and V_t. 3. Append K_t and V_t to the stored cache. 4. Compute attention using Q_t against all cached keys. 5. Weighted-sum all cached values. 6. Write the result back into the residual stream. Important detail. You still recompute the new token's query. You do not reuse old queries. Why? Because only the latest token needs a fresh output. Old outputs were already produced on earlier steps. ASCII picture for one layer:
old cache:   K1 K2 K3
             V1 V2 V3

new token x4 -> Q4 K4 V4
                  |  |
                  |  +---- append to cache
                  |
                  +---- attend over K1 K2 K3 K4
                              use V1 V2 V3 V4
Now multiply this by every layer. Each layer keeps its own cache. That matters for memory.


Memory formula — why speed costs RAM

The usual fp16 KV cache memory formula for batch size 1 is:

memory = 2 × n_layers × n_heads × d_head × seq_len × 2 bytes
Why the leading 2? One copy for keys. One copy for values. Why the trailing 2 bytes? Because fp16 stores each number in 2 bytes. So yes, the cache saves compute by spending memory. That trade is usually worth it. But at long context windows, memory becomes the next bottleneck.


Worked memory calculation — a 7B model

Take a common 7B-style decoder setup:

n_layers = 32
n_heads  = 32
d_head   = 128
seq_len  = 4096
precision = fp16
Now plug into the formula:
memory = 2 × 32 × 32 × 128 × 4096 × 2 bytes
Step by step:
32 × 32 = 1024
1024 × 128 = 131,072
131,072 × 4096 = 536,870,912
536,870,912 × 2 = 1,073,741,824   # K and V
1,073,741,824 × 2 bytes = 2,147,483,648 bytes
So the KV cache is about:
≈ 2.0 GiB
for batch size 1. See. Just the cache. Not the model weights. Not activations for training. Only cached keys and values. Now imagine batch size 8. Roughly 16 GiB. That is why serving engineers care so much about cache compression.


Cache eviction and sliding windows

What if context gets too long? You cannot keep everything forever. One answer is eviction. Drop old cache entries. Keep only the most recent window. This is the sliding-window idea. Mistral-style attention is the public example people know. Picture a fixed notebook that keeps only the latest pages. ASCII picture:

full history:   t1 t2 t3 t4 t5 t6 t7 t8
window size 4:              t5 t6 t7 t8
next step keeps:               t6 t7 t8 t9
This cuts memory growth. It also changes what the model can directly access. So the architecture must be trained with that rule in mind. Eviction is not free. You save RAM. You may lose very old exact context. Production systems choose this trade deliberately.


GQA and MQA — cache compression by sharing

Now another trick. Do all query heads really need separate key and value heads? Not always.

Grouped-query attention

In GQA, many query heads share a smaller number of KV heads. Example:

32 query heads
8 KV heads
So the cache scales with 8, not 32. Memory drops by 4x. Using the earlier 7B example:
2.0 GiB / 4 = 0.5 GiB
More exactly, about 512 MiB.

Multi-query attention

In MQA, all query heads share one KV head. Example:

32 query heads
1 KV head
Now memory drops by 32x relative to full multi-head KV storage. From the same example:
2.0 GiB / 32 ≈ 64 MiB
So what changes conceptually? The parallel crews still ask different questions. But they share fewer notebooks of keys and values. That is why GQA and MQA matter so much for serving. They are not only modeling tricks. They are systems tricks.


Where this lives in the wild

  • ChatGPT- and Claude-style streaming generation depends on KV caching because each token is generated from a growing prefix and latency would be unacceptable with full recomputation.
  • vLLM is built around cache management ideas like paged attention so large batches of user requests can share GPU memory more efficiently during decoding.
  • Mistral public models use sliding-window attention to bound memory and compute while keeping good recent-context performance.
  • Meta Llama-family serving stacks rely on KV caches at every decoder layer because decoder-only generation would otherwise waste enormous work on long prompts.
  • Falcon, Llama 3, and other public models use MQA or GQA variants because sharing KV heads cuts cache memory enough to matter in real deployments.

Interview Q&A

Q: What does a KV cache store? A: Past keys and values for every decoder layer, usually for every attention head or KV head. On each new step, the model computes only the new token's Q, K, and V, appends K and V, and attends over the stored history. Q: Why is cached decoding O(T) instead of O(T^2) per step? A: Because only one new query is computed at the current step, and it attends over T cached keys. Without cache, the model recomputes all T queries against all T keys again. Common wrong answer to avoid: "Because the cache removes attention." No. Attention still happens. Repeated recomputation is what disappears. Q: What is the main downside of a KV cache? A: Memory. Long contexts, many layers, and large batches make the cache huge. That is why sliding windows, GQA, MQA, quantized cache formats, and paged memory matter. Common wrong answer to avoid: "No downside, only faster." In production, cache memory is often the limiting resource. Q: Difference between GQA and MQA? A: GQA shares keys and values across groups of query heads. MQA is the extreme case where all query heads share one KV head. Both reduce cache memory; MQA reduces it more aggressively.


Apply now (5 min)

Take a decoder with current context length T = 10. Do four quick calculations. 1. Compute naive per-step score work: 10^2. 2. Compute cached per-step score work: 10. 3. Say the speedup factor out loud. 4. Sketch one-layer cache growth from step 1 to step 4. Then sketch from memory: - The formula O(T^2) versus O(T). - The memory formula for fp16 KV cache. - The idea of GQA as "many query heads, fewer KV heads". If you can explain why KV cache helps inference but not ordinary parallel training, you own this file.


Bridge. KV cache saves compute but eats memory. Modern models fight back by sharing KV heads across query heads. That trick — and the broader family of attention efficiency methods — is next. Read 12-gqa-mqa.md next.