13. Flash Attention — same answer, far less memory traffic¶
Standard attention keeps spilling giant tables to slow memory. Flash Attention stops that waste.
Built on the ELI5 in
00-eli5.md. The social bench — every token consulting other tokens — now gets a faster seating plan, using tiny fast scratch space instead of giant slow tables.
Mental model — stop drawing the full seating chart¶
Picture the social bench. Every token wants to look at many other tokens. In standard attention, we often build the full seating chart first. Who looked at whom? How strongly? That chart has N × N cells. For long contexts, it becomes huge.
Then we write that giant chart to GPU HBM, read it back for softmax, and read it again for the weighted sum. See the waste.
Flash Attention says something simpler. Do not materialize the full table. Work on one small tile at a time, keep that tile in fast SRAM, update the answer, and discard the tile.The standard path — why memory explodes¶
The textbook computation is:
The math is fine. The implementation is the problem. For sequence lengthN, the score matrix S is N × N. The probability matrix P is also N × N. That means big reads, big writes, and big HBM traffic.
Q ----->
[ QK^T ] -> S in HBM -> softmax -> P in HBM -> multiply with V -> O
K ----->
V --------------------------------------------------------------^
The core idea — tile the attention work¶
Break queries into blocks. Break keys and values into matching blocks. Process one Q block against one K,V block. Do the score calculation, the softmax update, and the weighted-value update while the tile is still hot in SRAM. Never store the whole N × N matrix.
Q blocks on rows
K blocks on columns
K1 K2 K3 K4
+-------+-------+-------+-------+
Q1 | tile | tile | tile | tile |
+-------+-------+-------+-------+
Q2 | tile | tile | tile | tile |
+-------+-------+-------+-------+
Q3 | tile | tile | tile | tile |
+-------+-------+-------+-------+
Q4 | tile | tile | tile | tile |
+-------+-------+-------+-------+
one tile at a time lives in SRAM
not the full grid
How softmax stays exact without the full matrix¶
This part sounds magical at first. It is not. For each query-row block, we keep three running statistics in SRAM:
Then for each newK,V tile, we compute local scores, update the running max, rescale old contributions if needed, and add the new tile's weighted values.
row block state in SRAM
m_old, l_old, o_old
new tile scores -> local max, local exp-sum, local weighted V
merge exactly:
m_new = max(m_old, m_tile)
l_new = rescaled(l_old) + rescaled(l_tile)
o_new = rescaled(o_old) + rescaled(o_tile)
Picture before math — what gets fused¶
Standard attention often feels like three big stages.
Flash Attention fuses them. One tile flows through all three steps while still hot in SRAM. Less spilling. Less rereading. Same answer. Faster wall clock.The IO-complexity headline¶
Let M be the available SRAM size for the kernel. Then the HBM IO complexity for Flash Attention is roughly:
n means more pain for standard attention. Bigger usable SRAM means fewer expensive trips to HBM for Flash Attention.
Worked numerical example — why 8K hurts¶
Take one head with:
First count score cells. If a standard kernel keeps the score buffer and the post-softmax buffer in fp32, each number needs 4 bytes. So one large buffer is: Two such large buffers live in the standard path, so the live footprint is about: That is just the giant attention tables. Not model weights. Not KV cache. Not other activations. See why long context becomes scary.Worked numerical example — the 64 KB tile view¶
Now tile it. Take block size B = 64 and fp16 tiles for Q, K, and V. Each tile has shape 64 × 128.
One tile size is:
Inside SRAM we can keep roughly:Q tile 16 KB
K tile 16 KB
V tile 16 KB
stats/output scratch about 16 KB
------------------------------
total about 64 KB
512 MiB giant table in HBM, we keep a tiny working set near the cores. Tile comes in, gets processed, gets discarded, and the next tile comes in. Full answer emerges. No giant N × N matrix is ever stored.
Why wall-clock speed improves so much¶
Less HBM traffic means the GPU spends more time doing useful math. Fused kernels remove intermediate writes. Better tiling improves occupancy. That is why people often see 2x to 4x speedups, and why longer contexts stop OOMing so quickly.
standard attention
compute -> write -> read -> write -> read -> compute
flash attention
compute -> update -> compute -> update
FlashAttention 2 and 3 — same principle, better execution¶
FlashAttention 1 gave the big idea: tile and fuse. FlashAttention 2 improved work partitioning so more of the GPU stayed busy. FlashAttention 3 pushes further on newer hardware with better pipelining and more hardware-aware scheduling. The principle stays the same: do not materialize giant attention tables in HBM.
Where this lives in the wild¶
- vLLM relies on FlashAttention-style kernels so long-context batch serving does not drown in attention-memory traffic.
- Hugging Face TGI uses fused fast-attention kernels, because production decoding needs speed without changing model outputs.
- NVIDIA TensorRT-LLM ships highly optimized flash-style attention kernels for high-throughput GPU inference.
- PyTorch scaled dot-product attention routes to flash-style kernels on supported hardware, making the optimization mainstream.
- Modern Meta, Mistral, and other LLM serving stacks depend on flash-style attention under the hood, because standard materialized attention wastes too much memory bandwidth.
Interview Q&A¶
Q: What problem does Flash Attention solve?
A: It solves the memory-traffic problem of standard attention. Instead of materializing full N × N score and probability matrices in HBM, it computes attention in SRAM-sized tiles and fuses the steps.
Common wrong answer to avoid: "It changes attention from O(n²) math to O(n log n)." No. The math is still dense attention. The implementation is more IO-efficient.
Q: Why is Flash Attention faster if the formula stays the same?
A: Because GPU performance is often limited by memory movement, not only arithmetic. Flash Attention reduces HBM reads and writes and keeps active tiles in fast SRAM.
Common wrong answer to avoid: "Because it skips softmax." No. Softmax is still there, just fused into the tiled kernel.
Q: Is Flash Attention approximate?
A: No. It is an exact attention algorithm. With proper running-max and running-sum updates, it produces the same output as standard attention, only with a smarter execution order.
Common wrong answer to avoid: "Yes, it trades accuracy for speed." That describes many approximate attention methods, not Flash Attention.
Apply now (5 min)¶
Take N = 4096 and answer fast.
- Compute
N². - If one score buffer is fp32, compute its size in bytes.
- Say out loud why storing both scores and probabilities is painful.
- Pick a tile with
B = 32,d = 128, fp16, and compute one tile size.
Then sketch from memory:
- The big Q blocks × K blocks grid.
- The three running row statistics: m, l, and o.
- The sentence: "same answer, less HBM traffic."
If you can explain why Flash Attention is exact, not approximate, you own this file.
Bridge. GQA shrinks the cache. Flash Attention speeds the compute. Together they make modern long-context transformers practical. But this module still glossed over plenty. The next file admits what was simplified. Read 14-honest-admission.md next.