Skip to content

02. Why Neural Networks Work — Narrative Explainer

Companion to 03_study_material.md. The study material gives the formulas. This gives the why — and the picture in your head.


Table of contents

  • ELI5 — the whole thing in kid words (start here)
  • Chapter 1: One rule is not enough
  • 1.1 The first failure — XOR
  • 1.2 Why this matters to you
  • Chapter 2: Many rules with bends
  • 2.1 Stacking alone does not help — the rule pile collapses
  • 2.2 The bend — activation function
  • Chapter 3: How weights get values
  • 3.1 Weight initialization — where the rule pile starts
  • 3.2 The nudge — backpropagation
  • 3.3 Batches, epochs, and stochastic gradient descent
  • 3.4 Loss functions — measuring wrongness
  • Chapter 4: Why deep nets are hard
  • 4.1 Deep rule pile kills the nudge — vanishing gradients
  • 4.2 Smart nudging — the optimizer
  • Chapter 5: Why bigger works (and when it doesn't)
  • 5.1 Scaling laws
  • 5.2 Overfitting, regularization, and the generalization mystery
  • 5.3 Honest admission
  • Chapter 6: Recap & application
  • 6.1 The failure-fix chain
  • 6.2 Key points to remember
  • 6.3 Important interview questions
  • 6.4 Production experience
  • 6.5 Apply now — graded exercises

ELI5 — the whole thing in kid words

Imagine you want to teach a robot to spot cats in photos.

You write one rule. "If pixel here is fluffy, say cat." This is one perceptron. Too dumb. Many things are fluffy. Not all are cats.

So you write many rules and stack them up. Call this the rule pile. But here is the trick — if all rules are straight, the rule pile collapses into one big straight rule. No improvement.

So you put a tiny bend between each rule. Now they cannot collapse. The robot can see shapes, not just lines.

Where do the rule numbers come from? The robot guesses, checks the answer, then nudges every rule a tiny bit toward the right answer. Do this a million times. The robot gets very good. This nudging is called backpropagation.

But "nudge a tiny bit" is harder than it sounds. If the rule pile is too deep, the nudges shrink to nothing before reaching the early rules. So we use special bends (ReLU, GELU) that keep nudges alive.

Plain nudging is also dumb. Smart nudging — remember past direction, use different speeds for different rules — is much better. This smart nudging is called Adam.

Finally — bigger rule pile + more cat photos = smarter robot. There is a math rule for the best ratio. About 20 photos per rule.

That is it. That is a neural network.

The chapters that follow walk every piece. Each piece exists because something specific broke. Keep the cat-robot picture in your head. Every technical name maps to something there.


Chapter 1: One rule is not enough

1.1 The first failure — XOR

What a perceptron actually computes

Take the simplest cat-robot. One rule. Two inputs x₁, x₂. Two weights w₁, w₂. One bias b.

The perceptron computes a score:

score = w₁·x₁ + w₂·x₂ + b

if score > 0:  output = 1   ("yes")
else:          output = 0   ("no")

Three knobs: w₁, w₂, b. That is the entire machine.

A perceptron is a knife in 2D

The boundary is where score = 0. That is a straight line:

w₁·x₁ + w₂·x₂ + b = 0

Weights rotate the line. Bias slides the line. In 2D it is a line. In 3D, a plane. In nD, a hyperplane. Always one straight cut.

AND works fine — concrete numbers

Pick w₁ = 1, w₂ = 1, b = -1.5. Compute all four inputs:

x₁ x₂ score = x₁ + x₂ − 1.5 output
0 0 −1.5 0
0 1 −0.5 0
1 0 −0.5 0
1 1 +0.5 1

This is AND. One line slices off the top-right corner:

   x₂
 1 │  ○         ●
   │     ╲
   │       ╲    ← line: x₁ + x₂ = 1.5
   │         ╲
 0 │  ○        ○ ╲
   └──────────────→ x₁
      0        1

OR is the same idea with b = -0.5. Both linearly separable. One knife works.

XOR — the truth table

XOR ("exclusive or"): output 1 when exactly one input is 1.

x₁ x₂ y
0 0 0
0 1 1
1 0 1
1 1 0

Plot:

   x₂
 1 │  ●        ○
 0 │  ○        ●
   └──────────────→ x₁
      0        1

   ● = output 1
   ○ = output 0

The ●s are at opposite corners. Now we try to cut.

Three knives — every one fails

Attempt A: vertical knife (w₁=1, w₂=0, b=-0.5 → cut at x₁=0.5)

point should be knife says result
(0,0) 0 0
(0,1) 1 0 wrong
(1,0) 1 1
(1,1) 0 1 wrong

Two wrong.

Attempt B: diagonal knife (w₁=1, w₂=1, b=-0.5 → x₁+x₂=0.5)

point should be knife says result
(0,0) 0 0
(0,1) 1 1
(1,0) 1 1
(1,1) 0 1 wrong

Three right, one wrong. Closest yet.

Attempt C: any other line. Same problem. At least one point lands on the wrong side.

Why every knife must fail

Walk around the square in order:

  (0,0) → (0,1) → (1,1) → (1,0) → back to (0,0)
   ○       ●       ○       ●

   no  →  yes  →   no  →  yes

The output alternates at every corner.

A straight line splits the plane into two halves. Each half holds two corners. But our four corners group diagonally — no straight line keeps the two ●s together and the two ○s together. Geometry forbids it.

To cut a checkerboard, we need a bend. One straight cut cannot. So one perceptron is not enough. We need a rule pile.

1.2 Why this matters to you

You will write or fine-tune neural networks soon — likely transformers. Things will go wrong. Loss will not go down. Output will be garbage. Training will diverge.

To debug those, you need to reason about the stack: activation choice, init, optimizer, learning rate, batch size, regularization. Every concept in this module is one knob in that debugging space.

You are also interviewing for Lead AI Engineer roles. The interviewer will ask: why GELU and not ReLU? Why AdamW and not SGD? Why Chinchilla ratio? If you only memorized definitions, your answer is shallow. If you understand which failure each choice fixes, your answer is senior-level.

Each formula in the textbook is a fix for a specific failure. Learn the failures. The formulas will follow.


Chapter 2: Many rules with bends

2.1 Stacking alone does not help — the rule pile collapses

Mental model: stretching paper

Picture a flat sheet of paper. Each layer applies two operations: stretch it (multiply by W) and shift it (add b). Both keep the paper flat.

Layer 1: h = W₁x + b₁ → stretched, shifted. Still flat.
Layer 2: y = W₂h + b₂ → stretched again, shifted again. Still flat.

Worked example with numbers

Pick W₁ = 2, b₁ = 3 and W₂ = 4, b₂ = 1. Try input x = 5.

Layer 1: h = 2·5 + 3 = 13
Layer 2: y = 4·13 + 1 = 53

Now substitute layer 1 into layer 2 algebraically:

y = W₂(W₁x + b₁) + b₂ = W₂W₁·x + (W₂b₁ + b₂) = 8x + 13

Check with x = 5: y = 8·5 + 13 = 53. Same answer.

Two layers acted exactly like one layer with W = 8, b = 13. No new power. Same single line.

Why XOR cannot be solved this way

Section 1.1 proved one line cannot do XOR. Two stacked linear layers = one effective line. So two stacked linear layers also cannot do XOR. Stack 100 linear layers — still one line. The pile collapses.

We must bend the paper. Not stretch it.

2.2 The bend — activation function

Mental model: origami

Paper has to fold to make any shape. ReLU is the simplest fold:

  ReLU(x) = max(0, x)

   y
   │       ╱
   │     ╱
   │   ╱
   │ ╱
───┴──────────────→ x
   0

Negative inputs flatten to zero. Positive inputs pass through. One fold at zero.

Compare the activation family

sigmoid:           tanh:              ReLU:              GELU:

  1 ┐ ___           1 ┐ ___              ┐    ╱             ┐    ╱
    │/                │/                 │   ╱              │   ╱
  0 ┤              0 ─┤                  │ ╱                │ ╱
    │/                │                  │╱                 │╱  (smooth bend)
    │_____         -1 ┘___           ────┴──            ────┴──

flat ends         flat ends          gradient = 1       gradient ≈ 1
gradient → 0      gradient → 0       in positive        smooth everywhere
                                     region

Worked example — XOR solved with two ReLU units

Hand-pick weights for a 2-layer MLP. Architecture: 2 inputs → 2 hidden units (ReLU) → 1 output.

Hidden layer: - h₁ = ReLU(x₁ + x₂) (fires when at least one input is on) - h₂ = ReLU(x₁ + x₂ − 1) (fires when both inputs are on)

Output layer: - y = h₁ − 2·h₂

Compute for all four inputs:

x₁ x₂ h₁ = ReLU(x₁+x₂) h₂ = ReLU(x₁+x₂−1) y = h₁ − 2h₂ want
0 0 ReLU(0) = 0 ReLU(−1) = 0 0 − 0 = 0 0 ✓
0 1 ReLU(1) = 1 ReLU(0) = 0 1 − 0 = 1 1 ✓
1 0 ReLU(1) = 1 ReLU(0) = 0 1 − 0 = 1 1 ✓
1 1 ReLU(2) = 2 ReLU(1) = 1 2 − 2 = 0 0 ✓

All four correct. XOR solved.

What just happened? h₁ is essentially OR. h₂ is essentially AND. Output is OR − 2·AND = "exactly one is on". The bend at zero (ReLU) let h₂ stay silent for OR-only cases but fire for AND. Without the bend, h₂ could not have on-off behavior — linear functions cannot switch.

Universal approximation

The universal approximation theorem: with enough hidden units and any non-linearity, an MLP can approximate any continuous function. Origami can build any shape.

But the rule pile is empty. Random weights will not give XOR. How do we find the right weights?


Chapter 3: How weights get values

3.1 Weight initialization — where the rule pile starts

Before training, every weight needs a starting value. What should it be?

Bad idea: all zeros

If all weights start at zero, every neuron in a layer computes the same thing on the same input. The whole pile becomes one big neuron pretending to be many. Symmetry must break.

Bad idea: any single constant

Same problem. Symmetry must break.

Good idea: small random numbers — but how small?

Two failure modes from the wrong scale:

  • Too large. Forward pass values explode. After 5 layers h is in the thousands. ReLU passes them through unchanged. Output is unstable. Loss = NaN on step 1.
  • Too small. Forward pass values vanish. After 5 layers h is near zero. The network sees no signal. Gradients also vanish — the doom of section 4.1 starts at step 0.

Two practical rules

Xavier (Glorot) init. For sigmoid/tanh networks. Variance of W:

Var(W) = 2 / (fan_in + fan_out)

He init. For ReLU networks. Variance of W:

Var(W) = 2 / fan_in

He compensates for the fact that ReLU kills half the activations (the negative half). Need 2× variance to keep signal magnitude stable.

Concrete example

A linear layer mapping 256 inputs to 256 outputs. He init:

std(W) = √(2 / 256) ≈ 0.088

Each weight is sampled from Normal(0, 0.088). Most weights are between −0.18 and +0.18. Activations stay around magnitude 1.0 through every layer.

PyTorch's nn.Linear defaults to a variant of this. So does TensorFlow's Dense. If you ever see loss = NaN on step 1 and you wrote your own init, init is the first suspect.

3.2 The nudge — backpropagation

Mental model: river flowing backward

Forward pass: input flows downhill through the network, gets mixed at each layer, produces a prediction. Compare to truth. Get loss L.

x ──→[W₁]──→ h₁ ──→[W₂]──→ h₂ ──→[W₃]──→ ŷ ──→ L
                                                  │ "you are wrong by this much"
       ∂L/∂W₁  ←──── ∂L/∂W₂  ←──── ∂L/∂W₃  ←─────┘
       (how much        (how much        (how much
        did W₁           did W₂           did W₃
        cause this?)     cause this?)     cause this?)

Backward pass: blame walks upstream. Each weight learns its share of the error.

The chain rule

For a weight in layer 1:

∂L/∂W₁ = ∂L/∂ŷ · ∂ŷ/∂h₂ · ∂h₂/∂h₁ · ∂h₁/∂W₁

Read left to right: "the loss-blame at the end" × "how the end depends on h₂" × "how h₂ depends on h₁" × "how h₁ depends on W₁".

Multiply the chain. Get the gradient. Step downhill:

W₁ := W₁ − α · ∂L/∂W₁

α is the learning rate. Too big → overshoot the valley. Too small → stuck in fog.

Worked example — one full forward + backward step

Tiny network: 1 input → 1 hidden ReLU → 1 output. MSE loss.

Initial: W₁ = 3, b₁ = −1, W₂ = 0.5, b₂ = 0. Input x = 2. Target y = 1. Learning rate α = 0.1.

Forward pass:

z₁ = W₁·x + b₁ = 3·2 − 1   = 5
h  = ReLU(z₁)              = 5
ŷ  = W₂·h + b₂ = 0.5·5 + 0 = 2.5
L  = (ŷ − y)² = (2.5 − 1)² = 2.25

Backward pass (chain rule, going right to left):

∂L/∂ŷ  = 2(ŷ − y)              = 3
∂L/∂W₂ = (∂L/∂ŷ) · h           = 3 · 5     = 15
∂L/∂b₂ = (∂L/∂ŷ) · 1           = 3
∂L/∂h  = (∂L/∂ŷ) · W₂          = 3 · 0.5   = 1.5
∂L/∂z₁ = (∂L/∂h) · ReLU'(z₁)   = 1.5 · 1   = 1.5    (ReLU' = 1 since z₁ > 0)
∂L/∂W₁ = (∂L/∂z₁) · x          = 1.5 · 2   = 3
∂L/∂b₁ = (∂L/∂z₁) · 1          = 1.5

Update:

W₁ := 3   − 0.1 · 3   = 2.7
b₁ := −1  − 0.1 · 1.5 = −1.15
W₂ := 0.5 − 0.1 · 15  = −1.0
b₂ := 0   − 0.1 · 3   = −0.3

That is one training step. Forward to compute the loss. Backward to compute every gradient via chain rule. Update each weight. Repeat millions of times. This is gradient descent. Same algorithm trains GPT-4.

But the river dies in deep networks. We will see why in chapter 4. First — how do we actually iterate over data?

3.3 Batches, epochs, and stochastic gradient descent

Section 3.2 showed one update step on one input. Real training has millions of inputs. How do we iterate?

Three options

Method Inputs per gradient Steps per pass Speed Noise
Full-batch GD all (e.g., 1M) 1 very slow none
Stochastic GD 1 1M fast very noisy
Mini-batch GD 32–256 (a "batch") thousands fast moderate noise

Mini-batch wins. Practical default. Compute gradient on the batch, average it, take one step.

Vocabulary

  • Batch size: how many inputs per gradient step. Typical: 32, 64, 128, 256, sometimes much larger for LLMs.
  • Iteration / step: one gradient update.
  • Epoch: one full pass over the entire training set.

If training set has 1M examples and batch size is 256, then one epoch ≈ 3,900 steps.

Why "stochastic"?

Each batch is sampled randomly. Each step's gradient is a noisy estimate of the true full-data gradient.

Why noise is good

  1. Speed. Full-batch is too slow.
  2. Escapes shallow local minima. Noise kicks the optimizer out of bad spots.
  3. Finds flatter minima. Empirically, flat minima generalize better. Noise prefers them.

Batch size tradeoffs

  • Larger batch → less noise, more memory, often need higher learning rate (linear scaling rule).
  • Smaller batch → more noise, less memory, sometimes generalizes better.

Production: batch size is one of the most-tuned hyperparameters. Each architecture has a sweet spot. LLM pretraining uses batch sizes of millions of tokens.

3.4 Loss functions — measuring wrongness

Section 3.2 used MSE: L = (ŷ − y)². That works for regression. For classification, we need different losses.

Three common cases

Task Output activation Loss
Regression (predict a number) none MSE
Binary classification (yes/no) sigmoid binary cross-entropy
Multi-class classification (cat/dog/bird) softmax cross-entropy

Softmax — turning scores into probabilities

For K classes, the network outputs K raw scores (logits). Softmax converts them to probabilities:

p_i = exp(z_i) / Σ_j exp(z_j)

All p_i are positive. They sum to 1. Largest logit gets the largest probability.

Worked example. Logits = [2, 1, 0].

exp(2) = 7.39
exp(1) = 2.72
exp(0) = 1.00
sum    = 11.11

p_0 = 7.39 / 11.11 = 0.665
p_1 = 2.72 / 11.11 = 0.245
p_2 = 1.00 / 11.11 = 0.090

Sum = 1.0. Most probability on class 0.

Cross-entropy

If the true class is c, the loss is:

L = −log(p_c)

true class p_c loss
1.0 (perfect) 0
0.5 0.69
0.1 2.30
0.01 4.61
0

Right answer with high confidence → tiny loss. Wrong answer with high confidence → huge loss.

The clean gradient miracle

Compute the gradient of softmax + cross-entropy together. After some calculus:

∂L/∂z_i = p_i − y_i

where y_i = 1 for the true class, 0 elsewhere. The gradient is just predicted probability minus true probability. Beautiful, clean, no saturation.

Why MSE with sigmoid is bad

If you use MSE loss with a sigmoid output, the gradient picks up an extra factor:

∂L/∂z = (ŷ − y) · σ'(z)

That σ'(z) factor is at most 0.25, often near zero in flat regions. So when the model is very wrong (z very large with the wrong sign), the gradient is tiny. Training crawls.

Cross-entropy avoids this — the σ' factor cancels out. Wrong predictions get strong gradients. Right predictions get weak gradients. Exactly what you want.

For LLMs

Loss is cross-entropy on next-token prediction. Softmax over the vocabulary (e.g., 50K tokens). Picks the next token. Same math as multi-class classification — just with K = 50,000.


Pause and recall before chapter 4. We covered three things: init (where weights start), backprop (the chain rule going backward), batches (how we iterate), and losses (how we measure wrongness). Without scrolling — what is He init? What is one training step (forward + backward)? Why is cross-entropy with softmax cleaner than MSE? If any link is fuzzy, scroll back.


Chapter 4: Why deep nets are hard

4.1 Deep rule pile kills the nudge — vanishing gradients

Mental model: whisper game

Ten people in a line. Person 10 hears a number, whispers to 9, who whispers to 8, and so on. Each person muffles the message.

 message strength as it travels back:
 1.0 → 0.1 → 0.01 → 0.001 → ... → 10⁻¹⁰
 layer 10  9     8      7              1

By person 1, nothing is left.

Why sigmoid muffles — concrete numbers

Sigmoid: σ(x) = 1 / (1 + e⁻ˣ). Its derivative: σ'(x) = σ(x)·(1−σ(x)).

Compute at typical activation values:

x σ(x) σ'(x) = σ(x)(1−σ(x))
0 0.5 0.5 · 0.5 = 0.25 (peak)
1 0.73 0.73 · 0.27 = 0.197
2 0.88 0.88 · 0.12 = 0.105
4 0.98 0.98 · 0.02 = 0.020
−2 0.12 0.12 · 0.88 = 0.105

Maximum is 0.25, only at x=0. Anywhere else, smaller. Often much smaller.

Backprop through 10 sigmoid layers. Gradient at layer 1 = product of 10 derivatives. Best case (all 0.25):

0.25¹⁰ = 9.5 × 10⁻⁷

Realistic case (mix, average 0.1):

0.1¹⁰ = 10⁻¹⁰

The gradient that reaches early layers is one ten-billionth of the gradient at the top. Early weights essentially do not move. Network never learns basic features. This is the vanishing gradient problem.

Why ReLU saves the river

sigmoid derivative:              ReLU derivative:

 0.25┐ ___                          1 ┐    ────────
     │   ╲                            │
     │     ╲                          │
   0 ┴───────                       0 ┴────╱
       far from 0:                     positive: 1
       derivative ≈ 0                  negative: 0

ReLU's derivative is 1 in the active region. Multiply ten of those: 1¹⁰ = 1. No muffling. Full gradient reaches every layer.

GELU is the smooth version. Modern transformers use GELU. Llama uses SiLU (same idea).

Callback: the same bend we added in section 2.2 to break linearity is also what saves nudges in deep piles. One choice. Two jobs.

But the nudge itself is also dumb. We need smart nudging.


Halfway-point retrieval. Without scrolling — name every failure so far in order. What broke? What fixed it? You should have at least 7 items by now. If any link is fuzzy, scroll back.


4.2 Smart nudging — the optimizer

Mental model: hiker in a canyon

Loss surface is a landscape. Minimum is at the bottom. Plain gradient descent walks downhill.

Problem: the valley is a canyon. Steep walls in one direction. Gentle floor in another. Hiker bounces wall to wall instead of walking down the floor.

Worked example — why plain SGD fails in a canyon

Take a 2D loss L(w₁, w₂) = 10·w₁² + w₂². Steep in w₁ direction. Flat in w₂ direction.

Gradient: (20·w₁, 2·w₂).

Start at (w₁, w₂) = (1, 1). Apply plain SGD with α = 0.1:

step 0: (1.0, 1.0)     gradient = (20, 2)    update = (−2, −0.2)
step 1: (−1.0, 0.8)    gradient = (−20, 1.6) update = (+2, −0.16)
step 2: (1.0, 0.64)    gradient = (20, 1.28) update = (−2, −0.128)
step 3: (−1.0, 0.512)  ...

w₁ oscillates between +1 and −1 forever. Never converges in the steep direction. w₂ creeps slowly toward 0.

Shrink α to fix w₁ oscillation → w₂ becomes glacial. Trade-off has no winner.

   plain SGD:                  with momentum:

   ╱╲╱╲╱╲╱╲                    ─────────→
   walls bouncing               glides down floor
   wastes steps                 bounces cancel

Three optimizers, what each adds

SGD with momentum. Update keeps a running average of past gradients: v := β·v + g, then w := w − α·v. Bounces in w₁ (alternating signs) cancel out. Drift in w₂ (consistent sign) accumulates.

Adam. Tracks two running averages per parameter: gradient (m) and squared gradient (v). Effective update is m / √v. Steep direction has large v, so update is dampened. Flat direction has small v, so update is amplified. Each parameter gets its own effective learning rate.

AdamW. Adam + decoupled weight decay. In Adam, weight decay was wrongly multiplied by the adaptive scaling, which broke regularization. AdamW separates them. Default for transformers.

This matters more than people expect. Same architecture, same data — switching SGD to AdamW can be the difference between converging in a day and never converging.

Now we have a working rule pile, with bends, with init, with nudges, with smart nudging. Why is GPT-4 still better than a small one?


Chapter 5: Why bigger works (and when it doesn't)

5.1 Scaling laws

Mental model: log-log ramp

Loss decreases as a power law in three things — model size N, data D, compute C.

L(N) ≈ a · N^(−α) (similar form for D and C; α is small, typically 0.05–0.1)

Why power law = straight line on log-log paper

Take logs of both sides:

log L = log a − α · log N

This is y = c − α·x — a straight line with slope −α. Plot on log-log axes:

        log(loss)
            │ ●
            │  ●
            │   ●  ← slope = −α (small but consistent)
            │    ●
            │     ●
            │      ● ●
            │         ●
            │           ●
            └─────────────────→ log(compute)

Concrete prediction

Suppose α = 0.05. Double the model size. New loss = old loss · 2⁻⁰·⁰⁵ ≈ old loss · 0.966.

Each doubling improves loss by ~3.4%. Sounds small. But over 20 doublings (1M× more parameters): 0.966²⁰ ≈ 0.5. Loss halves.

Translation: double the pile → predictable improvement. No mysterious wall. Line keeps going.

Training became a budgeting problem. Tell me your compute, I tell you your loss.

The Chinchilla balance

Early advice was wrong. It said: make the pile very big, use less data. So GPT-3 = 175B parameters, only 300B tokens. Over-sized.

Picture a balance scale:

           ┌─────────┴─────────┐
           │                   │
       parameters           tokens
       (model size)         (training data)

       ratio should be ~ 1 : 20

GPT-3 ratio: 1 : 1.7 (top-heavy)
Chinchilla ratio: 1 : 20 (balanced)

Chinchilla = 70B params, 1.4T tokens. Less than half the size of GPT-3. Still beat it.

Surprise: smaller pile on more photos beats bigger pile on fewer.

This is why Llama 3 8B was trained on 15T tokens. Even pushed past Chinchilla — the ratio became more like 1:1875. Why? Because Chinchilla optimizes training compute. For a model that will be served to billions, inference cost matters more — and a smaller model is cheaper to serve forever. Over-training small models is the modern frontier-lab move.

5.2 Overfitting, regularization, and the generalization mystery

The pile gets bigger. Training loss goes down. But the model is tested in the wild on photos it has never seen. How does it do there?

Two failure modes

training loss validation loss diagnosis
high high underfitting (pile too small)
low low good
low high overfitting (memorizing)

Underfitting fix: make the pile bigger. Train longer.

Overfitting fix: several knobs.

Regularization toolkit

Technique What it does When
More data Harder to memorize, must generalize Always preferred when possible
Smaller model Less capacity to memorize specifics When data is fixed and small
Dropout Randomly turn off some neurons each step Wide use in older nets; less in transformers
Weight decay Add λ·Σ W² to loss → prefer small weights Everywhere; AdamW does this cleanly
Early stopping Stop when validation loss starts rising Always monitor, sometimes act
Data augmentation Random crops, flips, noise → effective data Vision and audio; less in pure text

Dropout intuition

During training, randomly zero out 10–50% of neurons each forward pass. The network cannot rely on any single neuron — it must distribute responsibility across many. At test time, all neurons are on.

Mental model: imagine a study group where random members skip each meeting. Everyone has to know the material. The group becomes robust.

Weight decay intuition

Big weights memorize specific examples. Small weights generalize better. Weight decay pulls weights toward zero — but only if the data does not push them harder. The data wins if the signal is real. Noise loses.

The double descent surprise

Classical theory: "more parameters than data → overfit." For a long time this was assumed.

But empirically, when you keep adding parameters past the data count, validation loss starts improving again. Two valleys, separated by a peak. Modern over-parameterized networks (LLMs, vision transformers) are in the second valley.

Why? Open problem. Theories: implicit regularization from SGD, the lottery ticket hypothesis, neural tangent kernel limits. None fully complete.

5.3 Honest admission

Three uncomfortable truths.

One: we do not fully know why deep rule piles generalize. A pile with billions of rules can memorize the training photos and still do well on unseen photos. Theories exist. None complete.

Two: scaling laws are empirical, not derived. Curves were measured, not proved. We do not know if they bend somewhere we have not yet looked.

Three: hyperparameter choices (lr, batch size, init scale) interact in ways that are still partly empirical. There is no closed-form recipe. Best practice is "what frontier labs published most recently."

Be honest about these in interviews. It builds trust.


Final retrieval before recap. Chapter 5 covered scaling, overfitting, and what we don't know. Without scrolling — what is the Chinchilla ratio? What does dropout actually do mechanically? What is double descent? If any link is fuzzy, scroll back.


Chapter 6: Recap & application

6.1 The failure-fix chain

Visual recap. Each step shows the failure picture and the fix:

# Failure Fix
1 One rule fails at XOR — knife can't cut diagonals Rule pile (multiple layers)
2 Stacked rules collapse — flat paper stays flat The bend (non-linearity)
3 Random or zero weights break or vanish at step 0 Xavier / He init
4 Cannot find rule numbers by hand The nudge (backpropagation)
5 Full-batch is too slow Mini-batch SGD
6 MSE saturates with sigmoid output Cross-entropy + softmax
7 Deep pile kills the nudge — whisper game muffles ReLU / GELU keeps gradient = 1
8 Plain nudging oscillates — bounces canyon walls Smart nudging (Adam, AdamW)
9 Want predictable improvement Scaling laws (log-log ramp)
10 Wrong pile-to-photo ratio Chinchilla (~1:20)
11 Big pile memorizes the photos Regularization (dropout, weight decay)

Every formula in study_material.md fixes one of these. Learn the failures. The formulas memorize themselves.

6.2 Key points to remember

  • The bend (non-linearity) is what makes deep learning possible. Without it, paper stays flat.
  • Init matters at step 0. Wrong init → NaN or vanishing on the very first forward pass. Use He for ReLU, Xavier for sigmoid/tanh.
  • The nudge (backprop) is only chain rule. Blame flowing upstream.
  • Mini-batch SGD balances speed and noise. Noise actually helps generalization.
  • Cross-entropy + softmax has a clean gradient (p − y). That is why it is the canonical classification pair.
  • Vanishing gradient is fixed by activation choice (ReLU/GELU) and supported by good init and residual connections.
  • AdamW is default for transformers. Per-parameter adaptive learning rate plus decoupled weight decay.
  • Scaling is empirical. Power-law ramp on log-log paper. Chinchilla rule: ~20 tokens per parameter (more for inference-heavy models).
  • Overfitting is the next failure after the pile gets big. Mitigate with data, dropout, weight decay, early stopping.
  • Generalization is partly mysterious. Be honest about it.

6.3 Important interview questions

Q: Why GELU and not ReLU in modern transformers?
A: ReLU has dead-neuron problem — negative inputs give zero gradient forever. GELU is smooth, no hard zero region, slightly better empirical performance. Llama uses SiLU, very similar.
Common wrong answer to avoid: "GELU is just smoother." That is true but not the point. The point is smoothness preserves gradient flow even for slightly-negative pre-activations, while ReLU zeroes them out permanently.

Q: AdamW vs Adam — what changed?
A: AdamW decouples weight decay from the gradient update. In Adam, weight decay was scaled by the adaptive learning rate, which broke regularization. AdamW fixes that. Default for transformers.
Common wrong answer to avoid: "AdamW is Adam with weight decay." Adam already had weight decay. The fix is decoupled weight decay. Saying just "with weight decay" reveals you read a summary, not the paper.

Q: Why cross-entropy for classification, not MSE?
A: Cross-entropy is the MLE for categorical distributions. With softmax, its gradient is (ŷ − y) — clean, no saturation. MSE with softmax gives tiny gradients when prediction is wrong, training is slow.

Q: What is the vanishing gradient problem? How do you detect it in practice?
A: Gradient at early layers = product of derivatives through the network. With sigmoid (max derivative 0.25), gradient shrinks exponentially with depth. Detect: monitor gradient norms per layer. If early layers have norms 100× smaller than late layers, you have vanishing gradient. Fix: ReLU/GELU, residual connections, layer norm, careful init.

Q: Universal approximation says one hidden layer is enough. Why do we use deep networks?
A: "Enough" is theoretical. With one hidden layer you may need exponentially many units to approximate a complex function. Deep networks compose features hierarchically — exponentially fewer parameters for the same approximation. Depth is parameter efficiency.

Q: Why does Chinchilla beat GPT-3 with fewer parameters?
A: GPT-3 was over-parameterized for its training data. For a fixed compute budget, model size and data should scale together — about 20 tokens per parameter. GPT-3 had ~1.7 tokens per parameter. Chinchilla had ~20. Same compute, better used.
Bonus answer that distinguishes a Lead candidate: "Llama 3 went past Chinchilla — ~1875 tokens per parameter at 8B. That's intentional. Chinchilla optimizes training compute; production models optimize total cost including inference, where smaller models win indefinitely."

Q: What is the role of learning rate warmup?
A: Early in training, weights are random. Big steps cause instability — gradient estimates are noisy, Adam moments are not yet calibrated. Warmup ramps lr from near-zero up to peak over a few thousand steps. Then cosine or linear decay. Skipping warmup → training divergence on transformers.

Q: How do you choose batch size?
A: Largest that fits in memory, usually. Bigger batch → less gradient noise, often higher achievable learning rate. But too big → diminishing returns and worse generalization. LLMs use millions of tokens per batch via gradient accumulation across many GPUs.

Q: Why does the same model trained twice give different results — same code, same data, same seed?
A: Floating-point non-associativity. On GPU, parallel reductions can complete in different orders depending on scheduling. (a + b) + ca + (b + c) at fp16/bf16 precision. Document hardware in experiment logs.

6.4 Production experience — what this looks like when you ship

  • Loss not decreasing. First suspect: vanishing or exploding gradients. Check gradient norms per layer. If exploding, add gradient clipping (max_norm=1.0 is common). If vanishing, check activation, init, residuals.

  • Loss = NaN at step 1. Almost always init or learning rate. Check init scale. Lower lr. Add gradient clipping defensively.

  • Mixed precision (fp16 / bf16). Saves memory and speeds up training. But fp16 has small dynamic range — gradients can underflow to zero. Use loss scaling (multiply loss by 2¹⁵ before backward, divide gradients after). bf16 has same range as fp32, no scaling needed. bf16 is now default on H100/TPU.

  • Gradient clipping. Common in transformer training. Clip the global gradient norm to 1.0. Prevents single bad batch from blowing up training.

  • Learning rate matters more than you think. For transformer fine-tuning: lr too high → catastrophic forgetting. Too low → no learning. Standard starting points: pretraining ~3e-4, instruction-tuning ~2e-5, LoRA ~1e-4.

  • Batch size and lr are coupled. Bigger batch → can use bigger lr (linear scaling rule, roughly). Halving batch size? Halve lr or training will diverge.

  • AdamW memory cost. Adam stores two extra tensors per parameter (first and second moment). For a 7B model, optimizer state alone is ~84 GB in fp32. This is why optimizer sharding (ZeRO, FSDP) exists.

  • Dead ReLU detection. If a fraction of your hidden neurons output zero on every input, those neurons are dead — their gradient is permanently zero. Diagnose: log the fraction of zero-activations per layer. Fix: switch to GELU/SiLU, lower learning rate, better init.

  • Loss curves never go to zero. If training loss is at 0, you are overfitting or have a label leak. Healthy LLM pretraining loss settles around 1.5–2.5 nats.

  • Validation loss diverging from training loss = overfitting. Add regularization, get more data, or stop training (early stopping).

  • Reproducibility is hard. Same code, same data, same seed — different GPU → different results. Floating-point non-associativity. Document your hardware in experiment logs.

Next module: transformers. We will see what failure of MLPs they fix — and you will recognize the same Linear → activation → Linear pattern from this module sitting inside every transformer's feed-forward block.

6.5 Apply now — graded exercises

Easy (5 minutes)

Open Python:

import numpy as np
X = np.array([[0,0],[0,1],[1,0],[1,1]])
y = np.array([0, 1, 1, 0])
print((X @ np.array([1.0, 1.0]) > 0.5).astype(int))

You just tried XOR with one linear classifier. Output is not [0, 1, 1, 0]. Try any weights you like. No combination works. Feel the failure.

Medium (15 minutes)

Take the tiny network from section 3.2 (W₁ = 3, b₁ = −1, W₂ = 0.5, b₂ = 0, input x = 2, target y = 1).

  1. Compute the forward pass by hand. Confirm L = 2.25.
  2. Compute the backward pass. Confirm ∂L/∂W₁ = 3, ∂L/∂W₂ = 15.
  3. Apply one update with α = 0.1.
  4. Compute the new forward pass and confirm loss has decreased.

If loss did not decrease, you have a sign error. Find it.

Hard (45 minutes)

Implement the XOR-solving 2-layer MLP from section 2.2 in numpy from scratch. No PyTorch.

  1. Define forward pass with manual ReLU.
  2. Define MSE loss.
  3. Implement backward pass — derive every gradient by hand.
  4. Train on the four XOR inputs for 5,000 iterations with α = 0.1.
  5. Confirm final outputs match [0, 1, 1, 0] to within 0.1.
  6. Now scramble the init and try again. Sometimes it converges. Sometimes it gets stuck in a local minimum where two of the four points stay wrong. Why? Hint: section 3.1 + the optimization landscape.

Final retrieval

Without looking, draw the failure-fix table from section 6.1 — failure picture in one column, fix picture in the other. All eleven rows. If you can sketch all eleven from memory — the module is yours.