ML Coding Rounds — Interview Questions¶
The "implement this from scratch in 30-45 min on a whiteboard / shared editor" round. AI engineer loops increasingly include one of these — usually multi-head attention, beam search, top-p/top-k sampling, an autoregressive generation loop, or a small LoRA layer. Distinct from classic-algo.md (DSA-flavored: LRU, trie, union-find) and practical-takehomes.md (multi-hour deliverables: web crawler + LLM, JSON+LLM pipeline).
The senior tell is not memorized recall — it's the running commentary. State the shapes of every tensor before you write the einsum. Identify which operation is the bottleneck. Mention where numerical stability matters. Get the masking right the first time. Discuss what would change at training vs inference. The interviewer wants to see the mental model, not just a working snippet.
Code examples in this file use PyTorch idioms (the 2026 default for ML coding rounds). NumPy is sometimes preferred for "from-scratch, no frameworks" framings; flag this with the interviewer up front.
Attention & transformer internals¶
Q: "Implement scaled dot-product attention."¶
Tags: senior · very-common · coding · source: 2026 AI engineer loops; standard ML-coding probe across all top labs
Answer outline:
- Inputs: Q (B, H, T_q, D), K (B, H, T_k, D), V (B, H, T_k, D), optional mask (B, 1, T_q, T_k) or (T_q, T_k).
- Output: (B, H, T_q, D).
- Steps:
1. Compute scores: scores = Q @ K.transpose(-2, -1) / sqrt(D). Shape (B, H, T_q, T_k).
2. Apply mask: scores = scores.masked_fill(mask == 0, -inf) (or -1e9).
3. Softmax along the last dim: attn = softmax(scores, dim=-1).
4. Multiply with V: out = attn @ V. Shape (B, H, T_q, D).
- Reference implementation:
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, V), attn
sqrt(D): the dot product of two D-dim vectors has variance ~D under standard init; without scaling, softmax saturates as D grows, gradients vanish.
- Mention: in production you'd call F.scaled_dot_product_attention (which picks FlashAttention / memory-efficient implementations); the from-scratch version is for understanding.
- Numbers to drop: "memory complexity: O(T²) for the score matrix — the wall for long contexts", "FlashAttention removes the O(T²) memory by tiling and recomputing"
Common follow-ups: - "How does FlashAttention change this?" - "Add causal masking." - "Extend to GQA (grouped-query attention)."
Traps: - Forgetting to scale by sqrt(D). Classic miss. - Off-by-one in masking. Causal mask is upper-triangular = 0, lower-triangular incl diagonal = 1. - Softmax over the wrong axis.
Related cross-cutting: Evaluation & quality
Related module: learning/00_ai_foundation/02_tokens_embeddings_context/
Q: "Implement multi-head attention."¶
Tags: senior · very-common · coding · source: 2026 AI engineer loops; standard ML-coding probe
Answer outline: - Reference implementation:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
B, T, D = x.shape
Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
scores = (Q @ K.transpose(-2, -1)) / (self.d_head ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = (attn @ V).transpose(1, 2).contiguous().view(B, T, D)
return self.W_o(out)
x: (B, T, D).
- After projection + view + transpose: (B, H, T, d_head).
- Scores: (B, H, T, T).
- Output before merge: (B, H, T, d_head).
- After transpose + contiguous + view: (B, T, D).
- Note bias=False for projections is conventional in modern transformer impls.
- For cross-attention, K and V come from a separate sequence; otherwise identical.
- Numbers to drop: "GPT-3 175B: d_model=12288, n_heads=96, d_head=128", "modern models use GQA — fewer K/V heads than Q heads to shrink KV cache"
Common follow-ups: - "Modify for grouped-query attention (GQA)." - "Add rotary position embeddings (RoPE)." - "Add KV caching for inference."
Traps:
- Forgetting .contiguous() before the final view. Stride mismatch.
- Conflating d_model and d_head. Always state both.
Related cross-cutting: Evaluation & quality
Related module: learning/00_ai_foundation/02_tokens_embeddings_context/
Q: "Modify MHA to grouped-query attention (GQA)."¶
Tags: senior · common · coding · source: 2026 AI engineer loops; Llama-3 / Mistral / Qwen2 architecture
Answer outline:
- GQA: n_q_heads query heads share n_kv_heads K/V heads, where n_q_heads % n_kv_heads == 0. Each KV head serves n_q_heads / n_kv_heads query heads.
- Why: K and V dominate the KV cache. Cutting their head count by 4-8× shrinks the cache by the same factor with minimal quality loss.
- Implementation: project K and V to n_kv_heads * d_head instead of n_heads * d_head. Then repeat (broadcast) K and V along the head dim before the attention compute.
- Reference snippet:
Q = self.W_q(x).view(B, T, self.n_q_heads, self.d_head).transpose(1, 2) # (B, n_q, T, d)
K = self.W_k(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2) # (B, n_kv, T, d)
V = self.W_v(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
repeat = self.n_q_heads // self.n_kv_heads
K = K.repeat_interleave(repeat, dim=1) # (B, n_q, T, d)
V = V.repeat_interleave(repeat, dim=1)
n_kv_heads = 1.
- Numbers to drop: "Llama-3 70B: n_q=64, n_kv=8 → 8× KV cache shrink", "GQA quality loss vs MHA: <0.5% on most benchmarks", "MQA: more aggressive, slightly more quality loss"
Common follow-ups: - "Why does GQA exist? What problem does it solve?" - "Show the KV cache size before and after."
Traps: - Forgetting to expand K and V. Shape mismatch.
Related cross-cutting: Cost & latency
Related module: learning/00_ai_foundation/02_tokens_embeddings_context/
Q: "Add KV caching for autoregressive inference."¶
Tags: senior · common · coding · source: 2026 AI engineer loops; standard ML-coding probe
Answer outline: - At decode time, each step generates one token. Without caching, you'd recompute K and V for all prior tokens every step — O(T²) total work for T tokens. - With KV caching: store K and V tensors per layer; append the new step's K and V; reuse the rest. Each step is O(T) (just attending the new Q to all cached K, V). - Reference shape:
# During decode, x has shape (B, 1, D) — one new token
Q = self.W_q(x).view(B, 1, n_h, d).transpose(1, 2)
K_new = self.W_k(x).view(B, 1, n_h, d).transpose(1, 2)
V_new = self.W_v(x).view(B, 1, n_h, d).transpose(1, 2)
K = torch.cat([cache_K, K_new], dim=2) # (B, n_h, T+1, d)
V = torch.cat([cache_V, V_new], dim=2)
# update cache for next step
cache_K, cache_V = K, V
scores = (Q @ K.transpose(-2, -1)) / (d ** 0.5)
# no mask needed for decode-only: Q is one token attending all of K
attn = F.softmax(scores, dim=-1)
out = attn @ V
2 * n_layers * B * T * n_kv_heads * d_head * dtype_size. For Llama-3 70B at 8k context, BF16, B=1: ~1.3 GB. At B=64: ~84 GB.
- This is why PagedAttention (vLLM) exists: managing this memory efficiently across many concurrent sequences.
- Prefill is the first forward pass on all prompt tokens; populates the cache. Decode is one step at a time after that.
- Numbers to drop: "Llama-3 70B KV cache: ~20MB per 1k tokens per sequence", "PagedAttention shrinks waste from ~60% to <4%", "GQA cuts KV cache by n_q/n_kv×"
Common follow-ups: - "How does PagedAttention work?" - "What's the difference between prefill and decode caching?"
Traps: - Recomputing K/V for cached tokens. Misses the point.
Related cross-cutting: Cost & latency
Related module: learning/00_ai_foundation/02_tokens_embeddings_context/, learning/02_ai_infrastructure/02_inference_serving_systems/
Sampling¶
Q: "Implement top-K sampling."¶
Tags: senior · very-common · coding · source: 2026 AI engineer loops; standard ML-coding probe
Answer outline:
- Input: logits (V,) for vocab size V. Pick top K, set the rest to -inf, renormalize, sample.
- Reference:
import torch
import torch.nn.functional as F
def top_k_sample(logits, k, temperature=1.0):
logits = logits / temperature
vals, idx = torch.topk(logits, k)
mask = torch.full_like(logits, float('-inf'))
mask.scatter_(-1, idx, vals)
probs = F.softmax(mask, dim=-1)
return torch.multinomial(probs, num_samples=1)
k=1 is greedy (argmax). k=vocab_size is pure sampling.
- Numbers to drop: "typical k: 40-100", "k=1 is deterministic = greedy", "temperature: 0.7-0.9 for chat, 1.0 for creative"
Common follow-ups: - "What if you want both top-K and top-p?" - "Implement greedy decode."
Traps: - Doing softmax before top-K. Wastes compute and the probabilities are wrong. - Wrong dim in topk for batched inputs.
Related cross-cutting: Cost & latency
Related module: learning/01_ai_engineering/03_decoding_strategies/
Q: "Implement nucleus (top-p) sampling."¶
Tags: senior · very-common · coding · source: 2026 AI engineer loops; standard ML-coding probe
Answer outline: - Idea: instead of fixed K, take the smallest set of tokens whose cumulative probability ≥ p. Adaptive: vocab size varies by step uncertainty. - Reference:
def top_p_sample(logits, p, temperature=1.0):
logits = logits / temperature
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= p # keep first token always
# set tokens beyond the nucleus to -inf
sorted_logits = sorted_logits.masked_fill(mask, float('-inf'))
# restore original order
logits_filtered = torch.full_like(logits, float('-inf'))
logits_filtered.scatter_(-1, sorted_idx, sorted_logits)
probs = F.softmax(logits_filtered, dim=-1)
return torch.multinomial(probs, num_samples=1)
Common follow-ups: - "Why nucleus instead of top-K?" - "What's the failure mode of top-p in high-uncertainty steps?"
Traps: - Excluding the highest-probability token if it alone exceeds p. Subtle off-by-one. - Sorting but forgetting to unsort before sampling.
Related cross-cutting: Cost & latency
Related module: learning/01_ai_engineering/03_decoding_strategies/
Q: "Implement beam search."¶
Tags: senior · common · coding · source: 2026 AI engineer loops; classic ML-coding probe
Answer outline: - Idea: maintain B "beams" (partial sequences with cumulative log-probabilities). At each step, expand every beam by every vocab token, score, keep top B. - Reference skeleton (pseudocode-leaning):
def beam_search(model, prompt_ids, beam_width, max_len, eos_id):
# beams: list of (token_ids tensor, score)
beams = [(prompt_ids, 0.0)]
finished = []
for _ in range(max_len):
candidates = []
for tokens, score in beams:
logits = model(tokens)[:, -1, :] # (V,)
log_probs = F.log_softmax(logits, dim=-1)
top_lp, top_idx = log_probs.topk(beam_width)
for lp, idx in zip(top_lp.tolist(), top_idx.tolist()):
new_tokens = torch.cat([tokens, torch.tensor([[idx]])], dim=-1)
candidates.append((new_tokens, score + lp))
candidates.sort(key=lambda x: x[1], reverse=True)
beams = []
for tokens, score in candidates[:beam_width]:
if tokens[0, -1].item() == eos_id:
finished.append((tokens, score / tokens.size(-1))) # length-normalize
else:
beams.append((tokens, score))
if not beams:
break
return max(finished + beams, key=lambda x: x[1] / x[0].size(-1))[0]
(5+len)^α / 6^α heuristic.
- Beam search is deterministic given fixed model + prompt. Used in machine translation, summarization, code completion. Less common in chat (greedy / sampling preferred for diversity).
- Numbers to drop: "beam width: 4-10 typical, diminishing returns past 16", "length penalty α: 0.6-1.0", "compute: B× a single greedy decode"
Common follow-ups: - "Why length-normalize?" - "When would you use beam search vs sampling?"
Traps: - No length normalization. Beams collapse to short EOS-terminated outputs. - Returning the longest beam instead of the highest-scoring one.
Related cross-cutting: Cost & latency
Related module: learning/01_ai_engineering/03_decoding_strategies/
Q: "Write a generation loop using your sampler."¶
Tags: senior · very-common · coding · source: 2026 AI engineer loops; standard ML-coding probe
Answer outline: - Skeleton:
def generate(model, prompt_ids, max_new_tokens, eos_id, sampler):
tokens = prompt_ids.clone()
cache = None # KV cache
for _ in range(max_new_tokens):
logits, cache = model.forward_with_cache(tokens[:, -1:] if cache else tokens, cache)
next_token = sampler(logits[:, -1, :]) # (B, 1)
tokens = torch.cat([tokens, next_token], dim=-1)
if (next_token == eos_id).all():
break
return tokens
tokens.size(-1) >= model.context_length) → truncate or error.
- Stopping criteria can be richer than EOS: regex on the decoded text, stop sequences, max tokens.
- For production: this is exactly what vLLM / TGI / SGLang do, plus continuous batching across many requests. Mention it.
- Numbers to drop: "naive O(T²); with KV cache O(T)", "batched generation: track per-sequence finished state"
Common follow-ups: - "Handle batched generation with different EOS times." - "How would you stop on a specific string (not just EOS)?"
Traps: - Re-running the model on the full prefix each step. O(T²) trap. - Not handling per-sequence EOS in batches.
Related cross-cutting: Cost & latency
Related module: learning/01_ai_engineering/03_decoding_strategies/, learning/02_ai_infrastructure/02_inference_serving_systems/
Fine-tuning primitives¶
Q: "Implement a LoRA layer."¶
Tags: senior · common · coding · source: 2026 AI engineer loops; HF PEFT-aware loops
Answer outline:
- LoRA: replace W ← W + BA where B ∈ R^{d×r}, A ∈ R^{r×k}, r << min(d, k). Only train A and B; freeze W.
- Reference:
class LoRALinear(nn.Module):
def __init__(self, base: nn.Linear, r: int, alpha: int = 16, dropout: float = 0.0):
super().__init__()
self.base = base
for p in self.base.parameters():
p.requires_grad = False
self.r = r
self.scaling = alpha / r
self.A = nn.Parameter(torch.zeros(r, base.in_features))
self.B = nn.Parameter(torch.zeros(base.out_features, r))
nn.init.kaiming_uniform_(self.A, a=5 ** 0.5)
# B initialized to zero so that initial output equals base
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.base(x) + self.dropout(x) @ self.A.T @ self.B.T * self.scaling
scaling = alpha / r decouples learning rate from rank choice.
- At inference, you can merge W + BA * scaling back into W for zero overhead.
- Numbers to drop: "typical rank: 8-64", "alpha: usually 16 or 32", "trainable params: ~0.1-1% of full FT", "memory savings: 80-95% on 7B+ models"
Common follow-ups: - "How does QLoRA differ?" - "Where in the transformer do you apply LoRA?" - "Merge LoRA back into the base weights."
Traps: - Initializing B non-zero. Step-0 output diverges from base. - Forgetting to freeze the base. Defeats the purpose.
Related cross-cutting: Cost & latency
Related module: learning/00_ai_foundation/06_adaptation_compression/
Q: "Implement cross-entropy loss for next-token prediction."¶
Tags: mid · common · coding · source: 2026 AI engineer loops; foundational ML-coding probe
Answer outline:
- Inputs: logits (B, T, V), targets (B, T) (token IDs).
- Shift: predict token t+1 from position t. So logits[:, :-1] predicts targets[:, 1:].
- Reference:
def next_token_ce(logits, targets, ignore_index=-100):
# logits: (B, T, V), targets: (B, T)
pred = logits[:, :-1, :].contiguous() # (B, T-1, V)
tgt = targets[:, 1:].contiguous() # (B, T-1)
return F.cross_entropy(
pred.view(-1, pred.size(-1)),
tgt.view(-1),
ignore_index=ignore_index,
)
ignore_index=-100 is the convention for masking padding tokens.
- Numerical detail: F.cross_entropy combines log-softmax + NLL in a numerically stable way. Don't roll your own softmax and then NLL — you'll lose precision on extreme logits.
- For instruction-tuning, also mask the loss on prompt tokens (only compute loss on the assistant's response). Implemented by setting target IDs in the prompt region to ignore_index.
- Numbers to drop: "PyTorch convention: ignore_index=-100", "F.cross_entropy is log-softmax + NLL fused, numerically stable"
Common follow-ups: - "How do you mask out prompt tokens?" - "Compute loss only on a specific span."
Traps: - Forgetting the shift. Computing loss against the same position is the classic off-by-one. - Manual log-softmax + manual NLL. Lose stability.
Related cross-cutting: Evaluation & quality
Related module: learning/01_ai_engineering/04_training_basics/
Tokenization & embeddings¶
Q: "Implement a simple BPE tokenizer (encode only)."¶
Tags: senior · occasional · coding · source: 2026 AI engineer loops; tokenization-aware probes
Answer outline:
- Given a learned merge table (list of (token_a, token_b) → merged), encode a string by:
1. Pre-tokenize (e.g., split on whitespace).
2. Split each piece into characters (bytes for byte-level BPE).
3. Repeatedly find the highest-priority (earliest-learned) adjacent pair that's in the merge table; merge it. Stop when no more merges apply.
4. Look up final tokens in the vocab to get IDs.
- Reference:
def bpe_encode(word, merges, vocab):
# word: list of subtokens (start as characters)
tokens = list(word)
while True:
pairs = [(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)]
# find the pair with the lowest merge rank (= highest priority)
best = None
best_rank = float('inf')
for i, p in enumerate(pairs):
if p in merges and merges[p] < best_rank:
best, best_rank = i, merges[p]
if best is None:
break
merged = tokens[best] + tokens[best+1]
tokens = tokens[:best] + [merged] + tokens[best+2:]
return [vocab[t] for t in tokens]
Common follow-ups: - "How does this differ from SentencePiece/Unigram?" - "Why byte-level?"
Traps: - Merging the first matching pair instead of the highest-priority one. Wrong output.
Related cross-cutting: Architecture choices
Related module: learning/00_ai_foundation/02_tokens_embeddings_context/
Q: "Compute cosine similarity between two batches of embeddings."¶
Tags: mid · common · coding · source: 2026 AI engineer loops; foundational RAG/embeddings probe
Answer outline:
- Inputs: A (N, D), B (M, D). Output: (N, M) similarity matrix.
- Reference:
Common follow-ups: - "How would you do this in FAISS?" - "When is dot product without normalization the right choice?"
Traps: - L2-normalizing in-place on tensors you need later. Mutates state. - Computing cosine on un-normalized vectors and calling it "cosine".
Related cross-cutting: Architecture choices
Related module: learning/01_ai_engineering/08_rag_system_design/, learning/01_ai_engineering/14_retrieval_ranking/
Numerical stability¶
Q: "Implement softmax in a numerically stable way."¶
Tags: mid · common · coding · source: 2026 AI engineer loops; foundational numerical probe
Answer outline:
- Naive exp(x) / sum(exp(x)) overflows for large x. Trick: subtract max(x) before exponentiating.
- Reference:
def stable_softmax(x, dim=-1):
x_max = x.max(dim=dim, keepdim=True).values
x_shifted = x - x_max
ex = x_shifted.exp()
return ex / ex.sum(dim=dim, keepdim=True)
softmax(x - c) == softmax(x) for any constant c. Subtracting max ensures max exponent is 0, no overflow. Underflow on the smallest entries is harmless (they round to 0).
- For loss: prefer log_softmax + NLL over log(softmax). The latter loses precision near 0.
- Numbers to drop: "fp32 max exp arg: ~88 (else overflow)", "bf16 max exp arg: ~88 (same mantissa range as fp32)", "fp16 max exp arg: ~11 (much tighter)"
Common follow-ups: - "Why does this matter in attention?" - "What goes wrong in fp16 specifically?"
Traps: - Subtracting mean instead of max. Doesn't bound the exponent.
Related cross-cutting: Evaluation & quality
Related module: learning/01_ai_engineering/04_training_basics/
Q: "Implement layer normalization from scratch."¶
Tags: mid · common · coding · source: 2026 AI engineer loops; foundational probe
Answer outline: - Normalize across the feature dim (not batch dim, unlike batch norm). - Reference:
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
return self.gamma * (x - mean) / (var + self.eps).sqrt() + self.beta
unbiased=False to match the standard formulation.
- Modern transformers often use RMSNorm instead (Llama, Mistral): drops the mean-centering, keeps only the variance scaling. Faster, comparable quality.
- RMSNorm:
def rms_norm(x, weight, eps=1e-6):
rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
return weight * x / (rms + eps)
Common follow-ups: - "Why LayerNorm over BatchNorm in transformers?" - "Implement RMSNorm." - "Pre-norm vs post-norm?"
Traps:
- Normalizing over the wrong dim. Feature dim is the last one.
- Using unbiased=True — small mismatch with reference impls.
Related cross-cutting: Evaluation & quality
Related module: learning/00_ai_foundation/02_tokens_embeddings_context/
Senior-level: think-aloud framing¶
Q: "Walk me through optimizing a generation loop for production."¶
Tags: staff · common · design · source: 2026 senior/staff AI engineer loops
Answer outline:
- Treat this as a design probe with code touchpoints, not a from-scratch implementation.
- Stages and the wins available at each:
- Prefill: FlashAttention (3-10× attention speedup at long context), chunked prefill (smoother p99 under mixed loads), prefix caching (skip recompute for shared prompts).
- Decode: KV caching (mandatory), GQA / MQA (smaller cache), speculative decoding (1.5-3× speedup), Medusa heads (similar gain, no draft model), quantized KV cache (FP8/INT8 halves cache memory).
- Across sequences: continuous batching (~20-30× throughput vs static batching).
- Across requests: prompt caching (provider-side or self-hosted), response caching (semantic similarity-based for repeated queries).
- Quantization decisions:
- BF16: default.
- FP8 (H100+): ~1.7× throughput, ~0-1% quality loss.
- INT4 (AWQ/GPTQ): ~3-4× throughput, ~1-3% quality loss.
- Trade per the deployment's quality bar.
- Engine choice: vLLM by default; TensorRT-LLM when one model serves at very high QPS and you can pay compile cost; SGLang for shared-prefix workloads.
- Observability: TTFT, end-to-end latency, tokens/sec, KV cache utilization, preemption rate, queue depth — see inference-serving.md.
- Mention what not to over-optimize early: speculative decoding helps less at high batch (compute-bound regime); FP8 on A100 doesn't exist; TensorRT-LLM compile time is real ops cost.
- Numbers to drop: "continuous batching alone: 20-30× over naive", "speculative decoding: 1.5-3×, decode-only, batch-sensitive", "FP8 on H100: 1.5-1.7× over BF16"
Common follow-ups: - "When does each optimization start to fight the others?" - "How would you benchmark these end to end?"
Traps: - Stacking every optimization without measuring. Some hurt at certain batch sizes.
Related cross-cutting: Cost & latency
Related module: learning/02_ai_infrastructure/05_agent_performance_economics/, learning/02_ai_infrastructure/02_inference_serving_systems/
Q: "Implement a stripped-down transformer block."¶
Tags: senior · common · coding · source: 2026 AI engineer loops; classic from-scratch probe
Answer outline: - Modern (pre-norm, RMSNorm, SwiGLU) block:
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, n_kv_heads, mlp_hidden):
super().__init__()
self.attn_norm = RMSNorm(d_model)
self.attn = GroupedQueryAttention(d_model, n_heads, n_kv_heads)
self.mlp_norm = RMSNorm(d_model)
self.mlp = SwiGLU(d_model, mlp_hidden)
def forward(self, x, mask=None):
x = x + self.attn(self.attn_norm(x), mask=mask)
x = x + self.mlp(self.mlp_norm(x))
return x
(SiLU(W_gate(x)) * W_up(x)) @ W_down. Slightly more params than vanilla MLP for the same hidden; better quality.
- Residual connections wrap each sub-block. Required for gradient flow at depth.
- Modern (2024+) blocks usually also have RoPE applied inside attention on Q and K before the score computation.
- Compare to the original (2017) post-norm block: x = LayerNorm(x + attn(x)). Less stable at depth; pre-norm is universal in 2026.
- Numbers to drop: "Llama-3 70B: 80 blocks, d_model 8192, mlp_hidden ~28000", "SwiGLU MLP: ~3× the parameter count of vanilla MLP for same hidden, ~1% quality bump"
Common follow-ups: - "Implement SwiGLU." - "Add RoPE to the attention." - "Why pre-norm and not post-norm?"
Traps: - Post-norm in a deep block. Trains poorly past ~24 layers without tricks. - Forgetting residuals.
Related cross-cutting: Evaluation & quality
Related module: learning/00_ai_foundation/02_tokens_embeddings_context/