12. Batch norm vs layer norm — same cleanup, different axis¶
Why activations drift during training, and why two normalizers solve it in two different ways.
> Built on the ELI5 in 00-eli5.md. The rule pile — full of bends and nudges — gets hard to train when each layer keeps changing the scale of the next layer's inputs.¶
The practical problem — later layers chase a moving target¶
Picture first.
Layer 1 changes its weights today. Layer 2 now receives a different distribution today. Layer 3 must readjust. Then Layer 1 changes again. The whole rule pile keeps re-aiming.
Old phrase: internal covariate shift. Fine. But do not memorize the term. Memorize the pain.
early layer shifts → middle layer input scale shifts → later layer retunes
again → again → again
cat-robot feeling: "I just learned this target."
"Why did the target move again?"
Training still works. But it becomes slower and touchier. Learning rates get fussy. Deep piles wobble. The smart nudge spends effort just keeping up.
So what to do? Keep activations in a predictable range. Center them. Scale them. Then let the layer learn any useful shift back.
Same matrix, two cleanup crews¶
Take one activation matrix X with shape [B, D]. B is batch size. D is feature count.
X = [B, D]
feature1 feature2 feature3
ex1 x11 x12 x13
ex2 x21 x22 x23
ex3 x31 x32 x33
ex4 x41 x42 x43
Batch norm asks: "For one feature, how does it vary across examples?"
Layer norm asks: "Inside one example, how do its features compare?"
ex1 x11 ─ x12 ─ x13 →
ex2 x21 ─ x22 ─ x23 →
ex3 x31 ─ x32 ─ x33 →
ex4 x41 ─ x42 ─ x43 →
average across each row
Same idea. Different axis. That axis choice changes everything.
Batch norm — normalize each feature using the mini-batch¶
Picture before math. Think of feature 2 as one thermometer. Four examples arrive. Batch norm asks, "Across these four examples, what is normal for this thermometer?"
Then it recenters each example relative to that batch average. A very high feature value becomes positive. A very low one becomes negative. Middle stays near zero.
For feature j:
μ_j = (1/B) Σ_i x[i,j]
σ_j² = (1/B) Σ_i (x[i,j] - μ_j)²
x̂[i,j] = (x[i,j] - μ_j) / sqrt(σ_j² + ε)
y[i,j] = γ_j * x̂[i,j] + β_j
γ and β are learnable. Important. We do not force the feature to stay mean-zero forever. We just give the layer a stable starting scale. Then the layer can bend it back if useful.
In a simple [B, D] table, batch norm normalizes down the batch axis. In CNNs, the same idea is used per channel. There the average is usually across batch and spatial positions.
Why people liked it in vision: - Big image batches give reliable statistics. - It lets higher learning rates behave. - The batch noise itself gives mild regularization.
Layer norm — normalize each example using its own features¶
Picture first. One example enters. Layer norm ignores the other examples completely. It asks, "Inside this one example, which features are high and low relative to this example's own average?"
So each row gets cleaned independently. No other row matters. The cat-robot does not peek at its neighbors.
For example i:
μ_i = (1/D) Σ_j x[i,j]
σ_i² = (1/D) Σ_j (x[i,j] - μ_i)²
x̂[i,j] = (x[i,j] - μ_i) / sqrt(σ_i² + ε)
y[i,j] = γ_j * x̂[i,j] + β_j
Same formula family. Different averaging direction.
Why this became the transformer default: - Training and inference behave the same way. - Batch size can be 1. Still fine. - Sequence lengths can vary. Still fine. - Streaming one token at a time. Still fine.
Simple, no? If your system cannot trust batch statistics, do not build on batch statistics.
Worked example — 4 examples × 3 features¶
Use this matrix. Ignore ε for hand calculation. Also take γ = 1, β = 0.
Compute the normalized value for the same cell: x[2,2] = 5. Second example. Second feature.
Batch norm for x[2,2]¶
Look only at feature 2 down the batch. Values are [2, 5, 4, 1].
μ_feature2 = (2 + 5 + 4 + 1) / 4 = 3
σ²_feature2 = ((2-3)² + (5-3)² + (4-3)² + (1-3)²) / 4 = 2.5
σ_feature2 = sqrt(2.5) ≈ 1.581
x̂_BN[2,2] = (5 - 3) / 1.581 ≈ 1.265
So under batch norm, this cell is about 1.27 standard deviations above other examples on this feature.
Layer norm for the same cell¶
Now ignore the batch. Look only at example 2 across features. Row values are [2, 5, 8].
μ_example2 = (2 + 5 + 8) / 3 = 5
σ²_example2 = ((2-5)² + (5-5)² + (8-5)²) / 3 = 6
σ_example2 = sqrt(6) ≈ 2.449
x̂_LN[2,2] = (5 - 5) / 2.449 = 0
Same raw number. Different question.
Batch norm asked, "How unusual is 5 for feature 2 across this batch?" Answer: quite high. Layer norm asked, "How unusual is 5 inside this example's own feature vector?" Answer: exactly average.
This is the whole difference in one cell.
The interview gotcha — batch norm changes behavior at train and test time¶
During training, batch norm uses the current mini-batch mean and variance. So the same example can get slightly different normalized values in different batches.
During inference, that would be a mess. You may have batch size 1. Or a weird tiny batch. Predictions would wobble with batch composition.
So batch norm keeps running estimates during training:
running_mean ← momentum * old + (1 - momentum) * batch_mean
running_var ← momentum * old + (1 - momentum) * batch_var
At inference, batch norm uses those running statistics, not the live batch. This is the gotcha. In training, yes. In inference, no. Layer norm has no such switch. It computes from the current example both in training and inference, so the rule stays the same.
When to use which — and why transformers chose layer norm¶
Batch norm usually wins in CNNs and vision. Big image batches give reliable statistics, channel-wise averaging makes sense, batch noise gives mild regularization, and classic ResNet recipes lean on it.
Layer norm usually wins in transformers and NLP. Variable sequence lengths, batch-size-1 inference, and awkward micro-batches make any batch-dependent normalizer fragile. Attention blocks want a cleanup rule that does not depend on neighbors.
One-line memory trick:
Now ask the transformer question directly. Each token is its own feature vector. Batch norm would make that token depend on neighbors in the batch, which is bad for padding, variable lengths, and autoregressive decoding.
Layer norm stays local to the token. So attention sees a stable scale at train time, test time, and one-token-at-a-time serving. This is why BERT, GPT, and friends settled there.
The modern variants¶
Three names come up a lot now. Simple, no? Same cleanup idea. Different trade-off.
RMSNorm. Drop the mean-centering step. Only divide by RMS: x / √(mean(x²) + ε) * γ.
Why? Slightly cheaper. No mean subtraction. Empirically it works just as well in many LLMs. LLaMA, Gemma, and Mistral use it.
Group Norm. Split channels into groups, then normalize inside each group. It works with any batch size, so it shows up in diffusion UNets like Stable Diffusion.
Pre-norm vs post-norm. Pre-norm means normalize before attention or FFN. Post-norm means normalize after the sublayer. The original transformer used post-norm. Modern deep transformers usually choose pre-norm because it trains more stably at depth. LLaMA is pre-norm. So are later GPT-style stacks.
Pause and recall¶
- Which axis does batch norm average over?
- Why does batch norm fail at inference with batch size 1?
- When do you pick layer norm over batch norm? If any answer felt slow, scroll back.
Where this lives in the wild¶
- Google Photos — batch norm in CNN backbones. Image classifiers and detectors depend on stable channel statistics so very deep conv stacks train quickly.
- Tesla Autopilot — batch norm in vision towers. Camera frames come in large training batches, so batch norm stabilizes the conv rule pile used for lanes, cars, and signs.
- OpenAI GPT models — layer norm around attention and feed-forward blocks. Each token is normalized without depending on other examples, which is crucial for next-token decoding.
- Google Translate — layer norm in transformer layers. Variable sentence lengths and serving-time batch changes make layer norm the safe choice.
- Microsoft 365 Copilot — layer norm in transformer language models underneath the assistant. The model must serve wildly different prompt lengths without depending on batch neighbors.
Interview Q&A¶
Q: What is the practical problem normalization is solving in deep nets?
A: Later layers keep seeing their input distribution shift as earlier layers update. So the rule pile keeps chasing a moving target. Normalization recenters and rescales activations so the smart nudge sees a more stable landscape.
Common wrong answer to avoid: "It only prevents overfitting." Batch norm may regularize a bit, but the core job is training stability and scale control.
Q: What is the axis difference between batch norm and layer norm?
A: Batch norm normalizes each feature across the examples in a batch. Layer norm normalizes all features within one example. Same formula family, different averaging axis.
Common wrong answer to avoid: "Batch norm is across layers, layer norm is across batches." The actual distinction is batch axis versus feature axis.
Q: Why is batch norm tricky at inference?
A: Because live inference batches can be tiny or inconsistent. So batch norm does not use current batch statistics at test time. It uses running mean and variance accumulated during training.
Common wrong answer to avoid: "Inference just uses the current batch too." That makes predictions depend on batch composition, which is exactly what we avoid.
Q: Why did transformers prefer layer norm?
A: Transformers often train with variable-length sequences and serve with batch size 1 or streaming decode. Layer norm stays fully local to one token's feature vector, so it works the same in all these settings.
Common wrong answer to avoid: "Because layer norm normalizes better." The real reason is not magical quality. It is batch independence.
Apply now (5 min)¶
Take this matrix:
Do two tiny tasks.
1. Compute batch-norm stats for feature 1. Mean first. Then variance. Then normalize the value x[4,1] = 3.
2. Compute layer-norm stats for example 4. Mean first. Then variance. Then normalize the same cell x[4,1] = 3.
Then do the real retrieval step. Without looking, sketch from memory:
And say one sentence aloud: "Batch norm compares me to other examples on one feature. Layer norm compares my features inside one example."
If you can do that in one minute, you own the axis picture.
Bridge. Normalization keeps the rule pile stable. But even with all these tools — bends, nudges, regularization, normalization — some things remain genuinely mysterious. The next file admits what we still do not know. Read 13-honest-admission.md.