Skip to content

Exercise 08 — Sampling Strategies

Timebox: 30-45 minutes

Goal

Implement greedy, temperature, top-k, and top-p (nucleus) sampling from a logits array. Cover the most common live-coding question on decoding.

Work in

  • sampling.py

Tasks

  1. greedy(logits) — argmax.
  2. temperature_sample(logits, T) — scale logits, softmax, sample.
  3. top_k(logits, k, T) — keep top-k logits, mask the rest, sample.
  4. top_p(logits, p, T) — nucleus: smallest set whose softmax sums to ≥ p.
  5. Combine: sample(logits, T, k=None, p=None) that applies whichever filters were given.

Done when

  • Pure NumPy or pure PyTorch, no library helpers
  • A unit test confirms each function does what it says on a hand-picked logits array
  • You can explain to a peer why top-p is preferred when output entropy varies a lot

Stretch

  • Add repetition penalty
  • Add min_p (any token below min_p × max_prob masked)