Skip to content

03. Next-Token Training Loop — the tiny contract repeated billions of times

What curriculum explains and what still needs mechanics

In chapter 1, we saw that a base model can miss the product contract even when it is fluent. In chapter 2, we saw that the curriculum shapes what the model treats as normal before later alignment starts.

But "the model practiced this text" is still not mechanical enough. We need to see how one piece of text turns into a weight update. Otherwise loss curves, checkpoints, and training runs stay mysterious.

This chapter opens the workbench: shifted tokens become labels, labels become loss, loss becomes gradients, and gradients move weights. Once you can trace one batch, the giant lifecycle has a concrete center.

What this file solves

Training can look like a giant mysterious run instead of a small repeated update. This file traces one shifted sequence from input tokens to logits, loss, gradients, optimizer step, and checkpoint state so the whole lifecycle has a mechanical center.

Why the giant run needs a tiny loop

Without the inner loop, "pretraining" sounds like a cloud of compute. The useful mental move is to reduce the run to one repeated contract: predict the next token, measure surprise, move weights, save enough state to continue.

When loss falls but the update is still a mystery

The naive repair is to trust a falling loss curve. If you cannot trace one batch through labels, logits, loss, gradients, and optimizer state, you cannot tell whether the run is learning the intended task or merely moving numbers.

One shifted sentence is the training game

Text: Rollback started at 14:18
Input: Rollback started at
Target: 14:18
Training nudges the model so 14:18 becomes less surprising next time.

Rule: lower surprise one token at a time

Training changes weights by making the correct next token less surprising.

Why the tiny loop works. Training is simple at the center: guess the next token, measure how wrong the guess was, move the weights a little, repeat.


1) Hook — the whole run hides inside one shifted sequence

Take:

Rollback started at 14:18

The model sees tokens and learns this shifted contract:

input:  Rollback started at 14:
target: started  at      14: 18

No one labels "incident management." The local objective says: given previous tokens, make the next token more likely.

That is the counterintuitive hook: a model learns incident language without ever seeing an "incident-management" label. The label is hidden in sequence structure. The world teaches through continuation pressure.

2) Mental model — conveyor belt through the workbench

tokens ──→ batch ──→ model forward ──→ logits
  ▲                                      │
  │                                      ▼
  └──────── optimizer step ◀── gradients ◀── loss

The workbench is not magic. It is tensor shapes, autograd graph, loss scalar, backward pass, optimizer update.

local token error
gradient direction
tiny weight movement
changed future probabilities

3) Running example — one incident sentence

Suppose the target next token is 14:18. Before training, the model assigns probability 0.02. Cross-entropy loss is about -ln(0.02) = 3.91. After enough similar operational text, it assigns 0.20. Loss becomes 1.61. That drop is the learning signal.

logits = model(input_ids).logits
loss = cross_entropy(logits[:, :-1], input_ids[:, 1:])
loss.backward()
optimizer.step()
optimizer.zero_grad()

If you cannot write that loop, Trainer output is theater.

4) Raw loop versus Trainer for seeing the update

  • Raw PyTorch loop is strongest for learning mechanics, custom research, and debugging; its cost is boilerplate and distributed complexity.
  • HF Trainer is strongest for standard fine-tunes, logging, and checkpoints; its trap is hiding shape bugs and scheduling details.
  • Accelerate loop keeps custom code while helping with multi-GPU launch; it still requires distributed awareness.
  • DeepSpeed/FSDP config helps under large-model memory pressure; config errors can look like model failures.

Use the highest abstraction whose failure modes you can still debug.

5) Gradient accumulation delays the step

If a GPU fits batch 4 but you want effective batch 32, accumulate 8 microbatches:

Microbatch Backward? Optimizer step?
1 yes no
2 yes no
... yes no
8 yes yes

The pressure relieved is memory. The new pressure is stale gradients, longer step latency, and scheduler accounting mistakes.

6) Zeroing gradients too early breaks the batch

Wrong order:

loss.backward()
optimizer.zero_grad()
optimizer.step()

This erases the gradients before the optimizer sees them. The run may "work" in the sense that code executes, while loss does not improve.

Mini-FAQ. "Won't the framework warn me?" Usually no. The operations are legal. The graph did exactly what you asked.

7) What larger batches fix and cost

  • bf16 mixed precision roughly halves activation size and relieves memory/bandwidth pressure, but creates numeric edge cases.
  • Gradient accumulation over 8 microbatches relieves batch memory, but makes optimizer updates less frequent.
  • Gradient clipping at 1.0 caps exploding steps, but can hide a bad learning rate.
  • A 2k-step warmup relieves early instability, but adds schedule tuning.

8) Signals that the training loop is learning or drifting

  • Healthy: training loss falls smoothly, validation loss follows then plateaus.
  • First degrading metric: gradient norm spikes before loss explodes.
  • Misleading beginner metric: GPU utilization alone.
  • Expert graph: loss, LR, grad norm, tokens/sec, and skipped AMP steps on the same timeline.

9) Where next-token loss explains training and where it stops

This loop explains pretraining and SFT mechanics. It becomes incomplete for preference methods where the loss compares responses or policies. It breaks operationally when the model, optimizer state, or activations exceed one device; then the GPU kitchen becomes the main story.

10) Wrong model: the library is what trains the model

Wrong model: "The library trains the model."

Replacement: the library orchestrates the same primitive loop. The supply shelf saves time, but the workbench still determines what can be debugged.

11) Other ways the update contract silently breaks

  • labels not shifted correctly
  • masking user tokens incorrectly during SFT
  • LR scheduler steps per microbatch instead of optimizer step
  • eval accidentally runs with dropout on
  • gradients accumulate across intended boundaries
  • mixed precision silently skips unstable steps
  • checkpoint resumes optimizer but not scheduler
  • dataloader repeats shards in distributed runs

12) The same update-contract shape in systems work

This mirrors database transaction logs: the simple primitive is easy; durability and orchestration create most production failures. It also echoes distributed training later: every abstraction must preserve the same update contract across devices.

13) Quick test: can you trace one batch end to end?

  • Can you show input and target tokens after shifting?
  • Do you know which tokens are masked from loss?
  • Does the LR schedule count optimizer steps correctly?
  • Can you resume from checkpoint without changing training dynamics?
  • Can you explain a loss spike using logs, not vibes?

Where the next-token loop appears in real tooling

  • PyTorch Autograd — records tensor operations and computes gradients.
  • torch.nn.Module — packages parameters and forward behavior.
  • Transformers Trainer — wraps loop, metrics, checkpoints, and distributed launch.
  • Accelerate — keeps custom loops while handling devices and process groups.
  • Weights & Biases / TensorBoard — makes loss/LR/grad timelines inspectable.
  • AMP/bf16 training — trades numeric precision for memory and throughput.
  • Gradient checkpointing — recomputes activations to reduce memory.
  • CI smoke fine-tunes — catch loop regressions before expensive runs.
  • nanoGPT-style scripts — expose the whole loop in a few readable files.
  • Language-model pretraining runs — repeat this contract across trillions of tokens.
  • Instruction tuning — changes the sequences, not the primitive next-token machinery.
  • Tokenizer debugging notebooks — reveal whether labels are shifted and masked correctly.
  • Learning-rate sweeps — test whether optimizer steps are too timid or too violent.
  • Checkpoint resumes — prove that optimizer, scheduler, and RNG state are part of the loop.
  • Distributed launchers — must preserve the same update semantics across ranks.

What you should remember

This chapter explained why the huge training lifecycle has a small mechanical center. The important idea is that training repeats one contract: feed shifted tokens, predict the next token, measure surprise, backpropagate the error, step the optimizer, and save enough state to continue.

You learned to trace one batch through token ids, logits, labels, loss, gradients, optimizer state, and checkpoint artifacts. That solves the mystery of a giant run because it gives you a concrete path to inspect when loss moves but the behavior or update looks wrong.

Carry this diagnostic forward: if a training run looks mysterious, reduce it to one batch. If you cannot explain the labels, masks, gradient accumulation, optimizer step, and saved state for that batch, you do not yet know what the run is training.

Remember:

  • Training is one small update repeated many times.
  • Inputs and labels are shifted views of the same token stream.
  • Loss says how surprising the right next token was.
  • Gradients say how to move weights so that token is less surprising.
  • Optimizer state and RNG state matter if a checkpoint must resume exactly.
  • A wrapper is safe only if it preserves this update contract.

Check your understanding of the training loop

  • What are the input and target sequences for next-token loss?
  • Why can gradient accumulation change scheduler behavior?
  • What graph would reveal exploding gradients before the run collapses?
  • Why must engineers understand the raw loop even when using Trainer?
  • Why is the objective local even when the learned behavior looks global?
  • What state must survive a checkpoint for the next step to be the same step?

Interview Q&A

Q. What are the five operations in a minimal training loop?
A. Forward pass, loss computation, backward pass, optimizer step, and gradient reset.
Common wrong answer to avoid: "Call trainer.train()."

Q. Why does SFT still use next-token loss?
A. The examples change from raw text to instruction-response sequences; the local objective still predicts target tokens.
Common wrong answer to avoid: "SFT uses a completely different objective by default."

Q. What is the risk of gradient accumulation?
A. It reduces memory pressure but can confuse scheduler steps, logging, clipping, and effective-batch assumptions.
Common wrong answer to avoid: "It is exactly identical to a large batch in every way."

Q. Why does next-token loss create broad capabilities instead of only autocomplete?
A. Predicting text across diverse contexts forces compression of syntax, facts, procedures, style, and latent structure that help reduce token surprise.
Common wrong answer to avoid: "The model only memorizes the next word."

Q. Why is label shifting a high-severity bug?
A. If inputs and targets are misaligned, the model receives pressure for the wrong prediction problem while the code can still run and log a loss.
Common wrong answer to avoid: "Shape-compatible tensors mean the objective is correct."

Q. What does a training abstraction have to preserve?
A. It must preserve the forward/loss/backward/step/reset contract, correct masks, scheduler semantics, checkpoint state, and distributed data uniqueness.
Common wrong answer to avoid: "Any wrapper that runs is equivalent."

Apply now (10 min)

  1. Model the exercise: write the shifted input/target pair for a six-token sentence.
  2. Your turn: list what must be checkpointed to resume a run exactly.
  3. Reproduce from memory: sketch the conveyor belt loop and annotate where gradients exist.

Bridge. The loop is small until the tensors are too large. The next pressure is not understanding the loss; it is fitting the loss, activations, gradients, and optimizer state inside the GPU kitchen. → 04-memory-and-parallelism.md