Loading learning content...
What if the model you want to train requires more activation memory than your GPU can hold? You could reduce batch size to 1, but that might still not be enough for very deep or very wide networks. You could buy more GPUs, but that's expensive and not always possible.
There's another option: gradient checkpointing (also called activation recomputation or rematerialization). Instead of storing all intermediate activations for the backward pass, we store only a few "checkpoints" and recompute the rest when needed.
This technique enables training models that would otherwise be impossible on available hardware. It's the key to fitting GPT-3 scale models on reasonable GPU clusters, and it's essential knowledge for any practitioner working with large models.
By the end of this page, you will understand: (1) The core idea of activation recomputation, (2) The memory-compute tradeoff mathematically, (3) Optimal checkpoint placement strategies, (4) Implementation in modern frameworks, and (5) When and how to apply checkpointing in practice.
Standard backpropagation caches all activations during the forward pass for use in the backward pass. But we could also choose a different strategy:
Standard Backprop:
Checkpointed Backprop:
Why This Works:
The key insight is that activations are deterministic. Given the same inputs and weights, a layer will produce the same outputs. So we can always recompute any activation—we just need to have saved its inputs (or the ability to recompute those inputs in turn).
The Tradeoff:
We save memory at the cost of about 33% extra compute (recomputing the forward pass portions). For memory-constrained training, this is an excellent trade.
All large-scale training uses checkpointing: GPT-3, LLaMA, PaLM, etc. It's not a niche optimization—it's essential infrastructure for training models with billions of parameters. Understanding it is crucial for anyone working at scale.
Let's derive the optimal checkpoint strategy mathematically. Consider a network with $L$ layers, where each layer's activations require memory $M$.
Standard Backprop:
Checkpointing with $k$ Checkpoints:
If we place $k$ evenly-spaced checkpoints, the network is divided into $k$ segments of $L/k$ layers each.
Memory Analysis:
During backward pass through segment $i$:
Total memory: $k$ checkpoints + $L/k$ recomputed activations in current segment: $$\text{Memory} = k \cdot M + \frac{L}{k} \cdot M = \left(k + \frac{L}{k}\right) M$$
Optimal Checkpoint Count:
To minimize memory, take derivative with respect to $k$: $$\frac{d}{dk}\left(k + \frac{L}{k}\right) = 1 - \frac{L}{k^2} = 0$$ $$k^* = \sqrt{L}$$
Optimal Memory: $2\sqrt{L} \cdot M = O(\sqrt{L})$
This is a remarkable result: by placing $\sqrt{L}$ checkpoints, we reduce memory from $O(L)$ to $O(\sqrt{L})$!
| Layers (L) | Standard Memory | Checkpointed Memory | Reduction |
|---|---|---|---|
| 16 | 16M | 8M (k=4) | 50% |
| 64 | 64M | 16M (k=8) | 75% |
| 144 | 144M | 24M (k=12) | 83% |
| 256 | 256M | 32M (k=16) | 87% |
| 1024 | 1024M | 64M (k=32) | 94% |
Compute Overhead:
Recomputing each segment during backward pass effectively adds an extra forward pass:
Overhead: 50% more total compute, or equivalently, the backward pass takes ~2× as long.
In Practice: The 50% compute overhead is often much less impactful than it sounds, because GPU utilization during the backward pass is often lower due to memory constraints. By reducing memory, checkpointing can enable larger batch sizes, which improves GPU utilization and can partially offset the compute overhead.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
import numpy as npimport matplotlib.pyplot as plt def analyze_checkpointing(num_layers, activation_per_layer_mb=100): """ Analyze memory-compute tradeoff for gradient checkpointing. """ L = num_layers M = activation_per_layer_mb print(f"Network: {L} layers, {M} MB per activation") print("=" * 60) # Standard backprop standard_memory = L * M standard_compute = 2 * L # 1 forward + 1 backward print(f"\nStandard backpropagation:") print(f" Memory: {standard_memory} MB = {standard_memory/1024:.2f} GB") print(f" Compute: {standard_compute} units") # Optimal checkpointing k_optimal = int(np.sqrt(L)) checkpoint_memory = (k_optimal + L / k_optimal) * M checkpoint_compute = 3 * L # 1 forward + 1 recompute + 1 backward print(f"\nOptimal checkpointing (k={k_optimal}):") print(f" Memory: {checkpoint_memory:.0f} MB = {checkpoint_memory/1024:.2f} GB") print(f" Compute: {checkpoint_compute} units") print(f" Memory reduction: {100 * (1 - checkpoint_memory/standard_memory):.0f}%") print(f" Compute overhead: {100 * (checkpoint_compute/standard_compute - 1):.0f}%") # Explore different checkpoint counts print(f"\nMemory for different checkpoint counts:") print("-" * 40) k_values = range(1, L + 1) memories = [] for k in k_values: mem = (k + L / k) * M memories.append(mem) # Print key values for k in [1, 2, int(np.sqrt(L)), L//2, L]: if k <= L: mem = (k + L / k) * M print(f" k={k:3d}: {mem:8.0f} MB") return k_values, memories def plot_memory_vs_checkpoints(num_layers=100): """ Visualize memory as a function of checkpoint count. """ L = num_layers k_values = np.arange(1, L + 1) memories = k_values + L / k_values k_optimal = np.sqrt(L) plt.figure(figsize=(10, 6)) plt.plot(k_values, memories, 'b-', linewidth=2) plt.axvline(k_optimal, color='r', linestyle='--', label=f'Optimal k=√L={k_optimal:.0f}') plt.axhline(L, color='gray', linestyle=':', alpha=0.5, label='Standard (no checkpointing)') plt.scatter([k_optimal], [2*np.sqrt(L)], color='red', s=100, zorder=5) plt.xlabel('Number of Checkpoints (k)') plt.ylabel('Memory Units (normalized by M)') plt.title(f'Memory vs Checkpoint Count (L={L} layers)') plt.legend() plt.grid(True, alpha=0.3) # Annotate minimum plt.annotate(f'Minimum: 2√L = {2*np.sqrt(L):.1f}', xy=(k_optimal, 2*np.sqrt(L)), xytext=(k_optimal + 10, 2*np.sqrt(L) + 10), arrowprops=dict(arrowstyle='->', color='red')) plt.savefig('checkpointing_memory.png', dpi=150) plt.show() # Run analysisanalyze_checkpointing(num_layers=100, activation_per_layer_mb=100)While the mathematical analysis suggests evenly-spaced checkpoints, real networks have structure that can be exploited for smarter placement.
Strategy 1: Uniform Checkpointing
Place checkpoints every $k$ layers. Simple and effective for homogeneous networks (all layers similar).
$$\text{Checkpoint at layers: } k, 2k, 3k, \ldots$$
Strategy 2: Checkpoint at Block Boundaries
Modern networks are organized into blocks (residual blocks, transformer layers). Checkpoint at block boundaries:
Strategy 3: Selective Checkpointing
Some operations are memory-intensive but cheap to recompute (e.g., activations). Others are compute-intensive (e.g., matrix multiplications). Smart checkpointing:
Strategy 4: Automatic Optimal Checkpointing
Some frameworks (like JAX with jax.checkpoint) can automatically determine optimal checkpoint placement by analyzing the computation graph.
Strategy 5: Trade-off Tuning
Allow user to specify memory budget, then automatically select checkpoint count: $$k = \arg\min_k \left(k + \frac{L}{k}\right) \text{ s.t. memory} \leq \text{budget}$$
For transformers, a common pattern is to checkpoint after each attention layer or after each transformer block. This balances the expensive attention computation (which we want to avoid recomputing fully) with the simpler FFN layers.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
import numpy as np def compare_checkpoint_strategies(num_layers, block_size=4): """ Compare different checkpoint placement strategies. """ L = num_layers print(f"Network: {L} layers (organized into blocks of {block_size})") print("=" * 60) # Strategy 1: Uniform checkpointing (optimal spacing) k_uniform = int(np.sqrt(L)) uniform_memory = k_uniform + L / k_uniform uniform_checkpoints = list(range(0, L, L // k_uniform))[:k_uniform] print(f"\n1. Uniform Checkpointing (k={k_uniform}):") print(f" Checkpoints: {uniform_checkpoints[:5]}...") print(f" Memory: {uniform_memory:.1f} units") # Strategy 2: Block boundaries num_blocks = L // block_size k_blocks = int(np.sqrt(num_blocks)) checkpoint_blocks = list(range(0, L, block_size * (num_blocks // k_blocks))) block_memory = len(checkpoint_blocks) + L / len(checkpoint_blocks) print(f"\n2. Block Boundary Checkpointing:") print(f" Checkpoints at blocks: {checkpoint_blocks[:5]}...") print(f" Memory: {block_memory:.1f} units") # Strategy 3: Selective (checkpoint only attention in transformers) # Assume 50% of layers are "expensive" (attention) expensive_layers = L // 2 checkpoint_expensive = list(range(0, L, 2)) # Every attention layer # Memory: all attention outputs + small working set selective_memory = expensive_layers + block_size print(f"\n3. Selective Checkpointing (expensive layers only):") print(f" Checkpoint expensive outputs: {checkpoint_expensive[:5]}...") print(f" Memory: {selective_memory:.1f} units") print(f" (Recomputes only cheap layers)") # Strategy 4: Custom memory budget target_memory = L // 4 # Solve: k + L/k = target_memory # k^2 - target*k + L = 0 discriminant = target_memory**2 - 4*L if discriminant >= 0: k_custom = int((target_memory + np.sqrt(discriminant)) / 2) actual_memory = k_custom + L / k_custom print(f"\n4. Memory-Budget Constrained (target: {target_memory} units):") print(f" Checkpoints: k={k_custom}") print(f" Actual memory: {actual_memory:.1f} units") else: print(f"\n4. Memory-Budget Constrained (target: {target_memory} units):") print(f" Target too aggressive - minimum is {2*np.sqrt(L):.1f}") return { 'uniform': uniform_memory, 'blocks': block_memory, 'selective': selective_memory, } compare_checkpoint_strategies(num_layers=48, block_size=4)Modern deep learning frameworks provide built-in support for gradient checkpointing. Understanding how to use these APIs is essential for training large models.
PyTorch: torch.utils.checkpoint
PyTorch provides checkpoint_sequential for sequential models and checkpoint for wrapping arbitrary functions.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
import torchimport torch.nn as nnfrom torch.utils.checkpoint import checkpoint, checkpoint_sequential class TransformerBlock(nn.Module): """A single transformer block with attention and FFN""" def __init__(self, hidden_dim, num_heads, ff_dim): super().__init__() self.attention = nn.MultiheadAttention(hidden_dim, num_heads) self.norm1 = nn.LayerNorm(hidden_dim) self.ff = nn.Sequential( nn.Linear(hidden_dim, ff_dim), nn.GELU(), nn.Linear(ff_dim, hidden_dim), ) self.norm2 = nn.LayerNorm(hidden_dim) def forward(self, x): # Self-attention with residual attn_out, _ = self.attention(x, x, x) x = self.norm1(x + attn_out) # FFN with residual ff_out = self.ff(x) x = self.norm2(x + ff_out) return x class CheckpointedTransformer(nn.Module): """ Transformer with gradient checkpointing. Only checkpoint every nth block to balance memory and compute. """ def __init__(self, num_layers, hidden_dim, num_heads, ff_dim, checkpoint_every=2): super().__init__() self.layers = nn.ModuleList([ TransformerBlock(hidden_dim, num_heads, ff_dim) for _ in range(num_layers) ]) self.checkpoint_every = checkpoint_every def forward(self, x): for i, layer in enumerate(self.layers): if self.training and (i + 1) % self.checkpoint_every == 0: # Checkpoint this layer - don't store activations # use_reentrant=False is recommended for new code x = checkpoint(layer, x, use_reentrant=False) else: # Normal forward pass - store activations x = layer(x) return x # Simple checkpointing with checkpoint_sequentialclass SimpleCheckpointedModel(nn.Module): """Use checkpoint_sequential for a sequence of layers""" def __init__(self, layers): super().__init__() self.layers = nn.Sequential(*layers) def forward(self, x): # Checkpoint every 2 segments segments = 4 # Number of checkpoint segments return checkpoint_sequential(self.layers, segments, x) # Memory comparisondef compare_memory(hidden_dim=512, seq_len=1024, batch=8, num_layers=12): """Compare memory with and without checkpointing""" # Without checkpointing model_standard = nn.Sequential(*[ TransformerBlock(hidden_dim, 8, hidden_dim * 4) for _ in range(num_layers) ]).cuda() x = torch.randn(seq_len, batch, hidden_dim).cuda() # Clear cache and measure torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() out = model_standard(x) loss = out.sum() loss.backward() standard_mem = torch.cuda.max_memory_allocated() / 1e9 # With checkpointing model_checkpoint = CheckpointedTransformer( num_layers, hidden_dim, 8, hidden_dim * 4, checkpoint_every=2 ).cuda() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() out = model_checkpoint(x) loss = out.sum() loss.backward() checkpoint_mem = torch.cuda.max_memory_allocated() / 1e9 print(f"Memory without checkpointing: {standard_mem:.2f} GB") print(f"Memory with checkpointing: {checkpoint_mem:.2f} GB") print(f"Reduction: {100 * (1 - checkpoint_mem/standard_mem):.0f}%")TensorFlow/Keras:
Use tf.recompute_grad to wrap functions for gradient checkpointing:
@tf.recompute_grad
def checkpointed_layer(x):
return layer(x)
JAX:
JAX provides jax.checkpoint (or jax.remat for "rematerialization"):
import jax
@jax.checkpoint
def transformer_block(params, x):
# Block computation
return output
JAX's approach is particularly elegant—it can automatically optimize checkpoint placement.
Beyond basic activation checkpointing, several advanced techniques push the memory-compute tradeoff further.
1. Selective Layer Checkpointing
Not all layers are equal. Attention layers are memory-intensive but compute-intensive to recompute. Linear layers are cheaper. Strategy:
2. CPU Offloading
Instead of recomputing, offload activations to CPU memory:
This trades memory for CPU↔GPU bandwidth rather than compute.
3. Disk Offloading
For extreme cases, offload to NVMe SSD. Very slow but enables training of arbitrarily large models on limited hardware.
4. Pipeline Parallelism + Checkpointing
Combine gradient checkpointing with pipeline parallelism:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
import torchimport torch.nn as nnimport threadingfrom collections import deque class CPUOffloadCheckpoint: """ Offload activations to CPU instead of recomputing. Useful when recomputation is expensive (attention). """ def __init__(self, max_cpu_tensors=10): self.cached = {} self.cpu_cache = deque(maxlen=max_cpu_tensors) def save_for_backward(self, name, tensor): """Move tensor to CPU asynchronously""" # Non-blocking copy to CPU cpu_tensor = tensor.to('cpu', non_blocking=True) self.cached[name] = cpu_tensor def restore(self, name): """Move tensor back to GPU""" if name in self.cached: gpu_tensor = self.cached[name].to('cuda', non_blocking=True) del self.cached[name] return gpu_tensor raise KeyError(f"Tensor {name} not found in cache") class SelectiveCheckpointTransformer(nn.Module): """ Selectively checkpoint based on operation type. - Expensive ops (attention): CPU offload - Cheap ops (LayerNorm, GELU): recompute """ def __init__(self, hidden_dim, num_heads, ff_dim): super().__init__() self.attention = nn.MultiheadAttention(hidden_dim, num_heads) self.norm1 = nn.LayerNorm(hidden_dim) self.ff1 = nn.Linear(hidden_dim, ff_dim) self.gelu = nn.GELU() self.ff2 = nn.Linear(ff_dim, hidden_dim) self.norm2 = nn.LayerNorm(hidden_dim) self.offload = CPUOffloadCheckpoint() def forward(self, x): # Attention (expensive to recompute - offload result) attn_out, _ = self.attention(x, x, x) self.offload.save_for_backward('attn_out', attn_out) # Cheap ops - will recompute residual = x + attn_out normed = self.norm1(residual) # Cheap to recompute # FFN intermediate (large but matrix mult is cheap) intermediate = self.gelu(self.ff1(normed)) # Can recompute # Final output out = self.ff2(intermediate) out = self.norm2(residual + out) return out class GradientAccumWithCheckpoint: """ Combine gradient accumulation with checkpointing. This enables very large effective batch sizes on limited memory. """ def __init__(self, model, optimizer, accumulation_steps=8): self.model = model self.optimizer = optimizer self.accumulation_steps = accumulation_steps def train_step(self, batches): """ Process multiple micro-batches with gradient accumulation. Each micro-batch uses checkpointing. """ self.optimizer.zero_grad() total_loss = 0 for i, (x, y) in enumerate(batches): # Enable checkpointing for this forward pass with torch.cuda.amp.autocast(): # Also use mixed precision # Checkpointed forward output = self.forward_with_checkpoint(x) loss = self.loss_fn(output, y) / self.accumulation_steps # Backward (gradients accumulate) loss.backward() total_loss += loss.item() # Clear activation cache but keep gradients torch.cuda.empty_cache() # Update weights with accumulated gradients self.optimizer.step() return total_loss def forward_with_checkpoint(self, x): """Forward pass with checkpointing enabled""" from torch.utils.checkpoint import checkpoint_sequential return checkpoint_sequential( self.model.layers, segments=4, # Number of checkpoint segments input=x )Gradient checkpointing isn't always necessary or beneficial. Here's practical guidance for when and how to apply it.
Start without checkpointing. If you run out of memory, enable it for every-other layer (checkpoint_every=2). If still memory-constrained, checkpoint every layer. Monitor training throughput (samples/second) to ensure the compute overhead is acceptable.
Tuning Checkpoint Frequency:
| Memory Pressure | Checkpoint Every | Approx Overhead |
|---|---|---|
| Mild | 4 layers | 10-15% |
| Moderate | 2 layers | 25-35% |
| High | 1 layer | 40-50% |
| Extreme | + CPU offload | 50-100% |
Always measure actual memory and throughput for your specific model and hardware.
We have developed a comprehensive understanding of gradient checkpointing—the essential technique for training models larger than available GPU memory.
Module Complete:
With this page, we conclude our deep exploration of the backpropagation algorithm. We've covered:
You now have a complete, professional-level understanding of how neural networks are trained. This knowledge is essential for diagnosing training issues, designing architectures, and scaling to large models.
You have achieved deep mastery of the backpropagation algorithm—from the chain rule mathematics to practical memory optimization techniques. This knowledge underlies all of modern deep learning. Whether you're debugging gradient issues, optimizing for hardware, or pushing the limits of model scale, you now have the conceptual foundation to understand and solve the challenges you'll face.