06 backpropagation
Backpropagation — the nudge in detail
The nudge from ELI5 made physical. Blame flows backward. Each weight gets a memo.
Built on
00-eli5.md. The "nudge" placeholder lives here. Backpropagation is the nudge — every other word is dressing.
The picture before the math¶
See. The cat-robot guesses. The guess is wrong. Now what?
Someone has to take the blame. The output neuron contributed. The hidden neurons before it contributed. The weights stitching them together contributed. Each one gets blamed in proportion to how much it helped produce the wrong answer.
Mental model — blame flows upstream. The output was wrong by some amount. Every weight that contributed gets a share of the blame. Bigger contribution → bigger share.
Each weight gets a small memo. The memo says: shift this much, in this direction, and the loss would have been smaller. That memo is the gradient. The whole act of computing memos for every weight is backpropagation.
This is the nudge from ELI5. Nothing more. Nothing mystical.
ASCII picture — forward and backward on a 2-2-1 net¶
Two inputs. Two hidden neurons. One output. Forward arrows go right. Backward gradient arrows go left.
FORWARD (compute prediction)
═══════════════════════════►
x₁ ──┬─ w₁₁ ─→ h₁ ─ w₃ ─┐
│ │
├─ w₁₂ ─→ (+) ─→ ŷ ─→ L
│ │
x₂ ──┴─ w₂₁ ─→ h₂ ─ w₄ ─┘
w₂₂ ─→
◄═══════════════════════════
BACKWARD (compute gradient memos)
∂L/∂w₁₁ ←─ ∂L/∂h₁ ←─ ∂L/∂ŷ ←─ L
∂L/∂w₂₁ ←─ ∂L/∂h₁
∂L/∂w₁₂ ←─ ∂L/∂h₂ ←─ ∂L/∂ŷ
∂L/∂w₂₂ ←─ ∂L/∂h₂
∂L/∂w₃ ←────────── ∂L/∂ŷ
∂L/∂w₄ ←────────── ∂L/∂ŷ
Two passes. Forward fills the neurons with numbers. Backward fills the weights with memos.
Why "chain" rule? — the picture¶
A weight in layer 1 is far from the loss. Its effect on the loss is indirect. It nudges its own neuron. That neuron nudges the next layer. That nudges the output. The output changes the loss.
So the weight's effect on loss = its effect on its own neuron × that neuron's effect on the next × ... × the last layer's effect on loss.
Multiply along the path. That is the chain.
weight → neuron → next neuron → ... → output → loss
───── ────── ─────────── ────── ────
∂h/∂w ∂h'/∂h ... ∂ŷ/∂h'' ∂L/∂ŷ
∂L/∂w = product of every link on the path
If a single link is small, the whole product is small. This is why deep networks struggle later — link rot. We will see that in 08-vanishing-gradients.md.
Worked numerical example — a 2-2-1 net by hand¶
Purpose: compute every gradient with real numbers. No hand-waving.
Network. Two inputs. Two hidden neurons with ReLU. One linear output. Squared-error loss.
Initial weights:
Layer 1 (input → hidden): Layer 2 (hidden → output):
w₁₁ = 0.5, w₁₂ = 0.3 w₃ = 1.0
w₂₁ = 0.4, w₂₂ = 0.2 w₄ = -0.5
b₁ = 0.0, b₂ = 0.0 b₃ = 0.1
Input: x₁ = 1, x₂ = 2. Target: y = 1. Learning rate: α = 0.1.
Forward pass¶
z₁ = w₁₁·x₁ + w₂₁·x₂ + b₁ = 0.5·1 + 0.4·2 + 0 = 1.3
z₂ = w₁₂·x₁ + w₂₂·x₂ + b₂ = 0.3·1 + 0.2·2 + 0 = 0.7
h₁ = ReLU(z₁) = 1.3
h₂ = ReLU(z₂) = 0.7
ŷ = w₃·h₁ + w₄·h₂ + b₃ = 1.0·1.3 + (-0.5)·0.7 + 0.1 = 1.05
L = (ŷ - y)² = (1.05 - 1)² = 0.0025
Small loss. Good. Now compute every memo.
Backward pass — chain rule, link by link¶
Start at the loss. Walk backward.
Output-layer weights — direct dependence:
∂L/∂w₃ = (∂L/∂ŷ) · h₁ = 0.10 · 1.3 = 0.130
∂L/∂w₄ = (∂L/∂ŷ) · h₂ = 0.10 · 0.7 = 0.070
∂L/∂b₃ = (∂L/∂ŷ) · 1 = 0.10
Push the blame back into hidden activations:
Through the bend (ReLU' = 1 since z₁, z₂ > 0):
Layer-1 weights — chain reaches all the way back:
∂L/∂w₁₁ = ∂L/∂z₁ · x₁ = 0.10 · 1 = 0.10
∂L/∂w₂₁ = ∂L/∂z₁ · x₂ = 0.10 · 2 = 0.20
∂L/∂w₁₂ = ∂L/∂z₂ · x₁ = -0.05 · 1 = -0.05
∂L/∂w₂₂ = ∂L/∂z₂ · x₂ = -0.05 · 2 = -0.10
Each line is the chain. Multiply along the path. The memo for w₁₁ says: the loss depends on you through z₁, then through h₁ (via ReLU, which is alive), then through w₃, then through ŷ. Total influence = 0.10. Shift accordingly.
Three weight updates by hand¶
Apply w := w - α · ∂L/∂w.
w₃ := 1.0 - 0.1 · 0.130 = 0.987 (output weight)
w₁₁ := 0.5 - 0.1 · 0.10 = 0.490 (input → hidden, top path)
w₂₂ := 0.2 - 0.1 · -0.10 = 0.210 (input → hidden, bottom path)
Each weight got its memo. Each weight shifted a tiny step toward less wrong. This is the nudge from ELI5, fully unpacked.
The gradient computation graph — ASCII tree¶
L
│
∂L/∂ŷ
│
┌─────────────┼─────────────┐
│ │ │
∂L/∂w₃ ∂L/∂w₄ ∂L/∂b₃
∂L/∂h₁ ∂L/∂h₂
│ │
ReLU' ReLU'
│ │
∂L/∂z₁ ∂L/∂z₂
┌─┴─┐ ┌─┴─┐
∂L/∂w₁₁ ∂L/∂w₂₁ ∂L/∂w₁₂ ∂L/∂w₂₂
Each leaf is a weight memo. Each branch reuses what was computed above it. Reuse is the whole reason backprop is fast — the shared sub-expressions are cached, not recomputed.
Forward-mode vs reverse-mode¶
Same graph. Two ways to carry derivatives.
- Forward-mode. Push tangent information forward. Cost scales like
O(inputs). Good when outputs are many and inputs are few. - Reverse-mode. Push adjoints backward. Cost scales like
O(outputs). Good when inputs are many and outputs are few. - JVP vs VJP. Forward-mode naturally gives Jacobian-vector products (JVPs). Reverse-mode naturally gives vector-Jacobian products (VJPs).
- Why backprop wins. Neural nets have millions of weights but usually one scalar loss. Reverse-mode is cheaper by miles. That reverse-mode walk on a computation graph is backprop.
Pause and recall. Without scrolling — why is it called the "chain rule"? What's the dL/dw for the last-layer weights specifically?
Where this lives in the wild¶
The memo machinery ships in every modern AI system. It is not folklore — it is silicon-level engineered:
- PyTorch autograd engine. Builds a dynamic computation graph during the forward pass.
loss.backward()walks that graph in reverse and fills.gradon every tensor. Therequires_grad=Trueflag on tensors is the on-switch for memo tracking. - JAX
gradtransform. Trace-based. You write the forward function;jax.grad(f)returns a new function that returns memos. Pure-functional, JIT-compileable, scales to TPU pods. - TensorFlow
GradientTape. Records operations inside awith tf.GradientTape() as tape:block. Callingtape.gradient(loss, weights)plays the tape backward. Same idea, different ergonomics. - Triton kernel-level backward. When you write a custom GPU kernel for a fused operation, you also write its hand-tuned backward kernel. Triton lets ML engineers ship both forward and backward at hardware speed without dropping to CUDA.
- FlashAttention's fused backward. Standard attention's backward pass would re-materialize the N×N attention matrix and blow up memory. FlashAttention recomputes parts of the forward inside the backward, trading FLOPs for memory bandwidth. The chain rule, optimized by hand for the GPU memory hierarchy.
The pattern. Forward is what users see. Backward is where ML engineering livelihoods are made.
Interview Q&A¶
Q: Why does autograd need to cache the forward-pass values?
A: Each backward link needs forward activations. Look at ∂L/∂w₃ = ∂L/∂ŷ · h₁. We need h₁. So forward values are stashed in memory until backward consumes them. This is why training memory is roughly 2× inference memory.
Common wrong answer to avoid: "to recompute the forward". No — caching avoids recomputation. Gradient checkpointing trades cache for recompute when memory is tight.
Q: Why is backprop O(forward), not O(forward²)? A: Each chain rule node is computed exactly once and reused everywhere it appears downstream. The reverse-mode walk visits each edge once. Naive numerical differentiation would re-run the forward per parameter — that is O(forward × num_params), which is prohibitive. Common wrong answer to avoid: "because matrix multiplies are fast". The speed is structural — sub-expression reuse — not hardware.
Q: What gets stored vs recomputed during backprop? A: By default, all forward activations are stored. Gradients flow once and weights update. Under gradient checkpointing, only a subset of activations are stored — the rest are recomputed during backward. Trades memory for compute. Standard at scale. Common wrong answer to avoid: "Backward never recomputes anything." It usually uses cached activations, yes, but checkpointing deliberately recomputes parts to save memory.
Q: If a ReLU's input was negative on the forward pass, what is its gradient memo on the backward pass? A: Zero. ReLU' is 0 for negative inputs, so the chain rule multiplies by 0 at that link. Every weight upstream of that dead ReLU gets a zero memo for this example. This is why dead ReLUs stop learning — the memo never reaches them. Common wrong answer to avoid: "That backprop is unique to neural networks." No — it is just reverse-mode autodiff applied to any computational graph.
Apply now (5 min)¶
Take a 2-1-1 net with sigmoid activation. Two inputs, one hidden neuron, one output. Pick:
w₁ = 0.5, w₂ = -0.3, b₁ = 0.1 (input → hidden)
w₃ = 0.8, b₂ = 0.0 (hidden → output)
σ(z) = 1 / (1 + e⁻ᶻ), σ'(z) = σ(z)(1 - σ(z))
x₁ = 1, x₂ = 1, target y = 0, MSE loss.
By hand:
- Forward — compute z₁, h = σ(z₁), ŷ, L.
- Backward — compute ∂L/∂ŷ, ∂L/∂w₃, ∂L/∂h, ∂L/∂z₁, ∂L/∂w₁, ∂L/∂w₂.
- Update each weight with α = 0.1.
Then close the notebook. From memory, sketch the computation graph. Draw forward arrows in one direction, gradient memos flowing back. Label each memo with the chain it multiplied. If you can do this without peeking, the nudge is yours.
Bridge. One forward + one backward + one update = one training step. But you cannot afford one step per training example with a billion examples. So we batch. And we shuffle. And we loop. Next:
07-sgd-batches-epochs.md.