Skip to content

04. Memory and Parallelism — training breaks where bytes multiply

What the training loop explains and what still does not fit

In chapter 3, we reduced training to one repeated update: tokens, logits, loss, gradients, optimizer step, checkpoint. That explains how weights change.

The new problem is that the loop has to store much more than weights. The model, gradients, optimizer state, activations, temporary buffers, and communication state all compete for the GPU kitchen.

This chapter teaches memory accounting before parallelism names. If you can name which pile does not fit, you can choose whether to recompute, shard, offload, split layers, or split tensor operations.

What this file solves

A model can fit for inference but fail during training because training stores much more than weights. This file shows how to estimate parameters, gradients, optimizer state, and activations separately before choosing checkpointing, sharding, or offload.

Why memory planning comes before launching

Large training runs fail when bytes multiply faster than intuition. Before choosing ZeRO, FSDP, checkpointing, or offload, the engineer has to name which pile is too large and which subsystem will pay the cost.

When inference math lies about training fit

The tempting calculation is "weights fit, so training fits." Training also stores gradients, optimizer state, activations, temporary buffers, and communication state, so the first estimate can be off by several times.

When weights fit but training does not

A 7B model may need about 14GB just for bf16 weights.
Training also needs gradients, optimizer state, activations, and temporary buffers.
So "the model fits" does not mean "training fits."

Rule: training memory is more than weights

Training breaks when weights, gradients, optimizer state, and activations cannot all fit cheaply.

Why fitting weights is not enough. If the training state does not fit on one GPU, you either store less, recompute more, split work across machines, or move data around. There is no free option.


1) Hook — the model fits, training does not

A 7B model in bf16 has about 14GB of parameters. An engineer sees an 80GB GPU and assumes training fits. Then Adam states, gradients, activations, and fragmentation push the run out of memory.

Teacher voice. Inference memory is a suitcase. Training memory is a moving apartment.

The hook is that "model size" is the wrong noun. Training stores the model, the model's shadow copies, the model's recent thoughts, and the bookkeeping needed to change it safely.

2) Mental model — four piles in the GPU kitchen

┌──────────────── GPU memory ────────────────┐
│ parameters     gradients     optimizer      │
│ activations    temp buffers   fragmentation │
└─────────────────────────────────────────────┘

For Adam-style training, rough memory can be:

params bf16      ≈ 2 bytes/parameter
gradients bf16   ≈ 2 bytes/parameter
Adam moments fp32≈ 8 bytes/parameter
master weights   ≈ 4 bytes/parameter sometimes

That 7B model can ask for far more than 14GB.

fit problem appears as OOM
      ├─ too many stored activations?  → checkpoint/recompute
      ├─ too much optimizer state?     → ZeRO/FSDP/8-bit optimizer
      ├─ too much parameter memory?    → sharding/quant/adapters
      └─ too slow after sharding?      → communication topology problem

3) Running example — incident summarizer fine-tune

We fine-tune a 7B model with sequence length 4096. Parameters fit. The first OOM appears during backward because activations saved from the forward pass are needed to compute gradients.

Attempt A: lower batch size until it runs. This sacrifices throughput and may destabilize effective batch.

Attempt B: combine bf16, gradient accumulation, activation checkpointing, and sharding.

4) Choosing how to move the memory bill

  • Data Parallel / DDP splits batches and improves throughput, but adds all-reduce bandwidth pressure.
  • ZeRO-1 shards optimizer state and relieves optimizer memory, but adds gather complexity.
  • ZeRO-2 shards optimizer state plus gradients and saves more memory, but increases communication.
  • ZeRO-3 / FSDP shards parameters, gradients, and optimizer state, but creates frequent all-gather pressure.
  • Tensor parallelism splits matrix operations for huge layers, but needs a tight interconnect.
  • Pipeline parallelism splits layers and depth memory, but creates bubbles and scheduling complexity.

DDP is simple when each GPU can hold the model. FSDP/ZeRO exists when the model itself becomes the memory problem.

5) Activation checkpointing trades memory for compute

normal backward:
forward saves activations ──→ backward reuses them

checkpointed backward:
forward drops some activations ──→ backward recomputes them

If checkpointing cuts activation memory by 40% but adds 25% compute, it is a good trade only when memory is the blocking constraint.

6) Sharding can hide the real OOM

Under FSDP, an OOM may happen during all-gather before a layer runs, not when the model is first loaded. The stack trace points at a linear layer, but the cause is shard materialization plus activations plus temporary buffers.

sequenceDiagram
  participant G0 as GPU 0
  participant G1 as GPU 1
  participant L as Layer
  G0->>G1: all-gather parameter shards
  G1->>G0: all-gather parameter shards
  G0->>L: materialized full weight
  L-->>G0: forward activations
  Note over G0: OOM from overlap, not static params

7) What parallelism fixes and makes slower

  • 1x24GB full fine-tune usually does not fit; optimizer state and activations dominate.
  • 1x80GB with a small batch may fit, but long-context activations become the bottleneck.
  • 4x80GB DDP works if the model fits on each GPU, but gradient all-reduce becomes visible.
  • 4x80GB FSDP fits through sharding, but all-gather and config complexity become the cost.
  • QLoRA on 1x24GB often fits, but adapter capacity and quantization overhead become the tradeoff.

Numbers vary by stack, but the accounting categories do not.

8) Signals that memory, not math, is the bottleneck

  • Healthy: stable allocated/reserved memory, predictable tokens/sec, low retry or skipped-step count.
  • First degrading metric: memory spikes at specific sequence lengths or checkpoint boundaries.
  • Misleading beginner metric: static model parameter size.
  • Expert graph: max allocated memory by phase: load, forward, backward, optimizer, checkpoint save.

9) Where sharding helps and where communication wins

Memory techniques are strongest when quality is acceptable and the blocker is fit or throughput. They become pathological when configuration complexity exceeds team debugging capacity. They hit a hard limit when interconnect bandwidth cannot feed the chosen parallelism.

10) Wrong model: if weights fit, training fits

Wrong model: "If weights fit, training fits."

Replacement: training stores several weight-sized companions plus sequence-dependent activations. The GPU kitchen must hold the whole cooking process, not just the ingredients.

11) Other ways training bytes multiply

  • optimizer state doubles or triples expected memory
  • activation memory grows with sequence length
  • gradient accumulation accidentally changes effective LR
  • rank-0 checkpoint saves duplicate full weights
  • DDP dataloaders repeat examples across ranks
  • FSDP all-gather spikes memory
  • CPU/NVMe offload makes training crawl
  • pipeline bubbles destroy utilization

12) The same cost-moving shape in distributed systems

This mirrors cache pressure in inference serving: static model size is only one part of runtime memory. It also echoes distributed systems: sharding relieves storage pressure by creating coordination pressure. The invariant is conservation of cost across subsystems.

13) Quick test: can you account for every memory pile?

  • Can you estimate params, gradients, optimizer state, and activations separately?
  • Does each rank see unique data?
  • Do checkpoints save without materializing extra full copies?
  • Is communication or compute the current bottleneck?
  • Can a smaller sequence length prove the OOM source?

Where memory pressure shapes real training stacks

  • PyTorch DDP — replicates model, shards batch, all-reduces gradients.
  • Accelerate — launches distributed jobs with less boilerplate.
  • FSDP — PyTorch-native full sharding for parameters and states.
  • DeepSpeed ZeRO — optimizer, gradient, and parameter sharding configurations.
  • Megatron-LM — tensor and pipeline parallelism for very large models.
  • QLoRA — avoids full fine-tune memory by training low-rank adapters on quantized bases.
  • Gradient checkpointing — common long-context memory relief.
  • Cluster schedulers — turn bad memory estimates into wasted queue time.
  • NCCL traces — reveal communication becoming the real bottleneck.
  • Sequence-length sweeps — isolate activation-driven OOMs.
  • Optimizer swaps — trade convergence behavior against state memory.
  • Checkpoint save hooks — prevent rank-0 from materializing full duplicate weights.
  • NVMe offload systems — extend capacity while threatening throughput.
  • Long-context tuning — makes activations dominate even when parameters fit.
  • Multi-node training — turns sharding decisions into network-topology decisions.

What you should remember

This chapter explained why training can fail even when the model weights fit on the GPU. The important idea is that training memory includes weights, gradients, optimizer state, activations, temporary buffers, fragmentation, and communication state.

You learned to account for each memory pile before choosing checkpointing, sharding, offload, or parallelism. That solves the opening failure because it turns "the model fits" into a real training budget and shows which subsystem pays when you save memory.

Carry this diagnostic forward: when a training run OOMs, do not ask only "which model size?" Ask which pile grew, which pile can move, and what new cost appears: recompute, communication, CPU transfer, slower throughput, or harder debugging.

Remember:

  • Inference memory is not training memory.
  • Training stores weights, gradients, optimizer state, activations, buffers, and communication state.
  • Every memory trick moves cost somewhere else.
  • Checkpointing saves memory by recomputing.
  • Sharding saves memory by communicating.
  • The right question after OOM is "which pile grew?"

Check your understanding of training memory

  • Why can a model fit for inference but fail for training?
  • What does ZeRO-3/FSDP shard that DDP does not?
  • When is activation checkpointing a good trade?
  • Which memory phase would you graph first after an OOM?
  • What new pressure does each memory-saving mechanism create?
  • Why can the stack trace point at a layer while the real cause is all-gather overlap?

Interview Q&A

Q. Why does Adam make training memory much larger than parameter memory?
A. It keeps optimizer moment estimates and sometimes master weights in addition to parameters and gradients.
Common wrong answer to avoid: "Only activations matter."

Q. When is DDP insufficient?
A. When each GPU cannot hold a full model plus training state, or when memory rather than throughput is the blocker.
Common wrong answer to avoid: "DDP automatically splits the model."

Q. What cost does FSDP introduce?
A. It saves memory by sharding but adds all-gather/reduce-scatter communication and more complex checkpointing.
Common wrong answer to avoid: "FSDP is free memory."

Q. Why is activation memory sequence-length sensitive?
A. Backward needs intermediate values from forward, and longer sequences create more token-level intermediate states to store or recompute.
Common wrong answer to avoid: "Only parameter count affects memory."

Q. How do you choose between checkpointing and sharding?
A. First identify the memory pile causing failure: activations point toward checkpointing; parameters, gradients, or optimizer state point toward sharding or adapters.
Common wrong answer to avoid: "Turn on every memory option at once."

Q. Why can offload make a run technically fit but practically fail?
A. Moving state to CPU or NVMe relieves GPU memory but can starve compute with transfer latency and bandwidth limits.
Common wrong answer to avoid: "If it no longer OOMs, the solution is good."

Apply now (10 min)

  1. Model the exercise: estimate parameter, gradient, and Adam-state memory for a 3B bf16 model.
  2. Your turn: choose DDP, FSDP, or QLoRA for a one-GPU 24GB constraint and explain why.
  3. Reproduce from memory: draw the four memory piles in the GPU kitchen.

Bridge. Once memory pressure is visible, tooling choices stop being cosmetic. PyTorch and Hugging Face are the interfaces through which lifecycle decisions become runnable systems. → 05-pytorch-hf-tooling.md