Skip to content

04. Weight initialization — where the rule pile starts

Three minutes. Two failure modes before training even begins. The variance-preserving fix.

Built on the ELI5 in 00-eli5.md. The rule pile needs starting numbers. Pick them wrong and the nudge never helps. Pick them right and the pile begins balanced.


The mental model

Weights are starting positions. Before any nudge flows back, every weight in the rule pile sits at some number. That number is not innocent.

A bad start kills training before it begins. The pile either freezes (no learning) or detonates (loss = NaN on step 1). Both happen on epoch 0.

So the question is. Where on the number line do we drop the weights so that the forward pass produces sensible signals — not too big, not too small, not all identical?


Failure 1 — all zeros. Symmetry kills the pile.

A natural first thought. Set every weight to 0. Then the network starts unbiased. Let the nudge push them apart.

It does not work. The pile cannot break free.

We will prove it on a tiny 2-2-1 net. Two inputs, one hidden layer of two ReLU neurons, one output. All weights start at 0. Bias also 0.

    x1 ──w11── h1 ──v1──┐
       \   /            ├── y
        \ /             |
         X              |
        / \             |
       /   \            |
    x2 ──w22── h2 ──v2──┘

Input (x1, x2) = (1, 1). Target = 1.

Attempt 1 — forward pass with all zeros

Hidden pre-activations:

h1_pre = w11·x1 + w12·x2 + b1 = 0·1 + 0·1 + 0 = 0
h2_pre = w21·x1 + w22·x2 + b2 = 0·1 + 0·1 + 0 = 0

After ReLU:

h1 = ReLU(0) = 0
h2 = ReLU(0) = 0

Output:

y = v1·h1 + v2·h2 + b_out = 0·0 + 0·0 + 0 = 0

So predicted = 0, target = 1. Loss = 1. Now we nudge.

Attempt 2 — what does backprop send back?

Gradient of loss with respect to v1:

∂L/∂v1 = (y - target) · h1 = (-1) · 0 = 0
∂L/∂v2 = (y - target) · h2 = (-1) · 0 = 0

Output weights stay frozen at 0. The first nudge does nothing.

What about hidden weights? Gradient flows back through v1, v2. Both are 0. So:

∂L/∂w11 = (signal through v1) · x1 = 0
∂L/∂w22 = (signal through v2) · x2 = 0

Every gradient is 0. Every weight stays exactly where it started.

Attempt 3 — try a non-zero input, many steps

Pick (x1, x2) = (2, 3), target = 1. Same all-zero init. Forward pass still gives h1 = h2 = 0 and y = 0. Same gradient = 0 everywhere. Train for 1000 steps. Nothing moves. The loss curve is a flat line at 1.0.

The structural reason

See. Every neuron in a layer computes the same function on the same input. Same output. Same gradient. The nudge that arrives at w11 is identical to the nudge at w21. They move together forever.

The two hidden neurons are mathematically tied. They are one neuron pretending to be two. Width zero. The pile collapses to a point.

This is the rule pile collapsing, but a worse version than chapter 1. There the pile collapsed because all rules were straight. Here it collapses because all rules are identical. Symmetry must break, and zero init cannot break it.

Same problem if every weight starts at any single constant — say all 0.5. Same outputs, same gradients, same tying. Symmetry must break.


Failure 2 — too-large random. Saturation and explosion.

OK. So we use random numbers. Symmetry breaks. But what scale?

Try big random. W ~ Uniform(-10, 10). Tanh activation, 256-wide layer.

Forward pass through one layer

Input x has 256 entries, each ~ N(0, 1). Pre-activation for one neuron:

z = Σ w_i · x_i  where w_i ~ U(-10, 10)

Variance of z ≈ 256 · Var(w) · Var(x) = 256 · 33.3 · 1 ≈ 8500.

So z typically lands around magnitude √8500 ≈ 92. Now apply tanh:

tanh(92) ≈ 1.0000000...   (saturated)
tanh(-78) ≈ -1.0000000... (saturated)

Every neuron output is pinned at ±1. The bend has flattened.

What is the gradient?

Tanh derivative: 1 - tanh²(z). At z = 92:

1 - (0.99999999)² ≈ 0.00000000...

The local gradient is microscopically small. The nudge dies the moment it touches this layer. Stack five such layers — gradient is (10⁻¹⁶)⁵. Effectively zero.

Or with ReLU — explosion instead

Same big-random weights, ReLU activation. ReLU does not saturate — it passes positives straight through. So z ≈ 92 becomes h = 92. Next layer multiplies by another 256-wide weight matrix. Now z ≈ 92 · √8500 ≈ 8500. Five layers later h is in the billions. Loss = NaN on step 1.

Either way — saturation or explosion — too-large init kills training before any nudge lands.


The variance-preserving rule

So what to do? We want the forward pass to keep signal magnitude steady through every layer. Not growing. Not shrinking.

Picture. Imagine the signal is a sound wave traveling through the rule pile. Each layer is a room. We want the sound to leave each room at the same volume it entered. Not amplified into feedback. Not muffled into silence.

Math says — for a layer with fan_in inputs and weights drawn from Normal(0, σ²):

Var(output) = fan_in · σ² · Var(input)

To keep Var(output) = Var(input), we need σ² = 1 / fan_in.

That gives Xavier (Glorot) init — works for tanh, sigmoid, linear:

Var(W) = 1 / fan_in        (or 2/(fan_in + fan_out) for symmetric flow)

But ReLU kills half the signal — it zeros out everything negative. So the surviving variance is only half. To compensate, double σ². That gives He (Kaiming) init:

Var(W) = 2 / fan_in        (for ReLU and ReLU-family activations)

Same picture, one factor of 2 to account for the half-killed bend.

Concrete example — 256 → 256 ReLU layer

He init:

σ = √(2 / 256) = √0.0078 ≈ 0.088

Each weight ~ Normal(0, 0.088). Most weights between −0.18 and +0.18. Forward pass:

Var(z) = 256 · 0.0078 · 1 = 2.0
After ReLU (kills half): Var(h) ≈ 1.0

Signal stays at magnitude 1 through every layer. Yes? Gradient also stays at magnitude 1 going back. The nudge survives.


Variance through 5 layers — bar chart

Three init schemes. Same 5-layer ReLU net. We track the magnitude of activations layer by layer.

                    L0    L1    L2    L3    L4    L5
all zeros       :  1.0 │ 0.0 │ 0.0 │ 0.0 │ 0.0 │ 0.0    (dead)
U(-10, 10)      :  1.0 │ 90  │ 8e3 │ 7e5 │ 6e7 │ NaN    (explodes)
U(-0.001, 0.001):  1.0 │ 0.01│ 1e-4│ 1e-6│ 1e-8│ ~0     (vanishes)
He init         :  1.0 │ 1.0 │ 1.0 │ 1.0 │ 1.0 │ 1.0    (steady)

Only He init keeps the signal at order 1 across depth. The other three each fail differently — but they all fail before training begins.


Pause and recall. Without scrolling — why does all-zero init fail in a deeper way than just "training is slow"? What's the difference between He and Xavier?


Where this lives in the wild

Variance-preserving init ships in every serious deep learning stack:

  • PyTorch torch.nn.init.kaiming_normal_. The default for nn.Conv2d and nn.Linear weights when used inside torchvision's ResNet, ConvNeXt, EfficientNet. Every convolution layer in a pretrained ResNet-50 was born from He init — that is the starting position the rule pile built on.
  • Hugging Face transformers GPT-2 / LLaMA _init_weights. Linear layers get Normal(0, 0.02) — a tuned variant of Xavier scaled for the transformer's residual stream. Embeddings get the same. Without this, large-scale pretraining diverges in the first 100 steps.
  • JAX flax.linen.Dense default init. Uses lecun_normalVar = 1/fan_in — which is Xavier's cousin. Google's T5 and PaLM training pipelines depend on this default to keep the forward pass stable across thousands of TPU cores.
  • Documented choice in the ResNet paper recipes. The reason ResNet-152 trains stably where VGG-19 needed batch norm crutches — He init was used end-to-end. Switching ResNet to Xavier init breaks training on ImageNet from step 1.
  • Stable Diffusion UNet conv blocks. Orthogonal init is often used for convolution kernels so gradient flow stays well-behaved through the deep encoder-decoder.

The pattern. Every framework hard-codes a variance-preserving init as the default because hand-rolled init is the #1 cause of "loss = NaN on step 1" tickets.


Interview Q&A

Q: Why isn't all-zero init just slow training?
A: Because every neuron in a layer computes the identical function on the identical input. Their gradients are identical. They move together forever. The hidden layer of width N is mathematically a hidden layer of width 1. No amount of training breaks symmetry.
Common wrong answer to avoid: "the gradient is zero so it can't move." Partly true for the first step — but the deeper failure is that even with non-zero gradients, all neurons in a layer would still move identically. Symmetry, not gradient magnitude, is the killer.

Q: Why a different formula for ReLU vs tanh?
A: ReLU zeros out half of all activations. So the variance flowing forward is halved. To compensate, He init doubles the weight variance (2/fan_in instead of 1/fan_in). Tanh is symmetric around zero, no signal is killed, so Xavier's 1/fan_in is enough.
Common wrong answer to avoid: "He is just better than Xavier." It is not — it is matched to ReLU. Use Xavier with ReLU and your activations vanish layer by layer.

Q: What does He init actually preserve through the layers?
A: The variance (or magnitude) of activations from one layer to the next. Input has variance 1, layer 1 output has variance ~1, layer 5 output has variance ~1. Same in reverse for gradients. This keeps both forward signal and backward nudge at workable scale through depth. Common wrong answer to avoid: "random normal always works." Without variance scaling, deep nets collapse or explode within 5-10 layers.

Q: If batch norm exists, do I still need careful init?
A: Yes, but less. Batch norm rescales each layer's pre-activation to unit variance, so it papers over moderate init mistakes. But on the very first forward pass, before BN's running stats are warm, you can still get NaN with bad init. And in transformers without BN, init is the only line of defense. Common wrong answer to avoid: "batch norm makes init irrelevant." It helps, but the very first forward pass can still explode, and many transformer stacks do not use batch norm at all.


Apply now (5 min)

Pick a 3-layer net. Input dim 4, hidden dim 4, hidden dim 4, output dim 1. ReLU activations. Input vector x = (1, 1, 1, 1).

Compute by hand the variance of the activations after each layer under three init schemes:

  1. All zeros. All activations = 0. Verdict: dead.
  2. U(0, 1). Each weight averages 0.5. Layer 1 pre-activation ≈ 4 · 0.5 · 1 = 2. After ReLU, h ≈ 2. Layer 2: 4 · 0.5 · 2 = 4. Layer 3: 8. Growing — would explode in a deeper net.
  3. He init. σ = √(2/4) ≈ 0.71. Layer 1 pre-activation variance: 4 · 0.5 · 1 = 2, after ReLU ≈ 1. Same at every layer. Steady.

Then — without looking — sketch from memory:

  1. The two failure modes (all-zero, too-large random) and what they break.
  2. The two init formulas: Var(W) = 1/fan_in for tanh, Var(W) = 2/fan_in for ReLU.
  3. One sentence: what He init preserves.

If you can do it in 90 seconds, you own this idea.


Bridge. Now the rule pile starts balanced. Forward pass produces sensible signals. But how do we measure how wrong the output is? That is the loss function. Read 05-loss-functions.md next.