Skip to content

06. Mixed precision and activation recompute — buying memory with compute and bits

~19 min read. The model is spread across 64 GPUs at high MFU. Every per-GPU byte is spoken for. Then someone raises the sequence length from 8k to 32k for long-context training, and the run OOMs again — not in the static state this time, but in the activations that ZeRO and 3D parallelism never touched. This file buys back headroom from two places the topology can't reach: the bytes inside each number, and the activations you can recompute instead of store.

Built on 3D parallelism. We placed every collective on the wire it can afford and the 70B model runs at high MFU. But splitting the model across GPUs never shrank the per-GPU budget below what one GPU must hold — and activations, the one term the memory wall arithmetic (file 01) said responds only to a knob, are the term that explodes on long sequences. This file relieves that pressure with two compute-for-memory trades: lower precision (bf16/fp8) and activation recompute, each spending FLOPs or bits to buy the memory that lets a longer sequence fit.


What the topology can't shrink

The last four files spread a 70B model across 64 GPUs: data parallelism for throughput, ZeRO to shard state, tensor and pipeline parallelism to cut the model, 3D placement to put each collective on the right wire. Every one of those mechanisms answered the same question — how do we divide the fixed 16N static state across more GPUs? And they answered it well; each GPU now holds ~1/32 of the parameters with sharded optimizer state.

But none of them touched the fourth memory consumer from file 01: activations. Activations are the intermediate outputs saved during the forward pass so the backward pass can compute gradients. They don't scale with parameter count — they scale with batch size × sequence length × hidden size × layers. Splitting the model across GPUs reduces the params per GPU, but each GPU still computes its layers on the full sequence, so the activation memory for that work lives on that GPU regardless of how finely you've sharded the weights.

So picture the run that fit at sequence length 8k suddenly OOMing at 32k. The static state didn't change — it's still ~17.5 GB per GPU. What grew is activations, roughly linearly in sequence length (and quadratically in the attention scores if not using a fused kernel). The topology has no lever for this. The headroom must come from inside each GPU: fewer bytes per number, or fewer numbers stored. This file builds both.

What this file solves

A run that fits at one sequence length OOMs at a longer one because activation memory — untouched by any sharding or 3D split — grows with sequence length and batch size. This file shows two ways to buy that memory back. First, mixed precision: storing weights and activations in bf16 (and the hottest matmuls in fp8) so each number costs 2 bytes or 1 instead of 4, while a master fp32 copy and careful scaling preserve convergence. Second, activation recomputation (gradient checkpointing): discarding most activations after the forward pass and recomputing them during the backward pass, trading ~30% more compute for a large drop in activation memory. The concrete move: enable bf16 autocast with an fp32 master copy, then wrap transformer blocks in checkpointing when activations still overflow.

Trade one: spend bits, not bytes

The first lever is the size of each number. A naive training loop stores everything in fp32 — 4 bytes per number. But the H100's tensor cores run matrix multiplies far faster in 16-bit, and most values in a transformer don't need fp32's precision. Mixed precision keeps the compute in low precision (bf16, sometimes fp8) while keeping a master copy of the weights and the optimizer moments in fp32 for numerical stability — exactly the 16-bytes-per-param accounting from file 01.

Why bf16 rather than fp16? Both are 16 bits, but they split those bits differently. fp16 has 5 exponent bits and 10 mantissa bits — good precision, narrow range (max ~65,504), so gradients can overflow to infinity or underflow to zero. bf16 has 8 exponent bits and 7 mantissa bits — the same dynamic range as fp32 but less precision. For training, range matters more than precision: a gradient that underflows to zero stops learning, while a slightly-less-precise gradient still points the right way. bf16's wide range is why it became the default for transformers and why it usually needs no loss scaling at all.

   FORMAT          bits   exponent  mantissa   range          used for
   ─────────       ────   ────────  ────────   ──────────     ─────────────────────
   fp32             32      8         23        ±3.4e38        master weights, moments
   bf16             16      8          7        ±3.4e38        weights/activations (default)
   fp16             16      5         10        ±65,504        legacy; needs loss scaling
   fp8 E4M3          8      4          3        ±448           fwd matmul (precision)
   fp8 E5M2          8      5          2        ±57,344        bwd matmul (range)

Teacher voice. The reason bf16 quietly won is range, not precision. fp16's 10 mantissa bits are more precise, but its narrow exponent means small gradients vanish and large ones explode — so fp16 training needs loss scaling (multiply the loss by a constant to shift gradients into fp16's representable range, then divide back). bf16 keeps fp32's full exponent, so gradients rarely under/overflow and loss scaling is usually unnecessary. You traded mantissa bits you didn't need for range you did.

For the 70B run, bf16 storage of activations is already the default — it's why file 01's activation estimate used bf16. The next step down is fp8.

fp8 — when 16 bits is still too many

The H100's Transformer Engine can run the bulk matrix multiplies in 8-bit. fp8 comes in two flavors because no single 8-bit format works for both passes. E4M3 (4 exponent, 3 mantissa) has more precision and a smaller range (±448) — used for the forward pass, where activations and weights are well-behaved. E5M2 (5 exponent, 2 mantissa) has wider range and less precision — used for the backward pass, where gradients span a larger dynamic range. The engine picks the format per tensor and per pass.

But fp8's range is so narrow that bf16's "no scaling needed" no longer holds. fp8 needs per-tensor scaling: each tensor gets its own scale factor chosen from its recent value statistics, so its values land in fp8's representable window. Delayed scaling is the common approach — use the scale computed from the previous few steps' statistics, amortizing the cost of measuring the range. Get the scaling wrong and fp8 silently underflows or saturates, corrupting the gradient.

   PRECISION LADDER for the matmuls (H100)
   ─────────────────────────────────────────
   fp32  ████ 4B   slow, never needed for the matmul itself
   bf16  ██   2B   default; tensor cores ~2× fp32, no loss scaling
   fp8   █    1B   ~2× bf16 again; per-tensor delayed scaling REQUIRED
                    forward = E4M3 (precision), backward = E5M2 (range)

The payoff and the cost: fp8 roughly doubles matmul throughput over bf16 and halves the bytes for the tensors it touches — but it adds the scaling machinery, the risk of silent numerical corruption, and it does not touch the fp32 master weights or optimizer moments (those stay fp32 for convergence). fp8 is a throughput and activation-memory win on the matmul-dominated parts, not a free across-the-board halving.

Trade two: spend compute, not storage

The second lever is whether you store an activation at all. The backward pass needs each layer's input activations to compute that layer's gradient. The forward pass produces them; the naive approach keeps every one alive until the backward pass consumes it. For a deep transformer on a long sequence, that pile of saved activations is enormous — and it's exactly what overflows at 32k context.

Activation recomputation (also called gradient checkpointing) makes a different bet. Instead of storing every layer's activations, store only a few checkpoints — say, the input to each transformer block. Discard everything in between. When the backward pass reaches a block, recompute its internal activations from the saved checkpoint by re-running that block's forward pass, use them for the gradient, then discard them again.

   WITHOUT recompute              WITH recompute (checkpoint per block)
   ─────────────────              ──────────────────────────────────────
   fwd: store ALL activations     fwd: store only block INPUTS
        L1 L2 L3 ... L80                L1✓ (L2..internal discarded) L?✓ ...
   bwd: read stored activations   bwd: RE-RUN block forward to rebuild
                                       its activations, then compute grad
   memory: O(layers × seq)        memory: O(checkpoints) + O(one block)
   compute: 1 forward             compute: ~1.33 forwards (extra fwd in bwd)

The trade is explicit: storing only block-boundary activations cuts activation memory from "all layers" to "checkpoints plus one block's worth at a time," typically a large reduction — but the backward pass now does an extra forward pass per block to rebuild what it threw away, adding roughly 30% more compute (the backward already costs ~2× the forward, so one extra forward is ~33% of the forward, ~30% of the total step). You buy a big chunk of memory with a modest compute tax.

Mini-FAQ. "Why not checkpoint every single layer for maximum memory savings?" You can, but recompute granularity is itself a tradeoff. Checkpoint more often (finer) → less memory held, but more recomputation. Checkpoint less often (coarser, e.g. per block) → more memory, less recompute. Per-transformer-block is the usual sweet spot: it's the natural module boundary, the saved checkpoint is small (one activation tensor), and the recomputed unit is large enough to keep the tensor cores busy. The same memory-vs-compute curve you saw in ZeRO (file 03), now inside one GPU.

The picture: two knobs against the activation wall

The canonical mental model for this file — the per-GPU budget, with the two new knobs acting on the one term the topology couldn't move:

   PER-GPU BUDGET (70B, 3D-split, one GPU)        80 GB ceiling
   ════════════════════════════════════════      ═════════════

   static (params+grads+opt, sharded)  ~17.5 GB   ── fixed by 3D split (files 03–05)
   activations @ seq 8k, bf16, no recompute ~50 GB ── the term that explodes

   ┌─ knob 1: PRECISION ──────────────────────────────────────┐
   │  fp32→bf16 activations: halve the bytes per number         │
   │  bf16→fp8 (matmul tensors): halve again + 2× matmul speed   │
   └────────────────────────────────────────────────────────────┘
   ┌─ knob 2: RECOMPUTE ──────────────────────────────────────┐
   │  store only block-boundary checkpoints, rebuild the rest    │
   │  in the backward pass: big memory cut, ~30% more compute    │
   └────────────────────────────────────────────────────────────┘

   together: a 32k-context run that OOM'd now fits — at the cost of
   ~30% compute (recompute) + scaling overhead (fp8)

Read the asymmetry: the static state is fixed by how you split the model, but activations are negotiable. Precision shrinks every number; recompute deletes most stored numbers and rebuilds them on demand. Both spend something the GPU has spare (FLOPs, or precision the gradient didn't need) to buy something it lacks (memory). That is the same compute-for-memory shape as ZeRO's gather-on-demand and a database recomputing a value instead of caching it.

The 70B run at long context

Thread the running example into long-context training. The 70B model fit comfortably at 8k sequence length (file 03's ~60 GB activation headroom). Push to 32k for long-context fine-tuning and activations roughly quadruple toward ~200 GB of would-be activation memory per GPU — far past the ~60 GB headroom. The OOM is back, and file 01's diagnostic says: it crashed on a long sequence, so it's the activation term, not the static state — reach for recompute, not more sharding.

   70B, 3D-split, per GPU, at seq 32k:
     static (sharded)                       ~17.5 GB
     activations, bf16, NO recompute        ~200 GB  → OOM (only ~62 GB free)
     activations, bf16, WITH block recompute ~35 GB  → FITS
     extra compute from recompute            ~+30% step time
     optional: fp8 matmuls                    further activation + throughput win

The fix order matters and follows file 01's rule. The OOM is in activations (long sequence), so recompute is the first lever — it directly attacks the overflowing term. bf16 is already on. fp8 is a throughput-and-memory bonus on the matmul-heavy parts, enabled once the run is stable. You don't reach for more sharding (files 03–05) — the static state didn't overflow; activations did.

Why recompute instead of just more GPUs or more sharding

The plausible alternative to recompute: throw more GPUs at it and shard activations across them via tensor or sequence parallelism. That works and is real (sequence/context parallelism splits the sequence dimension across GPUs), but it costs more GPUs and more communication — you're paying with hardware and bandwidth. Recompute costs only ~30% of one GPU's compute, which is often idle headroom anyway because the step is memory-bound, not compute-bound. When you have spare FLOPs (common) and scarce memory (the whole module's pressure), recompute is the cheaper trade.

The plausible alternative to mixed precision: train in fp32 for maximum stability. But fp32 doubles every activation and gradient byte (worsening the exact OOM you're fighting) and halves matmul throughput by abandoning the tensor cores' bf16/fp8 paths. fp32 is strictly worse on both axes here — more memory, less speed — for stability that bf16's wide range already provides. The decision is governed by what's scarce: when memory is the binding constraint and FLOPs/precision are spare, spend the spare to buy the scarce.

Teacher voice. Notice the pattern across this whole module: every mechanism spends something abundant to buy something scarce. Data parallelism spends bandwidth to buy throughput. ZeRO spends communication to buy memory. Recompute spends FLOPs to buy memory. fp8 spends precision to buy memory and speed. The art is always identifying which resource is scarce right now — here it's per-GPU memory on long sequences — and paying with whatever you have spare.

Operational signals — is precision or recompute hurting?

Healthy behavior. With bf16: loss curve matches an fp32 baseline closely, no NaNs, no loss scaling needed. With recompute on: activation memory drops sharply, step time rises ~30%, MFU stays reasonable because the recomputed forward keeps tensor cores busy. With fp8: throughput up ~1.3–1.5× over bf16, loss curve tracking bf16 within noise.

First metric to degrade. Step time, when recompute is enabled — the extra forward pass is real compute, ~30% more per step. That's the expected cost, not a bug; the win is that the run fits at all. The first unhealthy signal is loss divergence or NaNs after enabling fp8 — a sign the per-tensor scaling is under/overflowing.

The misleading metric. GPU memory looking great after enabling recompute, while throughput quietly dropped 30% — teams celebrate the fit and miss that they're now compute-bound where they were memory-bound. Watch step time and MFU alongside memory. Another trap: blaming fp8 for a convergence problem that's actually a learning-rate or data issue — disable fp8 first to isolate, since fp8 is the riskiest precision lever.

The graph an expert opens first. The loss curve overlaid against a known-good bf16 (or fp32) baseline, plus the gradient-norm trace. fp8 instability shows as gradient-norm spikes or a loss curve peeling away from the baseline. Recompute, done right, shows identical loss to no-recompute (it's mathematically the same computation, just recomputed) — if the loss changes when you toggle recompute, there's a non-determinism or RNG-state bug in the recomputed forward (e.g. dropout re-rolled differently).

Boundary of applicability — where these trades break

Where they shine. Long-sequence or large-batch training where activations dominate the budget — recompute is close to mandatory above some context length. Matmul-heavy dense transformers on H100/H200, where fp8 doubles matmul throughput. bf16 is essentially always on for transformer training.

Where they become pathological. fp8 on models or layers with extreme value distributions (some attention patterns, very deep stacks) where per-tensor scaling can't keep both range and precision — fp8 silently corrupts and the loss diverges. Recompute when the step is already compute-bound (short sequences, small models) — you add 30% compute to buy memory you didn't need, a net loss. And recompute with non-deterministic ops (dropout, some fused kernels) requires careful RNG-state save/restore, or the recomputed forward differs from the original and gradients are wrong.

The scale limit on intuition. "Lower precision is always faster and smaller" breaks at fp8: below bf16, you pay in scaling complexity and stability risk that can cost more (a diverged run wastes the whole compute budget) than the memory it saves. "Recompute everything for maximum savings" breaks too — finer checkpointing means more recomputation, and past a point the compute tax outweighs the memory relief. Both knobs have a sweet spot, not a "more is better" monotonicity.

The wrong model to carry, and the right one

The seductive-but-wrong intuition: "OOM means I need to shard more or add GPUs." When the OOM is in activations on a long sequence, more sharding does nothing — file 03's mechanisms shard the static state, which isn't what overflowed. You'd add GPUs and OOM identically, the same mistake as file 02's "add GPUs to an OOMing run."

The right model: classify the OOM first (file 01's discipline), then pick the matching lever. Static-state OOM → shard (files 03–05). Activation OOM on long sequences → recompute (this file), then lower precision. Activations are the negotiable term; you buy them back with compute (recompute) and bits (precision), not with more hardware.

Other ways these trades show up as a problem

  • OOM only at long sequence, fine at short — activations overflowed; enable recompute before any sharding change.
  • Step time up 30% after enabling recompute — expected; it's the extra forward pass, the price of fitting.
  • Loss diverges after enabling fp8 — per-tensor scaling under/overflowing; check scale factors, fall back to bf16 to confirm fp8 is the cause.
  • Loss changes when toggling recompute on/off — RNG-state bug; dropout or a stochastic op re-rolls differently in the recomputed forward. Recompute must be bit-identical.
  • fp16 (not bf16) run NaNs early — narrow fp16 range overflowed; either add loss scaling or switch to bf16's wide range.
  • Recompute enabled but memory barely drops — checkpoint granularity too coarse (or the big tensor is the KV/attention scores, needing a fused attention kernel, not block recompute).
  • fp8 enabled but no speedup — the bottleneck isn't the matmul (it's communication or memory bandwidth); fp8 only speeds the matmul-bound parts.

Where this fits the larger systems map

  • Compute-for-memory trade, recurring. Recompute spends FLOPs to avoid storing activations — the same store-vs-recompute curve as ZeRO's gather-on-demand (file 03), a query planner recomputing instead of caching, and a JIT recompiling instead of holding code. File 01 flagged this shape; here it returns as the primary lever.
  • Precision as a lossy compression. bf16/fp8 are lossy compression of the numbers, accepting controlled error to save space and bandwidth — the same bargain as quantized vector indexes (03_vector_retrieval_infrastructure) and lossy media codecs: shed bits the consumer can't perceive.
  • Range vs precision tradeoff. bf16 chose range over precision; fp8's two formats split the same way (E5M2 range, E4M3 precision). This range-vs-precision axis recurs anywhere fixed bits must cover a wide dynamic range — audio, sensor data, financial fixed-point.
  • Scarce-vs-abundant resource matching. Spending spare FLOPs to buy scarce memory is the same judgment as spending spare CPU to compress data for a scarce network, or spare disk to memoize an expensive query. Identify the binding constraint, pay with the slack.

Where this appears in production

Mixed precision:

  • NVIDIA Transformer Engine — the library that drives fp8 on H100/H200, picking E4M3/E5M2 per pass with delayed per-tensor scaling.
  • PyTorch AMP (torch.autocast + GradScaler) — the standard mixed-precision API; autocast runs ops in bf16/fp16, GradScaler handles fp16 loss scaling.
  • bf16 as the transformer default — Llama, Mistral, Falcon, and most modern LLMs train in bf16 precisely for its fp32-matching range.
  • Meta Llama 3 training — bf16 throughout, with fp8 explored for matmul speedups on H100s.
  • NVIDIA FP8-LM / MS-AMP — frameworks demonstrating fp8 training of large LLMs with the scaling machinery productized.
  • DeepSpeed / Megatron fp8 integration — expose Transformer Engine fp8 as a config flag in the 3D-parallel stack.
  • Google bf16 on TPUs — bf16 originated on TPUs and is the native training precision there.

Activation recomputation:

  • PyTorch torch.utils.checkpoint / checkpoint_sequential — the native gradient-checkpointing API; wrap a module and its activations are recomputed in the backward pass.
  • Hugging Face gradient_checkpointing=True — a one-line Trainer flag that wraps transformer blocks for recompute, standard for fine-tuning on tight memory.
  • Megatron-LM selective activation recompute — recomputes only the cheap-to-recompute, memory-heavy parts (e.g. attention) rather than whole blocks, a finer point on the curve.
  • FlashAttention — fuses the attention computation so the giant seq × seq score matrix is never materialized, attacking the quadratic activation term recompute alone can't cheaply fix.
  • DeepSpeed activation checkpointing — integrates recompute with ZeRO and offload, including CPU-offloading the checkpoints themselves.
  • Long-context fine-tuning (32k–128k) — recompute is effectively mandatory above moderate context lengths; every long-context recipe enables it.
  • QLoRA — combines 4-bit quantized base weights with recompute to fit large fine-tunes on a single GPU.

Pause and recall

  1. Which of the four memory consumers (file 01) does this file attack, and why don't sharding/3D touch it?
  2. Why is bf16 preferred over fp16 for transformer training, despite fp16 having more mantissa bits?
  3. What are the two fp8 formats, and why does the forward pass use one and the backward the other?
  4. Why does fp8 need per-tensor (delayed) scaling when bf16 usually needs no loss scaling?
  5. Explain activation recomputation: what's stored, what's discarded, what's rebuilt, and the compute cost.
  6. For the 70B run at 32k context, why is recompute the right first lever and not more sharding?
  7. What's the misleading metric after enabling recompute, and what should you watch instead?
  8. If the loss changes when you toggle recompute on/off, what's the bug?

Interview Q&A

Q1. A 70B run fits at 8k context but OOMs at 32k. Walk through your fix. A. Classify the OOM first: it appears only on long sequences, so it's the activation term, not the static state — file 01's discipline. The lever is activation recomputation (gradient checkpointing): store only block-boundary activations, recompute the rest in the backward pass, cutting activation memory sharply for ~30% more compute. bf16 is already on. fp8 is a later throughput/memory bonus. Do not reach for more sharding — the static state didn't overflow. Common wrong answer to avoid: "Increase the ZeRO stage / add GPUs" — sharding attacks static state; the long-sequence OOM is activations.

Q2. Why did the field standardize on bf16 over fp16 for training? A. Range. bf16 has fp32's 8 exponent bits (same dynamic range) with fewer mantissa bits; fp16 has more mantissa but only 5 exponent bits, so its narrow range lets small gradients underflow to zero and large ones overflow to infinity. For training, a gradient pointing roughly the right way beats a precise gradient that vanished — so bf16's range matters more than fp16's precision, and bf16 usually needs no loss scaling. Common wrong answer to avoid: "fp16 is more precise so it's better for training" — its narrow range causes under/overflow that loss scaling must patch; bf16 sidesteps it.

Q3. What does activation recomputation cost, and why is it usually a good trade? A. It costs ~30% more compute per step: the backward pass re-runs each block's forward to rebuild the activations it discarded (one extra forward ≈ 33% of the forward ≈ ~30% of the step). It's a good trade because large-model steps are often memory-bound with spare FLOPs, so you spend an abundant resource (compute) to buy a scarce one (memory) — and the alternative (more GPUs or sequence sharding) costs hardware and bandwidth. Common wrong answer to avoid: "It's free / it saves compute" — it spends ~30% more compute; the win is fitting, not speed.

Q4. After enabling fp8 the loss starts diverging. What's happening and how do you isolate it? A. fp8's very narrow range means per-tensor scaling is likely under/overflowing — values saturating to ±max or rounding to zero, corrupting gradients. Isolate by disabling fp8 and falling back to bf16: if the divergence stops, fp8 scaling is the cause. Check the per-tensor scale factors and the gradient-norm trace for spikes. fp8 is the riskiest precision lever, so it's the first thing to toggle off when debugging convergence. Common wrong answer to avoid: "Lower the learning rate" — that masks the symptom; the root cause is fp8 numerical range, confirmed by toggling fp8 off.

Q5. Toggling recompute changes your loss curve. Bug or expected? A. Bug. Recomputation is mathematically the same computation, just recomputed in the backward pass, so the loss must be bit-identical (within float noise). A changed loss means the recomputed forward differs from the original — almost always an RNG-state issue: dropout or a stochastic op re-rolls differently because the random seed wasn't saved and restored around the recomputed region. Common wrong answer to avoid: "Recompute approximates the activations, small differences are normal" — recompute is exact; differences indicate a non-determinism bug.

Q6. (Cumulative.) A 70B 3D-parallel run OOMs. How do you tell whether it's a static-state problem (files 03–05) or an activation problem (this file)? A. Look at when and where it OOMs (file 01's method). At optimizer construction or step, batch-1, independent of sequence length → static state; shard more or adjust the 3D split. Only at long sequences or large micro-batches, after clean short-sequence runs → activations; enable recompute and check precision. The discriminator is sequence-length sensitivity: static state is insensitive to it, activations scale with it. Common wrong answer to avoid: "An OOM is an OOM, just shard more" — the two OOMs need opposite fixes; misclassifying wastes a cluster.

Design/debug exercise (10 min)

Step 1 — modeled example. Estimate the activation memory effect of recompute on a single GPU at 32k context:

   without recompute: store activations for all 20 layers (this stage) × seq 32k
     ≈ 20 layers × per-layer activation @ 32k ≈ ~200 GB  → OOM
   with block recompute: store only 20 block-INPUT checkpoints + 1 block live
     ≈ 20 small checkpoints + one block's internal activations ≈ ~35 GB  → fits
   compute: backward re-runs each block's forward → ~+30% step time

Step 2 — your turn. For the 70B running example, a teammate wants to train at 128k context. (a) Will recompute alone likely be enough, and what other mechanism attacks the quadratic attention-score term that block recompute doesn't cheaply fix? (b) If they also enable fp8 and the loss diverges, what's your first debugging step? (c) Order the levers you'd apply (bf16, recompute, fp8, FlashAttention) and justify the order by which resource each spends.

Step 3 — reproduce from memory. Without looking, redraw the two-knobs picture (the per-GPU budget with static state fixed and activations attacked by precision and recompute), write the precision ladder (fp32→bf16→fp8 with bytes and the range-vs-precision note), and state the one-line rule: activations are the negotiable term — buy them back with compute (recompute) and bits (precision), not with more sharding. Connect to file 01 in one sentence: only the activation consumer responds to these knobs; the other three are fixed by params and optimizer choice.

Operational memory

This chapter explained the OOM that returns on long sequences after the topology has done all it can — the activation term, which no sharding or 3D split touches because each GPU still computes its layers on the full sequence. The important idea is that activations are the negotiable memory consumer, and you buy them back two ways: lower precision (bf16 for range, fp8 for the matmuls) shrinks every number, and activation recomputation discards most stored activations and rebuilds them in the backward pass for ~30% more compute.

You learned to classify an OOM by sequence-length sensitivity, to enable bf16 with an fp32 master copy and add recompute when activations overflow, and to treat fp8 as a riskier throughput/memory bonus that needs per-tensor scaling. That solves the long-context OOM: the 70B run that died at 32k now fits by recomputing activations, at a compute cost the memory-bound step could afford.

Carry this diagnostic forward: when a run OOMs only on long sequences, the term is activations — reach for recompute first, then precision, never more sharding. When step time jumps 30% after enabling recompute, that's the expected price of fitting, not a bug. When the loss diverges after fp8 or changes when you toggle recompute, suspect scaling and RNG state respectively before the model.

Remember:

  • This file attacks activations, the one memory term sharding and 3D never touch (it scales with seq × batch, not params).
  • bf16 over fp16 because range beats precision for gradients; bf16 usually needs no loss scaling.
  • fp8 (E4M3 fwd, E5M2 bwd) doubles matmul speed and halves bytes but needs per-tensor delayed scaling and risks silent divergence.
  • Activation recompute stores block-boundary checkpoints, rebuilds the rest in the backward pass: big memory cut for ~30% more compute.
  • Classify the OOM first: static-state OOM → shard; activation OOM (long sequence) → recompute then precision.
  • Every lever spends an abundant resource to buy a scarce one — here, spare FLOPs/precision buy scarce memory.

Bridge. We can now fit a 70B model across 64 GPUs at high MFU and stretch it to long context by trading compute and bits for memory. The run trains. But "the run trains" is a single-step property — and a frontier run is not one step, it's a week of a thousand GPUs, where the arithmetic of failure is merciless: at that scale something is always broken. A GPU dies, a link flaps, a host reboots, and the synchronous all-reduce we built in file 02 means one dead rank freezes all the others. Worse, some failures are silent — a bit flips, a gradient is subtly wrong, and the loss keeps falling while the model quietly corrupts. The next file confronts the ugliest truth of scale: survival depends not on speed but on checkpointing, restart, straggler detection, and elastic recovery. → 07-fault-tolerance-and-checkpointing-at-scale.md