04. Weight initialization — where the rule pile starts¶
Three minutes. Two failure modes before training even begins. The variance-preserving fix.
Built on the ELI5 in
00-eli5.md. The rule pile needs starting numbers. Pick them wrong and the nudge never helps. Pick them right and the pile begins balanced.
The mental model¶
Weights are starting positions. Before any nudge flows back, every weight in the rule pile sits at some number. That number is not innocent.
A bad start kills training before it begins. The pile either freezes (no learning) or detonates (loss = NaN on step 1). Both happen on epoch 0.
So the question is. Where on the number line do we drop the weights so that the forward pass produces sensible signals — not too big, not too small, not all identical?
Failure 1 — all zeros. Symmetry kills the pile.¶
A natural first thought. Set every weight to 0. Then the network starts unbiased. Let the nudge push them apart.
It does not work. The pile cannot break free.
We will prove it on a tiny 2-2-1 net. Two inputs, one hidden layer of two ReLU neurons, one output. All weights start at 0. Bias also 0.
Input (x1, x2) = (1, 1). Target = 1.
Attempt 1 — forward pass with all zeros¶
Hidden pre-activations:
After ReLU:
Output:
So predicted = 0, target = 1. Loss = 1. Now we nudge.
Attempt 2 — what does backprop send back?¶
Gradient of loss with respect to v1:
Output weights stay frozen at 0. The first nudge does nothing.
What about hidden weights? Gradient flows back through v1, v2. Both are 0. So:
Every gradient is 0. Every weight stays exactly where it started.
Attempt 3 — try a non-zero input, many steps¶
Pick (x1, x2) = (2, 3), target = 1. Same all-zero init. Forward pass still gives h1 = h2 = 0 and y = 0. Same gradient = 0 everywhere. Train for 1000 steps. Nothing moves. The loss curve is a flat line at 1.0.
The structural reason¶
See. Every neuron in a layer computes the same function on the same input. Same output. Same gradient. The nudge that arrives at w11 is identical to the nudge at w21. They move together forever.
The two hidden neurons are mathematically tied. They are one neuron pretending to be two. Width zero. The pile collapses to a point.
This is the rule pile collapsing, but a worse version than chapter 1. There the pile collapsed because all rules were straight. Here it collapses because all rules are identical. Symmetry must break, and zero init cannot break it.
Same problem if every weight starts at any single constant — say all 0.5. Same outputs, same gradients, same tying. Symmetry must break.
Failure 2 — too-large random. Saturation and explosion.¶
OK. So we use random numbers. Symmetry breaks. But what scale?
Try big random. W ~ Uniform(-10, 10). Tanh activation, 256-wide layer.
Forward pass through one layer¶
Input x has 256 entries, each ~ N(0, 1). Pre-activation for one neuron:
Variance of z ≈ 256 · Var(w) · Var(x) = 256 · 33.3 · 1 ≈ 8500.
So z typically lands around magnitude √8500 ≈ 92. Now apply tanh:
Every neuron output is pinned at ±1. The bend has flattened.
What is the gradient?¶
Tanh derivative: 1 - tanh²(z). At z = 92:
The local gradient is microscopically small. The nudge dies the moment it touches this layer. Stack five such layers — gradient is (10⁻¹⁶)⁵. Effectively zero.
Or with ReLU — explosion instead¶
Same big-random weights, ReLU activation. ReLU does not saturate — it passes positives straight through. So z ≈ 92 becomes h = 92. Next layer multiplies by another 256-wide weight matrix. Now z ≈ 92 · √8500 ≈ 8500. Five layers later h is in the billions. Loss = NaN on step 1.
Either way — saturation or explosion — too-large init kills training before any nudge lands.
The variance-preserving rule¶
So what to do? We want the forward pass to keep signal magnitude steady through every layer. Not growing. Not shrinking.
Picture. Imagine the signal is a sound wave traveling through the rule pile. Each layer is a room. We want the sound to leave each room at the same volume it entered. Not amplified into feedback. Not muffled into silence.
Math says — for a layer with fan_in inputs and weights drawn from Normal(0, σ²):
To keep Var(output) = Var(input), we need σ² = 1 / fan_in.
That gives Xavier (Glorot) init — works for tanh, sigmoid, linear:
But ReLU kills half the signal — it zeros out everything negative. So the surviving variance is only half. To compensate, double σ². That gives He (Kaiming) init:
Same picture, one factor of 2 to account for the half-killed bend.
Concrete example — 256 → 256 ReLU layer¶
He init:
Each weight ~ Normal(0, 0.088). Most weights between −0.18 and +0.18. Forward pass:
Signal stays at magnitude 1 through every layer. Yes? Gradient also stays at magnitude 1 going back. The nudge survives.
Variance through 5 layers — bar chart¶
Three init schemes. Same 5-layer ReLU net. We track the magnitude of activations layer by layer.
L0 L1 L2 L3 L4 L5
all zeros : 1.0 │ 0.0 │ 0.0 │ 0.0 │ 0.0 │ 0.0 (dead)
U(-10, 10) : 1.0 │ 90 │ 8e3 │ 7e5 │ 6e7 │ NaN (explodes)
U(-0.001, 0.001): 1.0 │ 0.01│ 1e-4│ 1e-6│ 1e-8│ ~0 (vanishes)
He init : 1.0 │ 1.0 │ 1.0 │ 1.0 │ 1.0 │ 1.0 (steady)
Only He init keeps the signal at order 1 across depth. The other three each fail differently — but they all fail before training begins.
Pause and recall. Without scrolling — why does all-zero init fail in a deeper way than just "training is slow"? What's the difference between He and Xavier?
Where this lives in the wild¶
Variance-preserving init ships in every serious deep learning stack:
- PyTorch
torch.nn.init.kaiming_normal_. The default fornn.Conv2dandnn.Linearweights when used inside torchvision's ResNet, ConvNeXt, EfficientNet. Every convolution layer in a pretrained ResNet-50 was born from He init — that is the starting position the rule pile built on. - Hugging Face
transformersGPT-2 / LLaMA_init_weights. Linear layers getNormal(0, 0.02)— a tuned variant of Xavier scaled for the transformer's residual stream. Embeddings get the same. Without this, large-scale pretraining diverges in the first 100 steps. - JAX
flax.linen.Densedefault init. Useslecun_normal—Var = 1/fan_in— which is Xavier's cousin. Google's T5 and PaLM training pipelines depend on this default to keep the forward pass stable across thousands of TPU cores. - Documented choice in the ResNet paper recipes. The reason ResNet-152 trains stably where VGG-19 needed batch norm crutches — He init was used end-to-end. Switching ResNet to Xavier init breaks training on ImageNet from step 1.
- Stable Diffusion UNet conv blocks. Orthogonal init is often used for convolution kernels so gradient flow stays well-behaved through the deep encoder-decoder.
The pattern. Every framework hard-codes a variance-preserving init as the default because hand-rolled init is the #1 cause of "loss = NaN on step 1" tickets.
Interview Q&A¶
Q: Why isn't all-zero init just slow training?
A: Because every neuron in a layer computes the identical function on the identical input. Their gradients are identical. They move together forever. The hidden layer of width N is mathematically a hidden layer of width 1. No amount of training breaks symmetry.
Common wrong answer to avoid: "the gradient is zero so it can't move." Partly true for the first step — but the deeper failure is that even with non-zero gradients, all neurons in a layer would still move identically. Symmetry, not gradient magnitude, is the killer.
Q: Why a different formula for ReLU vs tanh?
A: ReLU zeros out half of all activations. So the variance flowing forward is halved. To compensate, He init doubles the weight variance (2/fan_in instead of 1/fan_in). Tanh is symmetric around zero, no signal is killed, so Xavier's 1/fan_in is enough.
Common wrong answer to avoid: "He is just better than Xavier." It is not — it is matched to ReLU. Use Xavier with ReLU and your activations vanish layer by layer.
Q: What does He init actually preserve through the layers?
A: The variance (or magnitude) of activations from one layer to the next. Input has variance 1, layer 1 output has variance ~1, layer 5 output has variance ~1. Same in reverse for gradients. This keeps both forward signal and backward nudge at workable scale through depth.
Common wrong answer to avoid: "random normal always works." Without variance scaling, deep nets collapse or explode within 5-10 layers.
Q: If batch norm exists, do I still need careful init?
A: Yes, but less. Batch norm rescales each layer's pre-activation to unit variance, so it papers over moderate init mistakes. But on the very first forward pass, before BN's running stats are warm, you can still get NaN with bad init. And in transformers without BN, init is the only line of defense.
Common wrong answer to avoid: "batch norm makes init irrelevant." It helps, but the very first forward pass can still explode, and many transformer stacks do not use batch norm at all.
Apply now (5 min)¶
Pick a 3-layer net. Input dim 4, hidden dim 4, hidden dim 4, output dim 1. ReLU activations. Input vector x = (1, 1, 1, 1).
Compute by hand the variance of the activations after each layer under three init schemes:
- All zeros. All activations = 0. Verdict: dead.
- U(0, 1). Each weight averages 0.5. Layer 1 pre-activation ≈
4 · 0.5 · 1 = 2. After ReLU,h ≈ 2. Layer 2:4 · 0.5 · 2 = 4. Layer 3:8. Growing — would explode in a deeper net. - He init.
σ = √(2/4) ≈ 0.71. Layer 1 pre-activation variance:4 · 0.5 · 1 = 2, after ReLU ≈ 1. Same at every layer. Steady.
Then — without looking — sketch from memory:
- The two failure modes (all-zero, too-large random) and what they break.
- The two init formulas:
Var(W) = 1/fan_infor tanh,Var(W) = 2/fan_infor ReLU. - One sentence: what He init preserves.
If you can do it in 90 seconds, you own this idea.
Bridge. Now the rule pile starts balanced. Forward pass produces sensible signals. But how do we measure how wrong the output is? That is the loss function. Read
05-loss-functions.mdnext.