Skip to content

07. Tensor parallelism — how one giant layer gets sliced across many GPUs

~17 min read. When one kitchen is too small, we split the stove itself and pay a communication bill.

Built on the ELI5 in 00-eli5.md. The kitchen may be too small for one full model, so we divide the burners across rooms and carry partial dishes between them.


1) Picture first: split the layer, not the request

Suppose one matrix multiply is too large for one GPU. We do not split one order ticket into unrelated meanings. We split the weight matrix itself. Each GPU computes a shard. Then the shards are combined.

input x
  ├──→ GPU 0 holds W slice 0
  ├──→ GPU 1 holds W slice 1
  ├──→ GPU 2 holds W slice 2
  └──→ GPU 3 holds W slice 3
      partial outputs combine

See. The model is still one model. The order ticket is still one request. Only the physical placement changes. That is tensor parallelism in plain words.


2) Column-parallel and row-parallel are the usual pattern

Two common sharding shapes appear again and again. Column-parallel layers split output columns. Row-parallel layers split input rows. They usually alternate inside transformer blocks.

column parallel                  row parallel
x ──→ [W0 | W1 | W2 | W3]       x split ──→ W0, W1, W2, W3
        │   │   │   │                              │
        └─ partial outputs ─┐                      └─ partial sums ──→ all-reduce
                            └──→ concat or gather

Column parallel often needs an all-gather later. Row parallel often ends with an all-reduce sum. Look. The arithmetic stays local first. Communication appears only at the merge points. That is why layer design and collective placement matter.


3) Worked example: hidden size and communication bytes

Take a hidden size of 16,384. Split it across 4 GPUs. Each GPU owns 16,384 / 4 = 4,096 output columns in a column-parallel projection.

Suppose one micro-batch has 8 active token rows. Local output per GPU is:

  • 8 × 4,096 values

  • = 32,768 values

At fp16, that is:

  • 32,768 × 2 bytes

  • = 65,536 bytes

  • = 64 KB per GPU

After a later row-parallel layer, we may need an all-reduce over the full 8 × 16,384 activation. That payload is:

  • 8 × 16,384 × 2 bytes

  • = 262,144 bytes

  • = 256 KB per layer

Multiply that across dozens of layers. Communication is not a side note anymore.


4) Why interconnect quality changes everything

Now what is the problem? A shard on one GPU is useless without the others. Every collective waits for the slowest link. If GPUs talk over a fast fabric, tensor parallelism works well. If they talk over a slower path, communication can eat the latency gain.

So what to do? Use tensor parallelism when the model truly needs it, and when interconnect bandwidth supports it. Do not assume “more GPUs” means “lower latency.” Sometimes you simply moved the bottleneck from memory to networking. Simple, no?


5) When to choose tensor parallelism versus other splits

Tensor parallelism helps when one layer is too wide for one GPU, or when one GPU cannot deliver enough throughput alone. It is less attractive when:

  • small models already fit comfortably,
  • network links are slow,
  • batch sizes are tiny,
  • engineering simplicity matters more than peak speed.

You may then prefer data parallel replicas, pipeline parallel stages, or just a better single-GPU kernel stack. This is why most teams do not hand-wire collectives themselves. They rely on serving frameworks that already know these patterns. Next we compare those frameworks directly.


Where this lives in the wild

  • Llama-70B serving on 4× or 8× GPUs — tensor parallelism is the standard way to fit one interactive model replica across devices.
  • NVIDIA TensorRT-LLM deployments — communication-aware tensor sharding is a first-class optimization knob.

  • Open-source model hosting with DeepSpeed-MII or vLLM — large checkpoints often need tensor parallel groups just to start.

  • Enterprise copilots on H100 nodes — teams trade replica count against tensor-parallel width based on latency targets.

  • Multi-GPU inference for long-context chatbots — one big target model can stay interactive only if shard communication stays cheap.


Pause and recall

  • What exactly is being split in tensor parallelism: the request, the layer, or the tokenizer?

  • In the worked example, how many output columns did each GPU own?

  • Why can communication become the new bottleneck after sharding?

  • When might a plain single-GPU replica be better than tensor parallelism?


Interview Q&A

Q: Why use tensor parallelism instead of only adding more independent replicas?

A: Replicas help throughput when one full model fits on each GPU. Tensor parallelism is for the case where a single model shard must span devices because memory or per-request speed demands it.

Common wrong answer to avoid: "More replicas and tensor parallelism are basically the same." They solve different scaling problems.

Q: Why can tensor parallelism hurt latency on the wrong hardware?

A: Because every layer merge needs communication. If interconnect bandwidth or latency is poor, collective overhead can dominate the saved local compute.

Common wrong answer to avoid: "More GPUs always means faster inference." Collectives can cancel the gain.

Q: Why alternate column-parallel and row-parallel layers instead of one single sharding trick everywhere?

A: Because different projections need different merge behavior. Alternating patterns can keep intermediate work local and place collectives where they are cheapest.

Common wrong answer to avoid: "One shard pattern is universally optimal." Tensor shapes and merge needs differ by layer.

Q: Why do serving teams often hide tensor parallelism behind a framework?

A: Because collective placement, memory layout, kernel fusion, and topology tuning are easy to get wrong manually. Frameworks encode these hard-won patterns.

Common wrong answer to avoid: "It is just a few NCCL calls." Production-quality sharding is much more involved.


Apply now (5 min)

Take hidden size 12,288 and tensor-parallel width 3. Compute the per-GPU output columns for a column-parallel projection. Then compute the fp16 bytes in an 8-token all-reduce over the full hidden state. Sketch from memory:

  • the layer-splitting diagram,

  • the column-vs-row comparison,

  • and the communication-byte calculation.


Bridge. Once you see how many moving parts exist, you stop wanting to build the whole stack by hand. Next we compare the major serving frameworks that package scheduling, memory management, kernels, and multi-GPU logic together. → 08-serving-frameworks.md