Skip to content

07. SGD, batches, epochs — how the nudge actually happens

The training loop, demystified. Why mini-batch is the only option that ships, and what every knob does.

Built on the nudge from 00-eli5.md. Backprop in 06-backpropagation.md gave us the gradient for one example. Now we ask — when do we actually apply it?


The setup question

See. Backprop tells us the direction to nudge each weight to make one example less wrong. Fine. But you have a million examples. So what to do?

Three honest options:

  1. Compute the gradient on all million examples. Average it. Take one nudge.
  2. Compute the gradient on one example. Take a nudge. Repeat one million times.
  3. Compute the gradient on a small handful — say 64 examples. Average. Nudge. Repeat.

Option 1 is full-batch GD. Option 2 is online (stochastic) SGD. Option 3 is mini-batch SGD. Option 3 wins everywhere in practice. The rest of the file is — why.


The mental model — noisy peeks at a hidden map

Imagine the loss landscape is a real mountain range you cannot see. You only feel one tiny patch of it at a time, with your foot.

  • Full-batch. You step on every square inch of the mountain, average the slope, then take one careful step. Perfect map. But by the time you compute it once, the sun has set.
  • Online. You feel one square inch and step. Then another, and step. Each peek is wildly noisy — sometimes the patch points uphill even though the mountain trends down. You wobble, but you cover ground.
  • Mini-batch. You feel 64 square inches at once, average them, step. Still noisy — but the average of 64 patches is much closer to the truth than one. And 64 patches fit on a GPU at the same time, so it costs you almost the same as one.

Each batch is a noisy peek at the real loss surface. Many noisy peeks, in the right direction on average, give a good map of where to descend. That is the whole idea.


Trajectory pictures — three styles of descent

   Full-batch GD            Online SGD              Mini-batch SGD
   (smooth, slow)        (zig-zag, fast)         (some wobble, fast)

   ●────●────●────●       ●  ●                    ●──●
              ╲             ╲╱ ╲  ●                  ╲
               ●             ●   ╲╱                   ●─●
                ●            ╱╲  ●                       ╲
                 ●          ●  ╲╱ ╲                       ●─●
                  ●  goal    ●  ●  ●  goal                   ●  goal

Full-batch traces a clean curve down. Online bounces around it like a drunk. Mini-batch wobbles but trends true. Mini-batch is goldilocks — enough noise to escape shallow traps, enough averaging to point in roughly the right direction.


Worked numerical — one weight, three update styles

Tiny dataset. Four examples. We are training a single weight w, currently w = 1.0. Each example contributes a per-example gradient g_i:

Example g_i
1 +0.4
2 -0.2
3 +0.6
4 -0.8

Learning rate η = 0.1. Update rule: w ← w − η · (avg gradient).

Full-batch — one nudge using all 4

Average gradient = (0.4 − 0.2 + 0.6 − 0.8) / 4 = 0.0. Update: w = 1.0 − 0.1 · 0.0 = 1.000. One step. No movement. The signal cancelled itself out at this scale.

Mini-batch (size 2) — two nudges per pass

Batch A = examples {1, 2}. Avg = (0.4 − 0.2)/2 = +0.1. Update: w = 1.0 − 0.1 · 0.1 = 0.990.

Batch B = examples {3, 4}. Avg = (0.6 − 0.8)/2 = −0.1. Update: w = 0.990 − 0.1 · (−0.1) = 1.000.

Two steps. Net same place, but it explored. That exploration is what escapes flat regions.

Online (size 1) — four nudges per pass

Step g w after
1 +0.4 0.960
2 −0.2 0.980
3 +0.6 0.920
4 −0.8 1.000

Four steps. Maximum wobble. Same final point here — the example was rigged for that — but the path swung wider. On real data, that wobble can both help (escape) and hurt (instability).

Same data. Three update budgets. Each is the nudge from ELI5, just applied at different granularities.


Why full-batch breaks at scale

Now switch from 4 examples to 100,000. Forward pass on one example costs, say, 1 ms on a GPU. Backward pass — roughly 2 ms. So one example round-trip ≈ 3 ms.

Full-batch needs all 100,000 examples processed before one weight update:

100,000 examples × 3 ms = 300,000 ms = 5 minutes per single nudge

Training typically needs tens of thousands of nudges. That is months of wall clock for one model. Now scale to GPT-pretraining-size data — trillions of tokens. Full-batch is structurally dead. You cannot afford to wait for the perfect map. You need to start descending now, even with a smudgy map.

Mini-batch flips this. With batch size 256, you get 390 nudges per pass over 100k examples — same compute, 390× more learning signal applied. Plus the GPU stays full because 256 examples fit in parallel. Full-batch wastes capacity on serial scans; online wastes capacity by under-filling the GPU.


Vocabulary — pin these down

  • Batch size. How many examples go into one gradient average. Common: 32, 64, 128, 256. LLM pretraining: ~4 million tokens per batch.
  • Step / iteration. One weight update. One nudge.
  • Epoch. One full pass over the whole training set. With 1M examples and batch size 256, one epoch = ~3,900 steps.
  • Shuffling. Re-randomize order before each epoch so batches do not see the same neighbours every time.

Is the epoch concept still useful? For small data — yes, count epochs. For LLM-scale — no, count tokens seen. You may not even reach one full epoch before training ends.


Learning rate — the step-size dial

The nudge has two parts. Direction (from backprop) and size (the learning rate η). Picture the loss surface as a bowl:

  loss ↑
       │       too large η: bouncing
       │       ●╲    ╱●╲    ╱●
       │        ╲  ╱   ╲  ╱
       │         ●       ●
       │      goal: bottom of bowl
       │       just right η: smooth descent
       │       ●╲
       │         ╲●╲
       │            ╲●╲
       │              ╲●● goal
       │       too small η: crawling
       │       ●●●●●●● ... still going ... ●●●● goal
       └────────────────────────────────────────→ steps

Too small — training crawls, may not finish in budget. Too large — gradient sends you flying past the minimum, loss diverges to NaN. Just right — fast monotonic descent.

Three tricks dominate practice:

  • Warmup. Start at near-zero η, ramp up over the first few thousand steps. Why? At init, the rule pile is random — early gradients are huge and chaotic. A big η on a chaotic gradient blows up. Warmup gives the network a moment to settle.
  • Decay. As loss flattens, smaller steps refine. Cosine decay (smooth ramp from peak to ~10% of peak over training) is the modern default.
  • Linear scaling rule. Double the batch size → double η. Less noise per step means each step is more trustworthy, so you can take bigger ones.

Gradient clipping

Sometimes one bad batch creates a giant gradient spike. If that spike lands unchecked, one update can wreck otherwise good weights. So we clip the norm: if ‖g‖ > threshold: g = g · threshold/‖g‖. In most transformer training, threshold = 1.0 is a very common default. It does not choose the direction. It only stops disaster. Seatbelt, not steering wheel.


The training loop in pseudocode

for epoch in 1..E:
    shuffle(dataset)                # break order correlations
    for batch in chunks(dataset, B):
        loss, grads = forward_backward(batch)   # backprop
        η = lr_schedule(step)        # warmup → cosine decay
        weights ← weights − η · grads            # the nudge
        step += 1

That is it. The whole training script. Everything else — gradient clipping, mixed precision, checkpointing — is engineering on top of these six lines.


Pause and recall. Without scrolling — what is one epoch? Why is full-batch dead at 100k examples? What does warmup protect against? Where in the pseudocode is the nudge from ELI5?


Where this lives in the wild

  • GPT-style pretraining. Frontier LLMs use batch sizes of ~4 million tokens per step. That is hundreds of GPUs working in parallel, each with a slice. Smaller batches starve the GPU; much larger batches stop helping (gradient already low-noise enough).
  • Anthropic Claude training. Massive effective batches, often millions of tokens, are built with gradient accumulation across many GPUs. The update is still one nudge — just assembled from many micro-batches.
  • DeepSpeed and PyTorch FSDP gradient accumulation. When the desired batch is too big to fit in memory, you accumulate gradients across many micro-batches, then apply one nudge. Lets a 70B model train with effective batch 4M tokens on hardware that can hold only 64k tokens at once.
  • PyTorch DataLoader (shuffle=True). The default in every tutorial. Shuffling each epoch breaks accidental ordering — without it, if your dataset is sorted by class, every batch is one class and the gradient is biased. Subtle, common, breaks training silently.
  • NVIDIA's cosine LR schedule cookbook. Linear warmup over ~2k steps, then cosine decay to 10% of peak η over the rest of training. This is the production default for almost every transformer trained on NVIDIA hardware.

The pattern. Real systems live in option 3 — mini-batch with engineered batch size, warmup, and decay. Options 1 and 2 are textbook artifacts.


Interview Q&A

Q: Why does mini-batch SGD beat full-batch GD in practice?
A: Three reasons. Speed — many small nudges arrive faster than one perfect nudge. GPU utilization — a batch of 256 fills parallel hardware that one example cannot. Generalization — the gradient noise empirically pushes solutions toward flatter minima, which generalize better to test data.
Common wrong answer to avoid: "full-batch finds a better minimum". The full-batch minimum is often a sharp one that overfits — empirically worse on held-out data, even if loss looks lower on training.

Q: What is an epoch and is the concept still useful?
A: An epoch is one full pass over the training set. For supervised training on fixed datasets — yes, count epochs and watch validation curves. For LLM pretraining on internet-scale data — no, count tokens; you may train for less than one epoch deliberately to avoid memorization.
Common wrong answer to avoid: "more epochs always better". After a point, training loss keeps falling but validation loss rises — that is overfitting. Epoch count is a hyperparameter, not a goal.

Q: Why shuffle data each epoch?
A: To break order correlations. If the dataset is sorted by class or by time, an unshuffled batch is biased — its gradient does not represent the full distribution. Shuffling each epoch makes successive batches independent samples, so the noisy gradient is unbiased on average.
Common wrong answer to avoid: "Shuffling is optional if the dataset is clean." No — even a clean dataset can have order structure that biases every batch.

Q: What is the right batch size?
A: There is no universal answer. Start at 32–256 for vision, larger for text, much larger for LLMs. Two practical bounds — must fit in GPU memory; should not be so large that gradient noise vanishes (you lose the regularization benefit). Tune via the linear scaling rule: when you double batch size, double learning rate, then check that loss curves still descend cleanly.
Common wrong answer to avoid: "That larger batches always train faster." Beyond a critical batch size, returns diminish and generalization can suffer.


Apply now (5 min)

Take the XOR dataset from 01-xor-problem.md — four examples. Pick batch size 2, so two batches per epoch. Imagine the per-example gradients on weight w1 are [+0.3, -0.1, +0.5, -0.7] and η = 0.1, starting w1 = 0.5.

Compute by hand:

  1. Batch 1 (examples 1+2): avg gradient, new w1.
  2. Batch 2 (examples 3+4): avg gradient, new w1.
  3. That is one epoch. Two nudges. Where did w1 end up?

Then — without looking — sketch the loss-vs-step curve from memory for full-batch, online SGD, and mini-batch. Three side-by-side panels. Smooth line, jagged jumpy line, wobbly-but-trending line. If the shapes are right, the mental model is in.


Bridge. The nudge now happens — many times per epoch, on noisy mini-batches, with a tuned step size. But there is a deeper failure waiting. In a deep rule pile, the nudge shrinks to nothing before it reaches the early layers. They never learn. Read 08-vanishing-gradients.md next.