10. KV cache — reuse old keys, skip old work¶
~8 min read. One new token should not recompute the whole prefix.
Built on the ELI5 in 00-eli5.md. The memory shortcut stores past keys and values for reuse. The answer sheet still grows one visible token at a time.
1) Why naive decoding wastes work¶
Autoregressive decoding adds one token per step.
The prefix already exists.
Old token representations are already fixed.
So old keys stay fixed.
Old values stay fixed.
Naive decoding still recomputes them every step.
That repeated prefix work hurts latency.
The projection lens must rebuild Q, K, and V repeatedly.
Every layer repeats that waste.
Every set of parallel graders repeats that waste.
┌──────────────────────┐ ┌──────────────────────┐
│ naive (step 3) │ │ cached (step 3) │
├──────────────────────┤ ├──────────────────────┤
│ recompute K,V for │ │ compute new k3,v3 │
│ t0,t1,t2,t3 │ │ append to cache │
│ then attend │ │ then attend │
├──────────────────────┤ ├──────────────────────┤
│ cost: 4 units │ │ cost: 1 unit │
│ grows each step │ │ constant per step │
└──────────────────────┘ └──────────────────────┘
The exam rule does not change.
A new token still looks backward only.
Cache changes efficiency, not attention legality.
2) What gets cached at each decode step¶
The cache stores past keys and values per layer.
It does not store past queries.
Each new token asks a fresh question.
So each new token needs a fresh query.
At decode step t, one token enters the block.
We compute fresh q_t, k_t, and v_t.
We append k_t and v_t to history.
Then q_t attends over all cached keys.
Then weights combine all cached values.
The core shapes look like this.
cache_k : [B, H, t_so_far, d_head]
cache_v : [B, H, t_so_far, d_head]
new_q : [B, H, 1, d_head]
new_k : [B, H, 1, d_head]
new_v : [B, H, 1, d_head]
scores : [B, H, 1, t_so_far + 1]
Notice what changed.
Query length is 1.
Visible history keeps growing with the answer sheet.
The parallel graders stay separated by head.
Time grows on the cache time axis.
In pure decode mode, the mask is usually trivial.
There is no future column beyond the new token.
Position ids still matter.
Wrong positions make correct shapes behave incorrectly.
3) Minimal code for the core mechanism¶
The essential cache object is small.
Only the time axis should grow.
class KVCache:
def __init__(self):
self.k = None
self.v = None
def append(self, new_k, new_v):
self.k = new_k if self.k is None else np.concatenate([self.k, new_k], axis=2)
self.v = new_v if self.v is None else np.concatenate([self.v, new_v], axis=2)
Axis 2 is time for [B, H, T, d_head].
Appending elsewhere scrambles the layout.
The decode step stays conceptually short.
Project the new token.
Split heads.
Append k and v.
Score fresh q against cached k.
Mix cached v with those weights.
Return the new output state.
That is the whole memory shortcut.
4) One worked example, plus the easy failure modes¶
Take four decode steps.
Let one unit mean one token's projection bundle.
Naive work is 1 + 2 + 3 + 4 = 10 units.
Cached work is 1 + 1 + 1 + 1 = 4 units.
The saving is 6 units.
That is a 60% reduction in this toy walk.
Now add shapes.
Let B=1, H=8, and d_head=64.
At step 1, cache_k is [1, 8, 1, 64].
At step 2, cache_k is [1, 8, 2, 64].
At step 4, cache_k is [1, 8, 4, 64].
cache_v grows in the same way.
At step 4, new_q stays [1, 8, 1, 64].
At step 4, scores become [1, 8, 1, 4].
The pattern is stable.
One query row stays short.
History width grows with time.
Common failures are easy to name.
- Forgetting to clear cache between prompts.
- Appending on the wrong axis.
- Using mismatched position ids.
- Caching
Qeven though decode never reuses old queries.
Check those first when outputs look strange.
Where this lives in the wild¶
- Hugging Face
generate()—past_key_valuesreuse old keys and values during token-by-token decoding. - vLLM — paged cache storage keeps many concurrent decode streams efficient and practical.
- TensorRT-LLM — optimized kernels assume this append-and-reuse pattern for low latency.
- llama.cpp — local inference keeps a running KV history to avoid prefix recomputation.
- Apple MLX — incremental decoding uses cache reuse to stay fast on Apple Silicon.
Pause and recall¶
- Why does the cache store past
KandV, but not pastQ? - Which axis grows when decoding adds another token?
- Why do scores become
[B, H, 1, t]during decode? - What two implementation bugs break cache correctness most often?
Interview Q&A¶
Q1. Why is KV cache more useful during decoding than full-sequence training? Because decoding adds one token at a time, while training usually computes the full sequence together. Common wrong answer to avoid: Saying cache matters only for bigger models.
Q2. What exactly lives inside a standard KV cache? Past keys and values for every layer, every head, and every processed position. Common wrong answer to avoid: Saying the cache stores logits, losses, or past queries.
Q3. Why can one new query attend safely over cached keys? Old token states are fixed, and the exam rule still allows backward visibility. Common wrong answer to avoid: Saying cached keys must change after every sampled token.
Apply now (5 min)¶
Quick exercise.
For B=2, H=4, and d_head=16, write the shapes at decode step 5.
Give shapes for new_q, cache_k, and scores.
Sketch from memory.
Draw one new token entering the block.
Show the projection lens making q_t, k_t, and v_t.
Show k_t and v_t joining the memory shortcut.
Mark the growing answer sheet axis.
Add one note saying the exam rule still forbids future peeking.
Bridge. Good. Cache saves compute, but it occupies memory immediately. → 11-kv-cache-memory.md