Inference Fundamentals
Before diving into optimizations, you need to understand what happens when an LLM generates text. Unlike training (which processes entire sequences in parallel), inference is fundamentally sequential for autoregressive models.
The Autoregressive Process
Decoder-only transformers (GPT, Llama, etc.) generate text one token at a time. Each new token depends on all previous tokens. The model:
- Takes the current sequence as input
- Computes attention over all tokens
- Predicts a probability distribution over the vocabulary
- Samples the next token
- Appends it to the sequence and repeats
The Quadratic Problem
Without optimization, generating N tokens requires N forward passes, and each pass computes attention over an increasingly long sequence. Naive attention is O(n²) in sequence length. For a 100K context, this becomes computationally prohibitive.
Two Phases of Inference
LLM inference splits into two distinct phases with very different computational characteristics:
Phase 1: Prefill (Prompt Processing)
- What happens: Process all input tokens in parallel
- Computational profile: Compute-bound (limited by FLOPS)
- Memory access pattern: Sequential, predictable
- Parallelism: High - all input tokens processed together
- Output: KV cache populated, first token generated
Phase 2: Decode (Token Generation)
- What happens: Generate tokens one at a time
- Computational profile: Memory-bandwidth-bound
- Memory access pattern: Random access to KV cache
- Parallelism: Low - sequential token generation
- Bottleneck: Reading model weights + KV cache from HBM
Why Decode is Memory-Bound
During decode, we generate one token but must read the entire model weights and KV cache. For a 70B parameter model in FP16, that's 140GB of data read to produce a single token. The arithmetic intensity (FLOPS per byte) is extremely low, making HBM bandwidth the bottleneck.
Key Performance Metrics
| Metric | Definition | What It Measures |
|---|---|---|
TTFT | Time to First Token | Prefill latency - how fast prompt is processed |
TPOT | Time Per Output Token | Decode latency - time between tokens |
ITL | Inter-Token Latency | Variance in token generation time |
Throughput | Tokens/second (system) | Total capacity across all requests |
Latency | End-to-end time | TTFT + (output_tokens × TPOT) |
The Memory Hierarchy
Understanding where data lives is crucial for optimization:
┌─────────────────────────────────────────────────────────┐
│ Registers ~20 KB ~10 TB/s Fastest, smallest │
├─────────────────────────────────────────────────────────┤
│ Shared Memory ~200 KB ~19 TB/s Per SM, programmer│
├─────────────────────────────────────────────────────────┤
│ L2 Cache ~50 MB ~5 TB/s Shared across GPU │
├─────────────────────────────────────────────────────────┤
│ HBM (VRAM) 80-192GB 2-5 TB/s Main GPU memory │
├─────────────────────────────────────────────────────────┤
│ CPU RAM 512GB+ ~200 GB/s System memory │
├─────────────────────────────────────────────────────────┤
│ NVMe SSD TB+ ~7 GB/s Storage │
└─────────────────────────────────────────────────────────┘
The 10-50x bandwidth gap between HBM and compute capability is why memory optimization dominates inference work.
KV Cache: The Core Optimization
The KV cache is the single most important concept in LLM inference. It transforms O(n²) complexity to O(n) by caching intermediate computations.
What Gets Cached
In transformer attention, we compute Query (Q), Key (K), and Value (V) projections:
For autoregressive generation, Q comes from the new token, but K and V from previous tokens don't change. Instead of recomputing them, we cache them.
How KV Caching Works
Step 1 (Prefill): Process "The quick brown"
- Compute K₁, V₁ for "The"
- Compute K₂, V₂ for "quick"
- Compute K₃, V₃ for "brown"
- Cache: [(K₁,V₁), (K₂,V₂), (K₃,V₃)]
Step 2 (Decode): Generate "fox"
- Compute Q₄ for position 4
- Retrieve cached K₁₋₃, V₁₋₃
- Compute K₄, V₄ for "fox"
- Attention: Q₄ attends to K₁₋₄
- Cache: [(K₁,V₁), (K₂,V₂), (K₃,V₃), (K₄,V₄)]
Step 3 (Decode): Generate "jumps"
- Only compute Q₅, K₅, V₅ for new position
- Reuse all previous K, V from cache
Memory Calculation
The factor of 2 accounts for both K and V. Let's calculate for Llama-3 70B:
| Parameter | Llama-3 70B Value |
|---|---|
| Layers | 80 |
| KV Heads (GQA) | 8 |
| Head Dimension | 128 |
| Sequence Length | 8,192 (or 128K) |
| Precision | FP16 (2 bytes) |
8K context: 2 × 80 × 8 × 128 × 8192 × 2 = 2.68 GB per sequence
64K context: 2 × 80 × 8 × 128 × 65536 × 2 = 21.5 GB per sequence
128K context: 2 × 80 × 8 × 128 × 131072 × 2 = 43 GB per sequence
The Real Memory Bottleneck
An 80GB GPU running Llama-3 70B (140GB in FP16, quantized to ~35GB in INT8) leaves only 45GB for KV cache. With 128K context, you can serve approximately one concurrent request. This is why KV cache optimization is critical.
Attention Variants That Reduce KV Cache
Modern architectures reduce KV cache size by sharing heads:
Multi-Head Attention (MHA) - Original
- Separate K, V projections per head
- num_kv_heads = num_attention_heads
- Maximum expressiveness, maximum memory
Multi-Query Attention (MQA)
- Single shared K, V across all heads
- num_kv_heads = 1
- 8-16x KV cache reduction
- Used by: PaLM, Falcon
Grouped-Query Attention (GQA)
- Groups of heads share K, V
- num_kv_heads = num_attention_heads / group_size
- Balance between MHA and MQA
- Used by: Llama 2/3, Mistral, Gemma
Multi-Latent Attention (MLA)
- Low-rank compression of KV cache
- Compresses K, V into latent space
- 16x+ reduction possible
- Used by: DeepSeek-V2, DeepSeek-V3
| Variant | KV Heads (70B) | Cache per 8K seq | Reduction |
|---|---|---|---|
| MHA | 64 | 21.5 GB | 1x (baseline) |
| GQA-8 | 8 | 2.68 GB | 8x |
| MQA | 1 | 0.34 GB | 64x |
PagedAttention & Memory Virtualization
PagedAttention, introduced by vLLM in 2023, revolutionized LLM serving by applying OS virtual memory concepts to GPU memory management.
The Problem: Memory Fragmentation
Traditional serving systems pre-allocate contiguous memory for each request's maximum sequence length:
Request 1: Allocate 8192 tokens × 2.68 GB = contiguous block
Request 2: Allocate 8192 tokens × 2.68 GB = contiguous block
...
Reality:
Request 1 uses 500 tokens → 7692 tokens WASTED
Request 2 uses 2000 tokens → 6192 tokens WASTED
Studies found existing systems waste 60-80% of KV cache memory.
The Solution: Virtual Memory for GPUs
PagedAttention divides KV cache into fixed-size blocks, allocated on-demand:
| OS Concept | PagedAttention Analog |
|---|---|
| Page (4KB) | KV Block (16-32 tokens) |
| Virtual Address | Logical Block Index |
| Physical Address | GPU Memory Location |
| Page Table | Block Table |
| Page Fault | Block Allocation |
| Free Page Pool (LRU) | Free Block Pool |
How It Works
Physical GPU Memory:
┌────┬────┬────┬────┬────┬────┬────┬────┬────┬────┐
│ B0 │ B1 │ B2 │ B3 │ B4 │ B5 │ B6 │ B7 │FREE│FREE│
└────┴────┴────┴────┴────┴────┴────┴────┴────┴────┘
Request A (needs 3 blocks):
Block Table A: [0, 3, 5] → Physical blocks B0, B3, B5
Request B (needs 2 blocks):
Block Table B: [1, 4] → Physical blocks B1, B4
Request C (needs 2 blocks):
Block Table C: [2, 6] → Physical blocks B2, B6
Blocks are NOT contiguous! Managed like virtual memory.
Block Structure
Each block stores KV cache for a fixed number of tokens:
Block Size: 16 tokens (typical)
Block Memory: 16 × (2 × num_layers × num_kv_heads × head_dim × precision)
= 16 × 2.68GB/8192 = ~5.2 MB for Llama-3 70B GQA
Benefits
- Near-zero waste: vLLM achieves <4% memory waste vs 60-80%
- No external fragmentation: All blocks same size
- Minimal internal fragmentation: Only last block partially filled
- Dynamic allocation: Blocks allocated as sequence grows
Prefix Caching (Copy-on-Write)
Multiple requests with same prefix can share physical blocks:
System Prompt: "You are a helpful assistant..." (Block 0, 1)
Request A: System prompt + "What is Python?"
Block Table: [0, 1, 2, 3] ← Shares blocks 0,1
Request B: System prompt + "Explain machine learning"
Block Table: [0, 1, 4, 5] ← Shares blocks 0,1
Physical storage: Only ONE copy of system prompt blocks!
Reference counting ensures blocks freed when all requests done.
Automatic Prefix Caching (APC)
vLLM's APC automatically detects common prefixes across requests using content-based hashing. No configuration needed - system prompts, few-shot examples, and repeated context are cached automatically.
Block Size Trade-offs
| Block Size | Pros | Cons |
|---|---|---|
| Small (8 tokens) | Less internal fragmentation | More kernel launch overhead, larger block tables |
| Medium (16 tokens) | Balanced | Standard choice |
| Large (32+ tokens) | Better GPU utilization per kernel | More memory waste in last block |
Recommended Videos
Continuous Batching
Continuous batching (iteration-level scheduling) maximizes GPU utilization by dynamically managing requests at each decode step.
Static Batching Problems
Static Batch of 4 requests:
┌─────────────────────────────────────────────────────────┐
│ Req 1: ████████████████████████████████████ (1000 tokens)│
│ Req 2: ████████ (200 tokens) [PADDING...............] │
│ Req 3: ████████████████ (400 tokens) [PADDING.......] │
│ Req 4: ████ (100 tokens) [PADDING...................] │
└─────────────────────────────────────────────────────────┘
↑ All must wait for longest request
↑ Padding wastes compute
Problems:
- Short requests wait for long ones (head-of-line blocking)
- Padding wastes compute on meaningless tokens
- GPU utilization drops as requests finish
- New requests must wait for entire batch to complete
Continuous Batching Solution
Time →
Step 1: [Req1, Req2, Req3, Req4] - all active
Step 2: [Req1, Req2, Req3, Req4] - all active
Step 3: [Req1, Req2, Req3, ----] - Req4 done, slot open
Step 4: [Req1, Req2, Req3, Req5] - New request joins!
Step 5: [Req1, ----, Req3, Req5] - Req2 done
Step 6: [Req1, Req6, Req3, Req5] - Another joins
...
Benefits:
- Requests return immediately when done
- New requests join without waiting
- GPU stays fully utilized
- No padding waste
Ragged Batching
To avoid padding, sequences are concatenated into a single "super-sequence":
Traditional (padded):
Seq 1: [A, B, C, PAD, PAD]
Seq 2: [D, E, F, G, H]
Seq 3: [I, J, PAD, PAD, PAD]
Ragged (concatenated):
Super-sequence: [A, B, C, D, E, F, G, H, I, J]
Position IDs: [0, 1, 2, 0, 1, 2, 3, 4, 0, 1]
Sequence IDs: [1, 1, 1, 2, 2, 2, 2, 2, 3, 3]
Attention mask ensures each sequence only attends to itself.
Chunked Prefill
Long prompts can monopolize an iteration, delaying all other requests:
Head-of-Line Blocking
A 100K token prompt takes significant time to prefill. Without chunking, all decode requests wait, causing latency spikes.
Solution: Split long prefills into chunks, interleaved with decodes:
Without chunking:
Step 1: Prefill 100K tokens (Request A) - BLOCKS EVERYONE
Step 2: Decode all
With chunked prefill (chunk_size=8192):
Step 1: Prefill 8K (Req A chunk 1) + Decode (Req B, C, D)
Step 2: Prefill 8K (Req A chunk 2) + Decode (Req B, C, D)
...continues interleaved...
Scheduling Policies
| Policy | Description | Use Case |
|---|---|---|
| FCFS | First-come, first-served | Fair, simple |
| SJF | Shortest job first (by estimated output) | Minimize average latency |
| Priority | Based on SLO or customer tier | Multi-tenant serving |
| Preemptive | Pause long requests for short ones | Strict latency SLOs |
Recommended Videos
FlashAttention & Kernel Optimization
FlashAttention revolutionized attention computation by being IO-aware - minimizing HBM reads/writes through algorithmic restructuring.
Standard Attention Problem
Naive attention materializes large intermediate matrices in HBM:
1. S = Q @ K^T # Shape: (seq_len, seq_len) - WRITE to HBM
2. P = softmax(S) # READ S, WRITE P to HBM
3. O = P @ V # READ P, WRITE O to HBM
For seq_len=8192: S and P are 8192×8192 = 268MB each!
Memory traffic: Multiple reads/writes of these large matrices.
FlashAttention Solution: Tiling
FlashAttention computes attention in tiles that fit in SRAM (shared memory):
Algorithm (simplified):
1. Divide Q, K, V into blocks that fit in SRAM
2. For each Q block:
a. Load Q block to SRAM
b. For each K, V block:
- Load K, V blocks to SRAM
- Compute partial attention in SRAM
- Accumulate results (online softmax)
c. Write final output block to HBM
Key insight: Never materialize full S or P matrices!
Only final output O written to HBM.
IO Complexity Reduction
Standard attention: O(N² + Nd) HBM accesses
FlashAttention: O(N²d² / M) where M is SRAM size
For typical configs, this is 5-20x fewer memory accesses.
Online Softmax
The trick that makes tiling work - compute softmax incrementally:
Standard softmax requires all values to compute denominator.
Online softmax maintains running max and sum:
For each new block of scores:
1. Update running max: m_new = max(m_old, block_max)
2. Rescale previous sum: sum_new = sum_old × exp(m_old - m_new)
3. Add new contributions: sum_new += sum(exp(block - m_new))
4. Update running output with correction factor
FlashAttention-2 Improvements
- Better parallelism: Parallelize over sequence length, not just batch
- Reduced non-matmul FLOPs: Fewer warp synchronizations
- Better work partitioning: Improved occupancy
- Result: 2x speedup over FlashAttention-1
FlashAttention-3 (Hopper)
- Exploits H100 Tensor Memory Accelerator (TMA)
- Asynchronous data movement overlapped with compute
- FP8 support for further speedup
- Result: 1.5-2x speedup over FA2 on H100
Skip Softmax (January 2026)
Many attention scores are near-zero and don't contribute meaningfully. Skip Softmax detects and skips these blocks:
Standard FlashAttention:
For each KV block: compute attention (even if scores ≈ 0)
Skip Softmax:
For each KV block:
1. Quick check: are all scores below threshold?
2. If yes: SKIP this block entirely
3. If no: compute normally
Results: Up to 1.4x faster TTFT and TPOT
No retraining needed - drop-in replacement.
FlashInfer (MLSys 2025 Best Paper)
NVIDIA's open-source library for high-performance attention:
- JIT-compiled kernels adapt to runtime settings
- Block-sparse KV cache formats
- Composable attention for complex scenarios
- Direct integration with vLLM, SGLang
Quantization for Inference
Quantization reduces precision of weights and activations, decreasing memory footprint and increasing arithmetic throughput.
Precision Formats Overview
| Format | Bits | Exponent/Mantissa | Range |
|---|---|---|---|
| FP32 | 32 | 8/23 | ±3.4×10³⁸ |
| FP16 | 16 | 5/10 | ±65504 |
| BF16 | 16 | 8/7 | ±3.4×10³⁸ (FP32 range!) |
| FP8 E4M3 | 8 | 4/3 | ±448 (weights) |
| FP8 E5M2 | 8 | 5/2 | ±57344 (gradients) |
| INT8 | 8 | - | -128 to 127 |
| INT4 | 4 | - | -8 to 7 |
| NVFP4 E2M1 | 4 | 2/1 | ±6 |
Why BF16 Over FP16
BF16 (Brain Float) uses same exponent bits as FP32, maintaining dynamic range while reducing precision:
FP32: 1 sign + 8 exponent + 23 mantissa = 32 bits
BF16: 1 sign + 8 exponent + 7 mantissa = 16 bits
FP16: 1 sign + 5 exponent + 10 mantissa = 16 bits
BF16 advantage: Same range as FP32, simpler conversion
FP16 advantage: More precision within smaller range
Weight-Only Quantization
Quantize only weights, compute in higher precision:
- INT4 AWQ (Activation-aware Weight Quantization): Preserves salient weights that handle important activations
- GPTQ: Layer-by-layer quantization with error compensation
- Benefit: 4x memory reduction for weights, minimal accuracy loss
W4A16: 4-bit weights, 16-bit activations
- Weights stored in INT4 (4x compression)
- Dequantized to FP16 before compute
- Memory-bound scenarios benefit most
Full Quantization (Weights + Activations)
Quantize both for maximum throughput:
- INT8 SmoothQuant: Mathematically smooths outliers between weights and activations
- FP8: Native hardware support on H100/MI300X/Blackwell
- W4A8: 4-bit weights, 8-bit activations - good balance
NVFP4: NVIDIA's 4-bit Float (Blackwell)
NVFP4 uses micro-block scaling for accuracy:
- Values grouped into blocks of 16
- Each block shares an FP8 (E4M3) scale factor
- Additional per-tensor FP32 scale
Results: 3.5x memory reduction vs FP16, <1% accuracy loss on most tasks.
KV Cache Quantization
Quantizing KV cache separately provides additional savings:
| KV Precision | Memory vs FP16 | Accuracy Impact |
|---|---|---|
| FP16 | 1x (baseline) | None |
| FP8 | 0.5x | Minimal |
| NVFP4 | 0.25x | Small (<1%) |
Quantization Best Practices
- Calibration data matters: Use representative samples from production workload
- Last layers are sensitive: Keep final 1-2 layers in higher precision
- Monitor per-task accuracy: Some tasks degrade more than others
- Layer-wise mixed precision: Quantize insensitive layers more aggressively
NVIDIA Inference Stack
NVIDIA provides a comprehensive stack optimized for LLM inference: TensorRT-LLM, FlashInfer, and Triton Inference Server.
TensorRT-LLM Architecture
┌─────────────────────────────────────────────────────────┐
│ User Application │
├─────────────────────────────────────────────────────────┤
│ TensorRT-LLM Runtime │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐│
│ │ In-flight │ │ Paged │ │ Speculative ││
│ │ Batching │ │ KV Caching │ │ Decoding ││
│ └─────────────┘ └─────────────┘ └─────────────────────┘│
├─────────────────────────────────────────────────────────┤
│ TensorRT Compiler │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐│
│ │ Graph │ │ Kernel │ │ Quantization ││
│ │ Optimization│ │ Fusion │ │ (FP8/FP4/INT) ││
│ └─────────────┘ └─────────────┘ └─────────────────────┘│
├─────────────────────────────────────────────────────────┤
│ Custom CUDA Kernels │
│ FlashAttention │ PagedAttention │ FP8 GEMM │ Fused Ops │
├─────────────────────────────────────────────────────────┤
│ CUDA / cuBLAS │
└─────────────────────────────────────────────────────────┘
Key Optimizations
1. Kernel Fusion
TensorRT identifies and fuses adjacent operations:
Before fusion (3 kernel launches, 3 HBM round-trips):
LayerNorm → HBM → FP8 Quantize → HBM → Linear
After fusion (1 kernel launch):
LayerNorm + FP8 Quantize + Linear (fused)
Benefit: ~3x kernel time reduction for fused all-reduce + layernorm
2. In-flight Batching
Continuous batching implementation with configurable policies:
- Max tokens in batch
- Max sequences in batch
- Scheduling policy (capacity, guaranteed_no_evict)
3. Speculative Decoding Support
- Draft model: Smaller model proposes tokens
- EAGLE-3: Learned draft head on target model
- Multi-token prediction: Model predicts multiple future tokens
- Medusa: Multiple parallel heads for speculation
Hardware: Blackwell Architecture
| Feature | Specification |
|---|---|
| Tensor Cores | 5th generation, native FP4/FP8 |
| Precision Support | FP64, FP32, TF32, FP16, BF16, FP8, FP6, FP4 |
| Memory | HBM3e, higher bandwidth than H100 |
| Interconnect | NVLink 5.0 (900 GB/s per GPU) |
| Transformer Engine | Automatic mixed precision management |
NIM (NVIDIA Inference Microservices)
Pre-optimized containers for common models:
- Optimized TensorRT-LLM engines pre-built
- REST/gRPC APIs compatible with OpenAI format
- Automatic batching and scaling
- Multi-GPU support out of the box
Triton Inference Server
Production serving layer:
- Model ensemble and pipeline support
- Dynamic batching
- Multi-framework support (TensorRT, PyTorch, ONNX)
- Metrics and monitoring (Prometheus)
- Model versioning and A/B testing
AMD ROCm & MI300X
AMD's ROCm platform has matured significantly, achieving first-class support in vLLM as of January 2026.
MI300X Architecture
Hardware Specifications
- Compute Dies: 8 XCDs (Accelerator Complex Dies)
- Compute Units: 304 total (38 per XCD)
- Memory: 192 GB HBM3 unified
- Memory Bandwidth: 5.3 TB/s (higher than H100's 3.35 TB/s)
- FP8 Support: Native via CDNA 3
Why MI300X Excels at Decode
Decode phase is memory-bandwidth-bound. MI300X's higher HBM bandwidth provides advantage:
| GPU | HBM Bandwidth | Decode Advantage |
|---|---|---|
| H100 SXM | 3.35 TB/s | Baseline |
| MI300X | 5.3 TB/s | +58% theoretical |
| H200 | 4.8 TB/s | +43% |
Real-world results show MI300X matching or exceeding H100 for batch sizes 1-64 in decode-heavy workloads.
ROCm Software Stack
┌─────────────────────────────────────────────────────────┐
│ vLLM / SGLang │
├─────────────────────────────────────────────────────────┤
│ AITER Kernels │
│ (Attention, FP8 GEMM, Fused Ops) │
├─────────────────────────────────────────────────────────┤
│ Triton / hipBLAS │
├─────────────────────────────────────────────────────────┤
│ ROCm / HIP │
└─────────────────────────────────────────────────────────┘
vLLM on ROCm (January 2026 Status)
- First-class platform: Pre-built Docker images on Docker Hub
- Test coverage: 93% of test groups passing (up from 37% in Nov 2025)
- No source builds needed:
docker pull rocm/vllm-omni
Key Optimizations
| Optimization | Impact |
|---|---|
| AITER FP8 kernels | Native FP8 compute |
| Fused LayerNorm/SiLU FP8 | Reduced memory traffic |
| Optimized PagedAttention | Assembly-level tuning |
| fastsafetensors | Faster model loading |
| Multimodal batch parallelism | 45% throughput gain on vision |
GPU Partitioning
MI300X supports hardware-enforced partitioning for multi-tenant deployments:
# Run multiple vLLM instances on partitioned MI300X
# Each instance gets isolated GPU resources
Instance 1: 1/2 of MI300X (96GB, 152 CUs)
Instance 2: 1/2 of MI300X (96GB, 152 CUs)
# Or 4-way partition for smaller models
Instance 1-4: 1/4 of MI300X each (48GB, 76 CUs)
Known Limitations
MoE Regression
AITER has known issues with MoE models (Mixtral, DeepSeek-R1). For MoE workloads, use older image: rocm/vllm:rocm7.0.0_vllm_0.11.1_20251103
RCCL vs NCCL
RCCL is AMD's collective communication library (NCCL equivalent):
- Meta's DDA solutions show 10-50% speedup over baseline RCCL
- Active development to close gap with NCCL
- InfiniBand and RoCE support
Recommended Videos
Parallelism Strategies
Scaling inference across multiple GPUs requires understanding when to use each parallelism strategy.
Tensor Parallelism (TP)
Splits individual weight matrices across GPUs. Each GPU holds a slice of every layer.
Single GPU:
Linear layer: W (4096 × 4096) = 64MB
TP=4 (4 GPUs):
GPU 0: W[:, 0:1024] (4096 × 1024) = 16MB
GPU 1: W[:, 1024:2048] (4096 × 1024) = 16MB
GPU 2: W[:, 2048:3072] (4096 × 1024) = 16MB
GPU 3: W[:, 3072:4096] (4096 × 1024) = 16MB
Forward pass: Each GPU computes partial result
All-reduce: Synchronize results across GPUs
Communication: All-reduce after every transformer block
Best for: Single node, high-bandwidth NVLink interconnect
Scaling limit: TP=8 typical max (communication overhead)
Pipeline Parallelism (PP)
Splits model vertically by layers. Each GPU holds consecutive layers.
PP=4 for 32-layer model:
GPU 0: Layers 0-7 (embedding + first 8 transformer blocks)
GPU 1: Layers 8-15
GPU 2: Layers 16-23
GPU 3: Layers 24-31 (+ output projection)
Forward pass: Activations flow GPU0 → GPU1 → GPU2 → GPU3
Backward pass: Gradients flow GPU3 → GPU2 → GPU1 → GPU0
Communication: Point-to-point between adjacent stages
Challenge: Pipeline bubbles reduce utilization
Best for: Multi-node, limited inter-node bandwidth
Expert Parallelism (EP) for MoE
Distributes experts across GPUs instead of replicating all experts everywhere.
MoE layer with 8 experts, EP=4:
GPU 0: Expert 0, Expert 1
GPU 1: Expert 2, Expert 3
GPU 2: Expert 4, Expert 5
GPU 3: Expert 6, Expert 7
Forward pass:
1. Router determines which experts each token needs
2. All-to-all: Send tokens to GPUs with their experts
3. Each GPU processes tokens for its experts
4. All-to-all: Return results to original GPUs
Communication: All-to-all for token routing
Benefit: Memory efficient - each GPU stores subset of experts
Hybrid Parallelism
Production deployments combine strategies:
DeepSeek-R1 on 8 GPUs:
TP=2: Split attention/FFN within each expert
EP=4: Distribute 256 experts across 4 groups
TensorRT-LLM config:
convert_checkpoint.py --moe_tp_size 2 --moe_ep_size 4
| Configuration | Use Case |
|---|---|
| TP only | Dense models, single node |
| PP only | Very deep models, multi-node |
| EP only | MoE models, memory-constrained |
| TP + PP | Large dense models, multi-node |
| TP + EP | Large MoE models |
| TP + EP + PP | Extremely large MoE (100B+) |
Context Parallelism (CP)
For very long contexts, splits sequence across GPUs:
128K context with CP=4:
GPU 0: Tokens 0-32K
GPU 1: Tokens 32K-64K
GPU 2: Tokens 64K-96K
GPU 3: Tokens 96K-128K
Attention uses ring-attention pattern for cross-chunk attention.
Speculative Decoding
Speculative decoding accelerates generation by drafting multiple tokens in parallel, then verifying them efficiently.
Core Insight
Decode is memory-bandwidth-bound. Verifying K tokens costs similar to generating 1 token because we're reading the same weights regardless.
Algorithm
1. DRAFT: Small model generates K candidate tokens
draft_tokens = [t₁, t₂, t₃, t₄, t₅] (K=5)
2. VERIFY: Target model evaluates all K tokens in ONE forward pass
target_logits = target_model([context + draft_tokens])
3. ACCEPT/REJECT: Compare distributions
For i in 0..K:
if sample(target_logits[i]) matches draft_tokens[i]:
ACCEPT token i
else:
REJECT, resample from target, stop
4. BONUS: Always get at least 1 token (resampled if all rejected)
Example:
Draft: [the, quick, brown, fox, jumped]
Target accepts: [the, quick, brown] ✓
Target rejects: [fox] ✗ (target wanted "red")
Output: [the, quick, brown, red] = 4 tokens from 1 verify pass!
Draft Model Options
| Approach | Description | Pros/Cons |
|---|---|---|
| Smaller model | Llama-7B drafts for Llama-70B | Simple, requires separate model |
| EAGLE | Learned draft head on target | High acceptance, trained on target |
| Medusa | Multiple parallel heads | No separate model needed |
| Self-speculation | Early exit from target layers | Single model, lower acceptance |
| Lookahead | N-gram based prediction | Training-free, works on any model |
EAGLE-3 (TensorRT-LLM)
State-of-the-art learned speculation:
- Lightweight draft head trained on target model outputs
- Tree-structured speculation (multiple branches)
- Acceptance rates 70-90% for typical queries
- Integrated in TensorRT-LLM runtime
When Speculative Decoding Helps
γ-Tolerance (ICLR 2026)
Research shows memory bandwidth remains the bottleneck even at large batch sizes. The "γ-tolerance" criterion characterizes when speculation helps:
- Low batch size: Almost always helps
- High batch size: Still helps if draft is fast enough
- Key metric: draft_latency / verify_latency ratio
Multi-Token Prediction
Some models trained to predict multiple future tokens natively:
Standard LM head: P(t_{n+1} | t_1..t_n)
Multi-token head:
P(t_{n+1} | t_1..t_n)
P(t_{n+2} | t_1..t_n) # Direct prediction, not autoregressive!
P(t_{n+3} | t_1..t_n)
...
Benefit: Native speculation without separate draft model.
Disaggregated Prefill/Decode
Disaggregated serving separates prefill and decode onto different hardware pools for independent optimization and scaling.
Why Disaggregate?
Prefill and decode have fundamentally different resource needs:
| Characteristic | Prefill | Decode |
|---|---|---|
| Bound by | Compute (FLOPS) | Memory bandwidth |
| Batch efficiency | High (parallel tokens) | Low (sequential) |
| Latency sensitivity | TTFT | TPOT, ITL |
| Ideal hardware | High FLOPS (H100) | High bandwidth (MI300X) |
| Scaling dimension | Prompt length | Concurrent users |
Architecture
┌─────────────────────┐
│ Load Balancer │
└──────────┬──────────┘
│
┌────────────────┴────────────────┐
▼ ▼
┌─────────────────────┐ ┌─────────────────────┐
│ Prefill Cluster │ │ Decode Cluster │
│ (H100 / H200) │ │ (MI300X / A100) │
│ │ │ │
│ - Process prompts │ KV Cache │ - Generate tokens │
│ - High FLOPS │────────▶│ - High bandwidth │
│ - Compute-bound │ Transfer │ - Memory-bound │
└─────────────────────┘ └─────────────────────┘
KV Cache Transfer
The critical challenge: moving KV cache from prefill to decode cluster quickly.
Transfer Overhead
If KV transfer takes longer than compute savings, disaggregation hurts performance. Solutions must support fast, low-latency protocols.
vLLM KV transfer backends:
- NIXLConnector: NVIDIA Inference Xfer Library
- UCX: Unified Communication X
- libfabric: OpenFabrics interfaces
- EFA: AWS Elastic Fabric Adapter
Industry Adoption (2026)
Disaggregation is now standard in production:
- Frameworks: NVIDIA Dynamo, llm-d, Ray Serve LLM, SGLang, vLLM, LMCache
- Companies: Fireworks AI, Perplexity, Meta, Amazon, Modular, DeepInfra
Performance Results
- Up to 6.4x throughput improvement
- 20x reduction in latency variance
- 15-40% infrastructure cost reduction
When to Disaggregate
- High scale: Thousands of concurrent users
- Variable workloads: Mix of long prompts and many short outputs
- Strict SLOs: Need to optimize TTFT and TPOT independently
- Heterogeneous hardware: Mix of GPU types available
GPU Communication & Collective Operations
Inter-GPU communication often determines scaling efficiency. Understanding NCCL and alternatives is critical.
Collective Operations
| Operation | Description | Use Case |
|---|---|---|
| All-Reduce | Sum/average across all GPUs, result everywhere | Tensor parallelism sync |
| All-Gather | Gather data from all GPUs to all GPUs | Sequence parallelism |
| Reduce-Scatter | Reduce then scatter result | ZeRO optimizer |
| All-to-All | Each GPU sends different data to each GPU | Expert parallelism routing |
| Broadcast | One GPU sends to all others | Weight synchronization |
All-Reduce: The TP Bottleneck
In tensor parallel inference, all-reduce synchronizes after every transformer block:
8-way TP forward pass:
Linear 1 (parallel) → All-Reduce → Activation → Linear 2 (parallel) → All-Reduce
Each all-reduce: ~23% of decode latency on 8×H100!
NCCL (NVIDIA Collective Communications Library)
- Optimized for NVLink intra-node communication
- Ring and tree algorithms for different message sizes
- Scales poorly for small messages across nodes
- Standard choice for NVIDIA GPUs
Optimization Strategies
1. Custom Single-Shot All-Reduce
Standard ring all-reduce: 2×(N-1) communication steps
Single-shot: Each rank aggregates from all peers in one stage
Result: ~3x kernel time speedup when fused with LayerNorm
~27% end-to-end decode improvement
2. Kernel Fusion
Before: All-Reduce → LayerNorm → Add (3 kernels)
After: Fused All-Reduce + LayerNorm + Add (1 kernel)
Eliminates intermediate HBM round-trips.
3. Communication Compression
Flash All-Reduce with INT4 quantization:
- Quantize tensors before communication
- Transfer 4x less data
- Dequantize after receive
Result: 3.18x latency reduction for >64MB messages
NVRAR (Research Alternative)
Hierarchical all-reduce based on recursive doubling with NVSHMEM:
- 1.9-3.6x lower latency than NCCL for 128KB-2MB
- 1.72x reduction in end-to-end batch latency for Llama 3.1 405B
GPUDirect RDMA
Bypassing CPU
GPUDirect RDMA allows network adapters (InfiniBand, RoCE) to directly access GPU memory without CPU involvement:
- Eliminates CPU memory copy
- Reduces latency by microseconds
- Critical for multi-node tensor parallelism
Unified Memory (Grace Hopper/Blackwell)
NVLink-C2C creates coherent CPU-GPU memory:
- Bandwidth: 900 GB/s between CPU and GPU
- Capacity: GH200 = 96GB HBM + 480GB LPDDR unified
- Benefit: Models larger than GPU memory "just work"
NCCLX (Meta, 100K+ GPUs)
Extended NCCL for extreme scale:
- Supports 100,000+ GPU collective operations
- Developed for Llama 4 training/inference
- Optimizations for both throughput and latency
Key Topics for System Design
Distilled essentials for LLM inference & training system design interviews. These are the concepts that matter most.
1. Prefill vs Decode: The Two Phases
Prefill (prompt processing): Compute-bound, processes all input tokens in parallel, generates KV cache. Decode (generation): Memory-bandwidth-bound, generates one token at a time, reads entire KV cache per token. Most optimization effort targets decode phase.
2. Memory Hierarchy: HBM → SRAM → Registers
HBM: 80-192GB, 2-5 TB/s bandwidth (main memory). SRAM: ~50MB per GPU, 100+ TB/s (L2 cache + SM shared memory). The goal is to keep data in SRAM as long as possible. FlashAttention exists because HBM bandwidth is the bottleneck.
3. KV Cache: The Memory Dominator
Formula: 2 × layers × kv_heads × head_dim × seq_len × precision_bytes. For Llama-70B with 8K context: ~2.6GB per sequence. KV cache often exceeds model weights in memory usage for long contexts. This is why PagedAttention and context length matter so much.
4. Continuous Batching: No Head-of-Line Blocking
Static batching wastes GPU cycles waiting for longest sequence. Continuous batching inserts new requests as others complete. Key insight: Iteration-level scheduling, not request-level. vLLM, TensorRT-LLM, and all modern serving systems use this.
5. FlashAttention: Tiling for Memory Efficiency
Standard attention materializes O(N²) attention matrix in HBM. FlashAttention tiles Q, K, V into blocks that fit in SRAM, computes partial softmax with online correction. Result: No O(N²) memory, 2-4x speedup. Now the default in all inference engines.
6. PagedAttention: Virtual Memory for KV Cache
Problem: Variable sequence lengths cause fragmentation. Solution: Allocate KV cache in fixed blocks (like OS pages), use block table for indirection. Achieves near-zero fragmentation, enables memory sharing for beam search. Core innovation behind vLLM.
7. Quantization: Trading Precision for Speed
Weight-only quantization (INT4/INT8 weights, FP16 activations): Reduces memory, speeds up memory-bound decode. Full quantization (FP8/INT8 everything): Faster compute on Tensor Cores. Key trade-off: accuracy vs throughput. AWQ, GPTQ for weights; FP8 on Hopper/Blackwell for full.
8. Tensor Parallelism: Splitting Within Layers
Splits attention heads and FFN columns across GPUs. Each GPU holds 1/N of weights, computes partial results, all-reduce to combine. Low latency (single forward pass), but requires fast interconnect (NVLink). Use for latency-sensitive serving, typically 2-8 GPUs.
9. Pipeline Parallelism: Splitting Across Layers
Different GPUs hold different layers. Data flows through pipeline, micro-batching hides bubble overhead. Lower communication than TP, but higher latency. Use for training or when model doesn't fit with TP alone. Combine with TP for very large models.
10. Speculative Decoding: Guess and Verify
Draft model generates K candidate tokens cheaply, target model verifies in single forward pass (parallel!). If correct, accept all K; if wrong, accept up to first mismatch. 2-3x speedup for latency without quality loss. Works because verification is parallel, generation is sequential.
11. Disaggregated Serving: Prefill ≠ Decode
Prefill is compute-bound, decode is memory-bound — why run both on same hardware? Disaggregated architecture: Prefill nodes with high compute, decode nodes with high memory bandwidth. Transfer KV cache between them. Emerging pattern for large-scale serving (Mooncake, DistServe).
12. Arithmetic Intensity: The Core Metric
AI = FLOPs / Bytes moved. Compare to hardware's ops:byte ratio (H100: ~500 for FP16). If AI < ratio → memory-bound. Decode AI ≈ 1-2 (one token, read all weights), prefill AI ≈ batch_size × seq_len. This explains why decode is always memory-bound.
13. Hardware: NVIDIA vs AMD vs Custom
H100: 80GB HBM3, 3.35 TB/s, NVLink 900 GB/s, dominant ecosystem. MI300X: 192GB HBM3, 5.3 TB/s, better memory, weaker software. Blackwell (B200): FP4 support, 2x H100 performance. Know the specs, but software maturity often matters more.
14. Serving Frameworks: vLLM vs TensorRT-LLM
vLLM: PagedAttention, Python, flexible, great for research/startups. TensorRT-LLM: NVIDIA optimized, compiled graphs, best raw performance. SGLang: RadixAttention for prefix sharing. Choice depends on: flexibility vs performance, NVIDIA-only vs multi-hardware.
15. Training at Scale: Data, Tensor, Pipeline, Expert
Data Parallel (DP/FSDP): Replicate model, split data, gradient sync. Tensor Parallel: Split ops. Pipeline Parallel: Split layers. Expert Parallel: MoE routing. Modern training uses 3D/4D parallelism combining all. FSDP shards optimizer states across DP ranks for memory efficiency.
Interview Quick Hits
- Why is decode slow? Memory-bandwidth-bound: read all weights for just 1 token.
- Why KV cache? Avoid recomputing attention for all previous tokens each step.
- Why batching helps? Amortize weight loading across multiple sequences.
- Why TP over DP for inference? Lower latency (single request), not throughput-focused.
- Why FP8 over INT8? No calibration needed, direct training support on Hopper+.
- Why speculative works? Verification is parallel (batch of K), generation is serial.
- Why disaggregate? Match hardware to workload characteristics (compute vs memory).
Knowledge Check
Test your understanding of LLM inference concepts. Select the best answer for each question.
Question 1: Inference Phases
During LLM inference, which phase is typically memory-bandwidth-bound?
Question 2: KV Cache Memory
For a 70B parameter model with GQA (8 KV heads), 80 layers, 128 head dimension, and FP16 precision, approximately how much KV cache memory is needed per 8K token sequence?
Question 3: PagedAttention
What is the primary benefit of PagedAttention's block-based memory management?
Question 4: Continuous Batching
What problem does continuous batching solve that static batching cannot?
Question 5: FlashAttention
FlashAttention achieves its speedup primarily by:
Question 6: Quantization
NVIDIA's NVFP4 format on Blackwell achieves accuracy preservation through:
Question 7: AMD MI300X
Why does AMD MI300X often match or exceed H100 performance in the decode phase?
Question 8: Tensor Parallelism
In tensor parallelism, what collective operation is performed after each transformer block?
Question 9: Expert Parallelism
Expert parallelism for MoE models uses which collective operation to route tokens to the correct GPUs?
Question 10: Speculative Decoding
Why does speculative decoding work even though it requires running two models?
Question 11: Disaggregated Serving
What is the main challenge in disaggregated prefill/decode architecture?
Question 12: Communication Optimization
GPUDirect RDMA improves multi-node inference by:
Question 13: GQA vs MHA
Grouped-Query Attention (GQA) reduces KV cache size by:
Question 14: TTFT vs TPOT
Which optimization primarily improves Time to First Token (TTFT)?
Question 15: Memory Hierarchy
The primary bottleneck in LLM decode is the bandwidth gap between: