Skip to content

03. ZeRO sharding and FSDP — stop storing 64 copies of the same state

~19 min read. Sixty-four data-parallel GPUs, each holding an identical 840 GB of optimizer state. That is 63 redundant copies of the same numbers. This file deletes the redundancy — and discovers that every byte of memory you free has to be paid back in communication.

Built on data parallelism. We made GPUs cooperate via the all-reduce, but every replica still holds the full 1,120 GB. This file attacks that redundancy directly with ZeRO sharding and FSDP, and in doing so meets the module's central tension head-on: the communication-vs-memory tradeoff — free more memory, pay more communication, and decide which subsystem absorbs the cost.


What replication wasted, and what we still cannot fit

Data parallelism gave us throughput but left a glaring inefficiency. Across 64 GPUs running the 70B model, each GPU stores the same 140 GB of parameters, the same 140 GB of gradients, the same 840 GB of optimizer state. The optimizer state alone is replicated 64 times — 64 × 840 GB = 53.8 TB of GPU memory spent storing 63 copies of numbers that are bit-identical. And despite all that memory, the model still doesn't fit, because each individual GPU must hold the full 1,120 GB and one H100 has 80.

So data parallelism solved the wrong problem. It scaled throughput while replicating the memory wall 64 times. The previous file's bridge named the obvious question: why store 64 copies of the optimizer state when 64 GPUs could each store one-sixty-fourth of it? This file answers it — and the answer is not free, because a shard you don't hold is a shard you must fetch.

What this file solves

A data-parallel run wastes memory by replicating the optimizer state, gradients, and parameters on every GPU, so the model still won't fit even though 64 GPUs have 5.1 TB between them. This file shows how ZeRO partitions those three populations across the data-parallel group — stage 1 shards optimizer state, stage 2 adds gradients, stage 3 adds parameters — turning the all-reduce into a reduce-scatter plus all-gather, so each GPU holds 1/N of the state. It also shows the bill: stage 3 fits any model but adds an all-gather of parameters on every forward and backward, and you must know when that communication cost is worth the memory it buys.

The key realization: most state is idle most of the time

Look at when each population is actually used during a step:

   forward pass:    needs PARAMETERS (read weights, compute activations)
   backward pass:   needs PARAMETERS (recompute) + produces GRADIENTS
   optimizer step:  needs GRADIENTS + OPTIMIZER STATE + master weights → updates params

The optimizer state (the 840 GB 12N term) is touched only during the optimizer step — a brief moment at the end of each iteration. For the entire forward and backward pass, those 840 GB sit idle on every GPU, replicated 64 times, doing nothing. The gradients are needed only from when the backward pass produces them until the optimizer consumes them. Only the parameters are needed continuously, throughout forward and backward.

This timing asymmetry is the lever. If a population is idle during most of the step, why does every GPU need its own full copy? Each GPU could hold one shard, and the GPUs could gather the pieces only at the moment they're needed. That is the entire idea of ZeRO — the Zero Redundancy Optimizer. Shard the state across the data-parallel group; reconstruct each piece on demand; throw it away after.

The load-bearing rule. ZeRO replaces replication with sharding plus on-demand reconstruction. Each GPU holds 1/N of a population and fetches the rest with a collective only when that population is needed. Memory drops by up to ; the cost is the extra collectives that fetch the missing shards.

The naive attempt and the visible break

The naive instinct, once you see the redundancy, is to just shard the optimizer state and leave everything else replicated — store 1/64 of the moments on each GPU and call it done. This is ZeRO stage 1, and it works, but if you stop there you've left most of the memory on the table. Try to fit the 70B model with only stage 1:

   ZeRO-1 per GPU (N=64):
     parameters (full)      140 GB
     gradients (full)       140 GB
     optimizer state (1/64)  840/64 ≈ 13 GB
     ───────────────────────────────────
     per-GPU static         ≈ 293 GB    → still 3.7× an 80 GB card. STILL OOMs.

Stage 1 cut the 1,120 GB to 293 GB — a huge win — but the parameters (140) and gradients (140) are still fully replicated and still overflow the card. So the real problem is not "the optimizer state is too big"; it is that all three populations are replicated, and freeing only the largest one is not enough when the other two together (280 GB) still exceed the card. Which leads to the question: can we shard the gradients and the parameters too, accepting that we'll have to gather parameter shards back during forward and backward?

The three ZeRO stages — sharding more, communicating more

ZeRO comes in three stages, each sharding one more population and paying one more collective. This is the cleanest illustration in the module of the communication-vs-memory tradeoff, so walk all three with the 70B numbers.

   POPULATION SHARDED       per-GPU static (70B, N=64)     extra communication added
   ─────────────────────    ──────────────────────────    ───────────────────────────
   ZeRO-1  optimizer        140 + 140 + 13   = 293 GB      reduce-scatter grads at step
            state only                                      (replaces all-reduce, same cost)
   ZeRO-2  + gradients      140 +  2.2 + 13  = 155 GB      gradients reduce-scattered, not
                                                            all-reduced — each GPU keeps its shard
   ZeRO-3  + parameters     2.2 + 2.2 + 13   ≈ 17.5 GB     all-gather params before fwd AND bwd,
            (= FSDP)                                        reshard after — per-layer collective

ZeRO stage 1 — shard optimizer state. Each GPU keeps the full params and gradients but only 1/N of the optimizer moments and master weights. After the backward pass, instead of an all-reduce (which gives everyone the full averaged gradient), you do a reduce-scatter: each GPU ends up with only the averaged gradient for the parameters it owns the optimizer state for. It updates its shard, then an all-gather broadcasts the updated parameters. Communication volume is roughly the same as plain data-parallel all-reduce — stage 1 is nearly free in communication.

ZeRO stage 2 — also shard gradients. Once a GPU only updates 1/N of the parameters, it only needs to keep 1/N of the gradients. So gradients are reduce-scattered as they're produced and each GPU retains only its shard — dropping the gradient term from 140 GB to ~2.2 GB. Still no extra communication beyond stage 1's; you're just throwing away gradient shards you don't own as soon as they're reduced.

ZeRO stage 3 — also shard parameters. Now each GPU holds only 1/N of the weights. But the forward pass needs the full weights of a layer to compute it. So before computing each layer, the GPUs all-gather that layer's parameters, compute, then immediately discard the gathered copy (reshard). The backward pass does the same. This is the expensive stage: an all-gather of parameters on every layer, every forward and every backward. In exchange, the per-GPU static state collapses to ~17.5 GB — the 70B model finally fits on one H100 with room for activations.

Teacher voice. Each stage trades the same currency: memory for communication. Stage 1 is almost free because the optimizer state was idle anyway. Stage 3 is expensive because parameters are needed continuously, so reconstructing them on demand means gathering them constantly. The art is using the cheapest stage that fits — not reflexively reaching for stage 3.

A picture of the tradeoff

The canonical mental model for this file — the same step, three sharding depths, memory falling as communication rises:

   MEMORY PER GPU                                 COMMUNICATION PER STEP
   (70B, N=64)                                    (what gets fetched)
   ════════════                                   ══════════════════════

   data-parallel  ████████████████████  1120 GB   all-reduce grads once
   ZeRO-1         █████▏                  293 GB   reduce-scatter grads (≈ same)
   ZeRO-2         ██▊                      155 GB   reduce-scatter grads (≈ same)
   ZeRO-3 / FSDP  ▍                       17.5 GB   all-gather params EVERY layer,
                                                    fwd + bwd, then reshard
                  └──────────────────────────┘
                  memory ↓ 64×             communication ↑ (stage 3 ≈ 1.5× DP traffic)

   The line to remember: stage 1→2 frees memory almost for free.
   Stage 3 frees the rest but adds a per-layer all-gather you pay forever.

The asymmetry that surprises people: going from data-parallel (1,120 GB) to ZeRO-2 (155 GB) is a 7× memory reduction for essentially no extra communication, because optimizer state and gradients were idle outside the update step. The last step, ZeRO-2 to ZeRO-3, frees another 9× of memory but is where the communication bill arrives. Most of the memory win is cheap; the expensive part is only the parameters.

FSDP — the same idea, named by PyTorch

PyTorch's Fully Sharded Data Parallel (FSDP) is ZeRO stage 3 with a PyTorch-native API. The modern version, FSDP2 (fully_shard), shards every parameter as a DTensor across the data-parallel group and registers forward/backward hooks that all-gather a module's parameters just before it runs and reshard them just after. The mechanism is identical to ZeRO-3; the vocabulary differs.

The one tunable worth knowing: reshard_after_forward. After the forward pass gathers a layer's parameters, FSDP can either keep them (so the backward pass doesn't re-gather) or discard them and gather again in the backward pass. Keeping them spends memory to save a gather; discarding them spends a gather to save memory — the communication-vs-memory tradeoff appearing again, now at the granularity of a single layer.

   FSDP2 per layer:
     [all-gather layer params] → [compute fwd] → reshard?  yes → discard, re-gather in bwd
                                                            no  → keep, skip bwd gather (more mem)

For our 70B run on 64 GPUs, FSDP2 with reshard_after_forward=True is the configuration that takes the per-GPU footprint to ~17.5 GB of static state, leaving ~60 GB for activations and communication buffers — finally a model that trains.

Mini-FAQ. "Is FSDP the same as DeepSpeed ZeRO-3?" Functionally yes — both shard all three populations and gather on demand. They differ in implementation, integration (FSDP is native PyTorch; ZeRO ships with DeepSpeed), and extras (DeepSpeed adds CPU/NVMe offload stages ZeRO-Infinity; FSDP2 integrates with torch.compile and DTensor). Pick by ecosystem, not by capability — the memory math is the same.

The 70B run, finally fitting

Thread the running example to its resolution. The model that crashed at step zero on one GPU (file 01), and that data parallelism couldn't help (file 02), now fits:

   70B, 64 × H100 (80 GB), FSDP / ZeRO-3:
     parameters  (1/64)   2.2 GB
     gradients   (1/64)   2.2 GB
     optimizer   (1/64)  13.1 GB
     ─────────────────────────────
     static per GPU      17.5 GB
     activations + buffers (batch-dependent)  up to ~60 GB headroom
     ─────────────────────────────
     → FITS on 80 GB, trains end to end

The cost it pays, relative to plain data parallelism: an all-gather of each layer's parameters in both the forward and backward pass. Measured, ZeRO-3 / FSDP communication is roughly 1.5× the volume of plain data-parallel all-reduce. So you trade a 64× memory reduction for a ~50% communication increase. On NVLink-rich nodes that trade is overwhelmingly worth it; the model that didn't fit at all now trains at ~85–90% of data-parallel throughput. On slow inter-node links, the extra all-gathers can hurt — which is the seam where tensor and pipeline parallelism (file 04) take over.

Why this instead of just buying GPUs with more memory

The plausible alternative: skip sharding, rent H200s (141 GB) or B200s. But the static state is 1,120 GB; even 141 GB cards need eight of them to hold it, and you still need data parallelism across those — at which point you're replicating 1,120 GB eight times again. Sharding is not avoidable by bigger cards; it is more valuable on bigger cards, because it lets each card hold a fraction and use the rest for a larger batch or longer sequences.

A second alternative: CPU offload (ZeRO-Infinity / ZeRO-Offload) — push the optimizer state to host RAM. This fits enormous models on few GPUs, but PCIe bandwidth (~64 GB/s) is an order of magnitude slower than HBM, so every step stalls on offload transfers. It is the right tool when you are GPU-poor and time-rich (fitting a 70B model on a single 8-GPU node for research), and the wrong tool when you have the cluster and need throughput. The decision is governed by whether you have GPUs to shard across or must fall back to slow host memory.

Operational signals — is sharding helping or hurting?

Healthy behavior. Per-GPU memory drops to roughly static/N + activations and stays flat. Throughput within ~10–15% of what data parallelism would give if the model fit — meaning the all-gathers are overlapping well. Memory headroom that scales: adding GPUs to the FSDP group lowers per-GPU static state proportionally.

First metric to degrade. Step time, as you push to ZeRO-3 on a slow interconnect. The all-gathers that hid fine on NVLink stop hiding on InfiniBand; you see step time rise even though memory is comfortable. The signal: memory is fine but throughput dropped after enabling parameter sharding.

The misleading metric. Peak memory looking great while throughput quietly tanks — teams celebrate fitting the model and miss that they're now communication-bound. Watch the ratio of compute time to all-gather time in the profiler, not just whether it fits. Another trap: assuming ZeRO-3 is "better" than ZeRO-2 because it uses less memory — if ZeRO-2 already fits, ZeRO-3 just adds communication for memory you didn't need.

The graph an expert opens first. Profiler timeline showing the per-layer all-gather kernels and whether they overlap the previous layer's compute (FSDP prefetches the next layer's params while computing the current one). Exposed all-gather gaps mean prefetch isn't keeping up — increase the prefetch depth or move to a less aggressive stage.

Boundary of applicability — when to stop sharding

Where it shines. Models whose static state exceeds one GPU but whose single layer fits comfortably. Clusters with fast intra-node interconnect (NVLink) so the all-gathers hide. The 70B-on-64-H100s case is squarely in the sweet spot — ZeRO-3/FSDP is the default first choice.

Where it becomes pathological. When even a single layer's parameters plus its activations don't fit on one GPU — ZeRO-3 can't help, because it still has to materialize a full layer to compute it. Giant individual layers (very wide MLPs, huge attention) need tensor parallelism (file 04), which splits the layer itself. Also pathological: ZeRO-3 across slow inter-node links with small per-GPU compute, where all-gathers dominate.

The scale limit on intuition. "More sharding is always better" breaks at stage 3. Sharding optimizer state and gradients is nearly free; sharding parameters is not. And no ZeRO stage splits a single layer — that is a hard boundary. Past the point where one layer won't fit, you've left ZeRO's domain entirely and need a different cut.

The wrong model to carry, and the right one

The seductive-but-wrong intuition: "ZeRO-3 is strictly better than ZeRO-2 because it uses less memory, so always use the highest stage." No. Each stage up adds communication. If ZeRO-2 already fits your model with room for activations, jumping to ZeRO-3 trades throughput for memory you didn't need — you pay the per-layer all-gather for nothing.

The right model: use the lowest ZeRO stage that fits. Stage up only when the current stage OOMs. Memory you're not using is not a problem to solve; communication you're paying for is. The skill is matching stage to the actual memory shortfall, not maximizing sharding.

Other ways sharding shows up as a problem

  • Throughput drops after enabling ZeRO-3 — the per-layer all-gather isn't hiding; prefetch depth too low, or the group spans slow links. Try ZeRO-2 if it fits, or topology-aware grouping.
  • OOM persists at ZeRO-3 — a single layer plus its activations exceeds one GPU; you need tensor parallelism, not a higher ZeRO stage.
  • CPU-offload run crawls — optimizer state on host RAM, every step stalls on PCIe; only acceptable when GPU-starved.
  • Memory great but step time spiky — all-gather contending with the data loader or other PCIe traffic; isolate the comm streams.
  • FSDP wrapping too coarse — wrapping the whole model as one unit gathers everything at once (no memory win); wrap per transformer block instead.
  • FSDP wrapping too fine — wrapping every tiny module fires thousands of small all-gathers, launch overhead dominates; wrap at block granularity.
  • Mismatched dtype in gathered params — silent precision bug when the gather and compute dtypes disagree; a subtle correctness issue, not an OOM.

Where this fits the larger systems map

  • Same constraint, different mechanism — caching vs recompute. ZeRO chooses not to store what it can reconstruct on demand (gather the shard when needed, discard after). This is the inverse of caching: instead of trading memory to avoid recomputation, it trades communication to avoid storage. Both are points on the store-vs-fetch curve.
  • Shared invariant — replicas stay identical. Like data parallelism (file 02), ZeRO keeps the logical model bit-identical across GPUs; it just stores it sharded. The all-reduce became a reduce-scatter + all-gather, but the convergence guarantee is unchanged.
  • Failure geometry — fetch-on-demand latency. The per-layer all-gather is the same shape as a cache miss that must fetch over the network: cheap when prefetched/overlapped, a stall when it isn't. The same prefetching cure applies.
  • Lazy materialization. Gathering a layer's params just-in-time and resharding after is lazy materialization — the same pattern as streaming a large dataset rather than loading it, or paging memory rather than resident-loading it.

Where this appears in production

  • Microsoft DeepSpeed ZeRO — the original; stages 1/2/3 and ZeRO-Infinity (NVMe offload) powered Turing-NLG and many large open models.
  • PyTorch FSDP / FSDP2 — native sharding used across the ecosystem; fully_shard with DTensor is the modern default for sharded data-parallel training.
  • Meta Llama training — FSDP is a core part of the stack, sharding the data-parallel dimension while tensor/pipeline parallelism handle the rest.
  • Hugging Face Accelerate / Trainer — exposes ZeRO stages and FSDP wrapping policies as config, so users pick a stage without writing collectives.
  • Mosaic/Databricks Composer — defaults to FSDP for billion-parameter training with block-level wrapping.
  • EleutherAI / GPT-NeoX — uses DeepSpeed ZeRO to fit large models on academic clusters where GPU count is limited.
  • Lightning (PyTorch Lightning)strategy="fsdp" / DeepSpeed strategies wrap the sharding for users.
  • bitsandbytes + FSDP (QLoRA at scale) — combines sharding with quantized optimizer state to fit very large fine-tunes.
  • Axolotl / LLaMA-Factory — fine-tuning frameworks that default to FSDP/ZeRO so a 70B fine-tune fits on a single 8-GPU node.
  • NVIDIA NeMo — supports both Megatron-style parallelism and ZeRO-style sharding, choosing per model shape.
  • JAX / TPU shard_map — the same shard-and-gather pattern expressed in XLA collectives on TPU pods.
  • DeepSpeed ZeRO-Offload deployments — research labs fitting 13B–70B models on a single workstation by paying PCIe latency.
  • Ray Train + FSDP — orchestrates sharded training across cloud GPU clusters with FSDP under the hood.

Pause and recall

  1. How much GPU memory does a 64-way data-parallel 70B run waste on redundant optimizer state, and why?
  2. What timing property of the optimizer state makes it the cheapest thing to shard?
  3. Name the three ZeRO stages and what each one shards.
  4. Why is going from data-parallel to ZeRO-2 nearly free in communication, while ZeRO-3 is not?
  5. What collective replaces the all-reduce in ZeRO, and into what two operations does it split?
  6. What does FSDP2's reshard_after_forward trade off, and at what granularity?
  7. For 70B on 64 GPUs, what is the per-GPU static state under ZeRO-3? Compare to data-parallel.
  8. When does ZeRO-3 fail to help, and what mechanism do you reach for instead?

Interview Q&A

Q1. Walk me through the memory and communication tradeoff across ZeRO stages 1, 2, and 3. A. Stage 1 shards optimizer state (the biggest term, and idle outside the update step) — large memory win, near-zero extra communication. Stage 2 also shards gradients (kept only for the params you own) — more memory win, still no extra communication. Stage 3 shards parameters too, but parameters are needed continuously, so it must all-gather each layer's weights on every forward and backward and reshard after — the memory win is largest but it adds ~50% communication. Use the lowest stage that fits. Common wrong answer to avoid: "Higher stage is always better" — stage 3 adds per-layer all-gathers you pay for forever; if stage 2 fits, stage 3 is wasted communication.

Q2. A 70B model OOMs under ZeRO-2 but the team insists each layer is small. What do you check, and what's the fix? A. Check the per-GPU static breakdown: under ZeRO-2 the parameters are still fully replicated (140 GB), which alone exceeds 80 GB. The fix is ZeRO-3 / FSDP to shard the parameters too, dropping per-GPU static to ~17.5 GB. If even ZeRO-3 OOMs, then a single layer plus activations doesn't fit and you need tensor parallelism. Common wrong answer to avoid: "Lower the batch" — the static parameter replication, not activations, is overflowing under ZeRO-2.

Q3. Why is sharding optimizer state almost free but sharding parameters expensive? A. Timing. The optimizer state is touched only during the brief optimizer step, so it can sit sharded the whole forward/backward and be gathered just once at update — minimal extra traffic. Parameters are read continuously throughout forward and backward, so sharding them means gathering each layer's weights on demand every time they're needed — a per-layer collective in both passes. Common wrong answer to avoid: "Parameters are bigger" — they're actually the smallest fixed term (140 GB vs 840 GB optimizer state); the cost is about access frequency, not size.

Q4. When would you use CPU offload (ZeRO-Infinity) over GPU sharding? A. When you're GPU-starved and time-rich — fitting a model that exceeds even your full cluster's GPU memory, e.g. a 70B model on a single 8-GPU node for research. The cost is PCIe-bound steps (~64 GB/s vs HBM's TB/s), so throughput craters. With enough GPUs to shard across, never offload; the cross-GPU all-gather is far faster than host transfers. Common wrong answer to avoid: "Offload is just a slower version of sharding, use it when memory is tight" — it's a fundamentally different bandwidth regime; use it only when you lack GPUs to shard across.

Q5. How do you tune FSDP wrapping granularity, and why does it matter? A. Wrap at transformer-block granularity. Too coarse (whole model as one unit) gathers all parameters at once — no memory win, defeats the purpose. Too fine (every linear layer) fires thousands of tiny all-gathers whose launch overhead dominates. Block-level wrapping gathers one block's worth of params at a time, overlapping the next block's gather with the current block's compute via prefetch. Common wrong answer to avoid: "Wrap everything individually for maximum sharding" — fine wrapping kills throughput via collective launch overhead.

Q6. (Cumulative.) Compare what data parallelism (file 02) and ZeRO-3 (file 03) do to memory and to the all-reduce. A. Data parallelism replicates the full 1,120 GB on every GPU and uses one all-reduce per step — scales throughput, not capacity. ZeRO-3 shards all three populations so each GPU holds ~17.5 GB, and replaces the all-reduce with a reduce-scatter (gradients) plus per-layer all-gathers (parameters) — scaling capacity at a ~50% communication premium. ZeRO-3 is data parallelism that stopped wasting memory, at the price of more collectives. Common wrong answer to avoid: "ZeRO-3 replaces data parallelism" — it is data parallelism, restructured to shard rather than replicate; the data-parallel group is the sharding group.

Design/debug exercise (10 min)

Step 1 — modeled example. Compute per-GPU static state for a 13B model under each scheme, N=8 GPUs:

                 params  grads   optimizer   per-GPU static
data-parallel    26      26      156         208 GB   → OOMs on 80 GB
ZeRO-1           26      26      19.5         71.5 GB  → barely fits, no activation room
ZeRO-2           26       3.3    19.5         48.8 GB  → fits with some room
ZeRO-3 / FSDP     3.3     3.3    19.5         26.1 GB  → comfortable
Note ZeRO-2 already fits 13B on 8 GPUs — so jumping to ZeRO-3 here would add per-layer all-gathers for memory you don't need.

Step 2 — your turn. For the 70B running example on 64 GPUs, compute per-GPU static state under ZeRO-1, ZeRO-2, and ZeRO-3. Identify the lowest stage that fits under 80 GB with at least 30 GB left for activations. Then state the communication cost you accepted to get there.

Step 3 — reproduce from memory. Without looking, draw the memory-vs-communication tradeoff diagram (the four bars: DP, ZeRO-1, ZeRO-2, ZeRO-3, with memory falling and communication noted), and write the one-line rule: stage 1→2 is nearly free; stage 3 buys the rest of the memory by adding a per-layer parameter all-gather. Connect to file 02 in one sentence: ZeRO-3 is data parallelism that shards instead of replicates, turning the all-reduce into reduce-scatter + all-gather.

Operational memory

This chapter explained why a 64-way data-parallel run wastes 53 TB on redundant optimizer state, and how ZeRO/FSDP deletes that redundancy by sharding each population across the data-parallel group and reconstructing pieces on demand. The important idea is the communication-vs-memory tradeoff made concrete: stage 1 and 2 free most of the memory almost for free because optimizer state and gradients are idle outside the update step, while stage 3 frees the rest by paying a per-layer parameter all-gather forever.

You learned to compute per-GPU static state for each stage, to pick the lowest stage that fits, and to recognize FSDP as ZeRO-3 with a PyTorch-native API whose reshard_after_forward knob is the same tradeoff at layer granularity. That solves the 70B fitting problem from files 01 and 02 — the model that crashed at step zero now trains at ~17.5 GB per GPU.

Carry this diagnostic forward: when a sharded run is slow but fits, suspect exposed all-gathers and check the profiler before raising the stage. When even ZeRO-3 OOMs, stop climbing stages — a single layer doesn't fit, and you need to split the layer itself, which is the next file.

Remember:

  • ZeRO replaces replication with sharding + on-demand reconstruction; memory drops up to , communication rises.
  • Stage 1 (optimizer), 2 (+gradients), 3 (+parameters) — each shards more and pays more.
  • Stages 1→2 are nearly free in communication (idle state); stage 3 adds a per-layer all-gather (~1.5× DP traffic).
  • FSDP = ZeRO-3 in PyTorch; wrap at transformer-block granularity, not whole-model or per-linear.
  • Use the lowest stage that fits — unused memory is not a problem; paid-for communication is.
  • ZeRO-3 cannot split a single layer — when one layer won't fit, you've left ZeRO's domain.

Bridge. ZeRO let the 70B model fit by sharding the state of data parallelism — but it never splits a layer. Every GPU still computes whole layers, just on gathered-then-discarded weights. That holds until a single layer's weights or activations exceed one GPU, or until the per-layer all-gathers swamp a slow interconnect. The next file makes a different cut: instead of sharding the optimizer's bookkeeping, it splits the math of a layer across GPUs (tensor parallelism) and splits the sequence of layers across GPUs (pipeline parallelism) — introducing a new waste we haven't seen yet, the pipeline bubble. → 04-tensor-and-pipeline-parallelism.md