Loading learning content...
Training neural networks requires memory—far more than you might expect. A model with 1 billion parameters (4GB in float32) can require 50GB or more of GPU memory during training. Where does this memory go?
The answer lies in a fundamental property of reverse-mode automatic differentiation: to compute gradients during the backward pass, we must remember what happened during the forward pass. Every intermediate activation, every pre-normalized value, every feature map must be stored until the backward pass processes that layer.
This page explores the memory requirements of backpropagation in depth, explaining why memory grows with network depth, how different operations contribute to memory usage, and what strategies we can employ to train large models within limited GPU memory.
By the end of this page, you will understand: (1) Why backpropagation requires storing forward-pass activations, (2) How memory scales with model depth, batch size, and sequence length, (3) The memory breakdown of different neural network components, (4) Techniques for reducing memory usage (mixed precision, activation recomputation, gradient accumulation), and (5) How to analyze and optimize memory usage in practice.
The memory requirement of backpropagation stems directly from how VJPs (vector-Jacobian products) are computed. To understand this, let's examine what information the backward pass needs.
The VJP Dependency:
Consider a simple operation $y = f(x)$. The VJP computes: $$\text{grad}_x = \mathbf{J}_f(x)^T \cdot \text{grad}_y$$
Notice that computing the Jacobian $\mathbf{J}_f(x)$ requires knowing $x$—the input value from the forward pass. Without storing $x$, we cannot compute the gradient.
Example: Matrix Multiplication
For $Y = XW$, the VJPs are:
The gradient w.r.t. weights requires the input activations $X$. This is why we must cache $X$ during the forward pass.
Example: ReLU Activation
For $y = \max(0, x)$, the VJP is: $$\text{grad}_x = \text{grad}y \odot \mathbf{1}{x > 0}$$
We need to know where $x$ was positive—information only available from the forward pass. We must either store $x$ or store the mask $\mathbf{1}_{x > 0}$.
Example: Batch Normalization
BatchNorm is particularly memory-hungry. The backward pass needs:
Each of these must be stored during forward pass and retrieved during backward.
Backpropagation trades memory for computation. By storing intermediate values, we avoid recomputing them during the backward pass. Without this storage, we'd need to recompute the entire forward pass for every layer's gradient—exponentially more computation.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
import numpy as np class LinearLayerWithCache: """ Demonstrates what must be cached for backpropagation. """ def __init__(self, in_features, out_features): self.W = np.random.randn(in_features, out_features) * 0.01 self.b = np.zeros(out_features) self.cache = None # Will store forward-pass data def forward(self, X): """ Forward pass: Y = X @ W + b Must cache X for backward pass! """ self.cache = { 'X': X, # Needed for: grad_W = X.T @ grad_Y # Note: We don't need to cache Y itself # (we receive grad_Y from upstream) } return X @ self.W + self.b def backward(self, grad_Y): """ Backward pass: Compute VJPs for X and parameters. Uses cached values from forward pass. """ X = self.cache['X'] # Retrieve cached input # VJPs grad_X = grad_Y @ self.W.T # Uses W (parameters, always available) grad_W = X.T @ grad_Y # Uses X (MUST be cached!) grad_b = grad_Y.sum(axis=0) # Clear cache to free memory (done after layer's backward) self.cache = None return grad_X, grad_W, grad_b def memory_usage(self, batch_size, in_features): """Estimate memory used by this layer's cache""" # Cache stores X of shape (batch_size, in_features) bytes_per_float = 4 # float32 cache_bytes = batch_size * in_features * bytes_per_float return cache_bytes # Calculate memory for a networkdef estimate_network_memory(layer_dims, batch_size, dtype_bytes=4): """ Estimate activation memory for a feedforward network. layer_dims: list of layer widths [input, hidden1, hidden2, ..., output] """ total_bytes = 0 print(f"Activation memory breakdown (batch_size={batch_size}):") print("-" * 50) for i in range(len(layer_dims) - 1): # Each layer caches its input cache_size = batch_size * layer_dims[i] * dtype_bytes total_bytes += cache_size print(f"Layer {i+1}: input ({batch_size} × {layer_dims[i]}) = " f"{cache_size / 1e6:.2f} MB") print("-" * 50) print(f"Total activation memory: {total_bytes / 1e6:.2f} MB") print(f" {total_bytes / 1e9:.4f} GB") return total_bytes # Example: A modest networklayer_dims = [1024, 2048, 2048, 2048, 1000] # Input: 1024, Output: 1000batch_size = 128 print("\nNetwork architecture:", " → ".join(map(str, layer_dims)))print()estimate_network_memory(layer_dims, batch_size)Understanding how memory scales with different factors is essential for designing trainable models. Let's analyze the key scaling relationships.
1. Scaling with Network Depth (L layers):
Memory for activations scales linearly with depth. Each layer adds its cached activations to the total. For a network with $L$ layers of width $H$:
$$\text{Activation Memory} = O(L \cdot B \cdot H)$$
where $B$ is batch size. This is why very deep networks (100+ layers) require substantial memory.
2. Scaling with Batch Size:
Memory scales linearly with batch size. Doubling the batch doubles activation memory. This is often the limiting factor for training—we'd like larger batches for stable gradients, but memory constrains us.
3. Scaling with Sequence Length (Transformers):
For transformers, the self-attention mechanism computes attention scores of shape $(B, H, T, T)$ where $T$ is sequence length. Memory scales quadratically with sequence length:
$$\text{Attention Memory} = O(B \cdot H \cdot T^2)$$
This is why long-context transformers are memory-intensive. A sequence of length 8192 uses 64× more memory than length 1024.
4. Scaling with Hidden Dimension:
For dense layers, activations scale linearly with hidden dimension. For attention, the scaling is also linear in the hidden dimension $D$ (number of heads $H$ times head dimension $D/H$).
Combined Scaling:
| Factor | MLP | CNN (ResNet) | Transformer |
|---|---|---|---|
| Depth (L layers) | O(L) | O(L) | O(L) |
| Batch Size (B) | O(B) | O(B) | O(B) |
| Width/Channels | O(H) | O(C) | O(D) |
| Spatial/Sequence | N/A | O(H × W) | O(T²) for attention |
| Typical bottleneck | Depth + Width | Early layers (large feature maps) | Attention (long sequences) |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
import numpy as npimport matplotlib.pyplot as plt def memory_scaling_analysis(): """ Analyze how different factors affect training memory. """ # Transformer memory model (simplified) def transformer_memory_gb( num_layers, hidden_dim, num_heads, seq_length, batch_size, bytes_per_param=4 # float32 ): """Estimate transformer training memory""" # Activations per layer # 1. Attention: Q, K, V projections + attention weights + output attention_memory = ( 3 * batch_size * seq_length * hidden_dim + # Q, K, V batch_size * num_heads * seq_length * seq_length + # Attention scores batch_size * seq_length * hidden_dim # Attention output ) # 2. FFN: Typically 4x hidden dim intermediate ffn_memory = ( batch_size * seq_length * hidden_dim * 4 + # Intermediate batch_size * seq_length * hidden_dim # Output ) # Total for all layers total_activations = num_layers * (attention_memory + ffn_memory) return total_activations * bytes_per_param / 1e9 # Analyze batch size scaling print("1. Memory vs Batch Size (seq=512, layers=12, hidden=768)") print("-" * 50) batch_sizes = [1, 2, 4, 8, 16, 32, 64] for bs in batch_sizes: mem = transformer_memory_gb(12, 768, 12, 512, bs) print(f"Batch size {bs:3d}: {mem:6.2f} GB") print() print("2. Memory vs Sequence Length (batch=8, layers=12, hidden=768)") print("-" * 50) seq_lengths = [128, 256, 512, 1024, 2048, 4096] for seq in seq_lengths: mem = transformer_memory_gb(12, 768, 12, seq, 8) print(f"Seq length {seq:4d}: {mem:6.2f} GB") print() print("3. Memory vs Model Depth (batch=8, seq=512, hidden=768)") print("-" * 50) depths = [6, 12, 24, 48, 96] for depth in depths: mem = transformer_memory_gb(depth, 768, 12, 512, 8) print(f"Layers {depth:3d}: {mem:6.2f} GB") print() print("Key Insights:") print("- Memory scales linearly with batch size and depth") print("- Memory scales QUADRATICALLY with sequence length (attention)") print("- This is why long-context models are memory-constrained") memory_scaling_analysis()While activations dominate memory during training, several other components contribute. Understanding the full breakdown helps identify optimization opportunities.
1. Model Parameters (Weights)
The actual parameters of the model. For a model with $P$ parameters in float32: $4P$ bytes.
2. Gradients
During backward pass, we compute and store gradients for all parameters. Same size as parameters: $4P$ bytes.
3. Optimizer States
Most optimizers store additional state per parameter:
For Adam, optimizer states add $8P$ bytes.
4. Activations
Intermediate values cached for backward pass. This scales with batch size and sequence length, not just parameter count. Often dominates memory for large batches.
5. Temporary Buffers
Allocated during computation:
These are harder to predict but can be significant.
| Component | Calculation | Memory (GB) | % of Total |
|---|---|---|---|
| Parameters | 1B × 4 bytes | 4 GB | ~10% |
| Gradients | 1B × 4 bytes | 4 GB | ~10% |
| Adam optimizer (m, v) | 1B × 8 bytes | 8 GB | ~20% |
| Activations (varies) | Batch × Seq × Layers | ~20 GB | ~50% |
| Temporary buffers | Variable | ~4 GB | ~10% |
| Total | ~40 GB | 100% |
For training with Adam in float32: expect ~16-20× parameter count in memory. A 7B model needs 112-140 GB. This is why quantization, offloading, and activation checkpointing are essential for large models.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
def training_memory_breakdown( num_params_billions, batch_size, seq_length, num_layers, hidden_dim, optimizer='adam', dtype='float32'): """ Comprehensive training memory estimation. """ bytes_per_param = 4 if dtype == 'float32' else 2 num_params = num_params_billions * 1e9 # 1. Model parameters params_gb = (num_params * bytes_per_param) / 1e9 # 2. Gradients (same size as params) grads_gb = params_gb # 3. Optimizer states if optimizer == 'adam': opt_multiplier = 2 # m and v elif optimizer == 'sgd_momentum': opt_multiplier = 1 # momentum buffer elif optimizer == 'adafactor': opt_multiplier = 1 # factored states (approximate) else: opt_multiplier = 0 optimizer_gb = params_gb * opt_multiplier # 4. Activations (simplified transformer model) # Each layer stores: input, attention outputs, FFN intermediates activations_per_layer = ( batch_size * seq_length * hidden_dim * 3 + # Q, K, V batch_size * seq_length * hidden_dim + # attention out batch_size * seq_length * hidden_dim * 4 + # FFN intermediate batch_size * seq_length * hidden_dim # FFN out ) activations_total = activations_per_layer * num_layers * bytes_per_param activations_gb = activations_total / 1e9 # 5. Temporary buffers (rough estimate: 10-20% of activations) temp_gb = activations_gb * 0.15 # Total total_gb = params_gb + grads_gb + optimizer_gb + activations_gb + temp_gb print(f"\nTraining Memory Breakdown: {num_params_billions}B parameter model") print("=" * 60) print(f"Configuration:") print(f" Batch size: {batch_size}, Seq length: {seq_length}") print(f" Layers: {num_layers}, Hidden dim: {hidden_dim}") print(f" Dtype: {dtype}, Optimizer: {optimizer}") print() print(f"{'Component':<25} {'Memory (GB)':>12} {'% of Total':>12}") print("-" * 60) print(f"{'Model parameters':<25} {params_gb:>12.2f} {100*params_gb/total_gb:>11.1f}%") print(f"{'Gradients':<25} {grads_gb:>12.2f} {100*grads_gb/total_gb:>11.1f}%") print(f"{'Optimizer states':<25} {optimizer_gb:>12.2f} {100*optimizer_gb/total_gb:>11.1f}%") print(f"{'Activations':<25} {activations_gb:>12.2f} {100*activations_gb/total_gb:>11.1f}%") print(f"{'Temporary buffers':<25} {temp_gb:>12.2f} {100*temp_gb/total_gb:>11.1f}%") print("-" * 60) print(f"{'TOTAL':<25} {total_gb:>12.2f} {'100.0%':>12}") # Practical GPU recommendations print() print("GPU Recommendations:") if total_gb <= 12: print(" ✓ Fits on consumer GPU (RTX 3080/3090, 12GB)") elif total_gb <= 24: print(" ✓ Fits on RTX 4090 / A10 (24GB)") elif total_gb <= 48: print(" ✓ Fits on A40 / 2x A10 (48GB)") elif total_gb <= 80: print(" ✓ Fits on A100 80GB") else: print(f" ⚠ Requires model parallelism or ~{int(np.ceil(total_gb/80))} A100 80GB") return { 'params': params_gb, 'grads': grads_gb, 'optimizer': optimizer_gb, 'activations': activations_gb, 'temp': temp_gb, 'total': total_gb } # Example: 1B parameter modeltraining_memory_breakdown( num_params_billions=1.0, batch_size=8, seq_length=2048, num_layers=24, hidden_dim=2048, optimizer='adam', dtype='float32')Mixed precision training uses lower-precision floating point formats (float16 or bfloat16) to reduce memory usage and increase compute throughput, while maintaining training stability.
The Core Idea:
This reduces activation memory by ~50% and enables using tensor cores for faster matrix operations.
| Format | Bits | Exponent | Mantissa | Range | Precision |
|---|---|---|---|---|---|
| float32 (FP32) | 32 | 8 bits | 23 bits | ±3.4×10³⁸ | High |
| float16 (FP16) | 16 | 5 bits | 10 bits | ±65504 | Low (loss scaling needed) |
| bfloat16 (BF16) | 16 | 8 bits | 7 bits | ±3.4×10³⁸ | Medium (same range as FP32) |
| TensorFloat32 | 19 | 8 bits | 10 bits | ±3.4×10³⁸ | Medium (A100+ only) |
Why Loss Scaling is Required for FP16:
FP16 has a very limited dynamic range. Small gradients (common in deep networks) can underflow to zero, killing training. Loss scaling multiplies the loss by a large factor (e.g., 1024) before backward pass, then divides gradients afterward:
BFloat16 Advantage: Same exponent bits as FP32, so same dynamic range—no loss scaling needed! This is why BF16 is preferred on modern hardware (A100+, H100).
Mixed precision can reduce training memory by 35-50%: activations are halved (16-bit), optimizer still uses 32-bit, gradients can be 16-bit computing/32-bit accumulating. Modern frameworks (PyTorch AMP, TensorFlow Mixed Precision) make this automatic.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
# PyTorch Automatic Mixed Precision Example import torchfrom torch.cuda.amp import autocast, GradScaler def mixed_precision_training_loop(model, optimizer, dataloader, num_epochs): """ Training loop with automatic mixed precision. Key components: 1. autocast: Runs forward pass in FP16/BF16 2. GradScaler: Scales loss and unscales gradients (for FP16) """ # GradScaler handles loss scaling for FP16 # Not needed for BF16, but doesn't hurt scaler = GradScaler() for epoch in range(num_epochs): for batch in dataloader: inputs, targets = batch optimizer.zero_grad() # Forward pass in mixed precision with autocast(dtype=torch.float16): # or torch.bfloat16 outputs = model(inputs) loss = loss_fn(outputs, targets) # Backward pass with scaled gradients scaler.scale(loss).backward() # Unscale gradients and update weights scaler.step(optimizer) scaler.update() # Memory comparisondef compare_memory_usage(): """ Compare memory usage: FP32 vs FP16 activations """ import sys # Simulate activation tensor batch, seq, hidden = 32, 2048, 4096 # FP32 activation fp32_tensor = torch.randn(batch, seq, hidden, dtype=torch.float32) fp32_bytes = fp32_tensor.element_size() * fp32_tensor.nelement() # FP16 activation fp16_tensor = torch.randn(batch, seq, hidden, dtype=torch.float16) fp16_bytes = fp16_tensor.element_size() * fp16_tensor.nelement() print(f"Activation shape: ({batch}, {seq}, {hidden})") print(f"FP32 memory: {fp32_bytes / 1e9:.3f} GB") print(f"FP16 memory: {fp16_bytes / 1e9:.3f} GB") print(f"Memory reduction: {100 * (1 - fp16_bytes/fp32_bytes):.0f}%") # For a full model with many layers num_layers = 24 total_fp32 = num_layers * fp32_bytes total_fp16 = num_layers * fp16_bytes print(f"\n24-layer model activations:") print(f"FP32 total: {total_fp32 / 1e9:.2f} GB") print(f"FP16 total: {total_fp16 / 1e9:.2f} GB") print(f"Saved: {(total_fp32 - total_fp16) / 1e9:.2f} GB") # Note: Always use mixed precision for production training!# It's essentially free performance and memory savings.Gradient accumulation simulates larger batch sizes without the memory cost. Instead of processing one large batch, we process multiple smaller batches and accumulate their gradients before updating weights.
The Technique:
Effective batch size = mini-batch size × accumulation steps
Why This Works:
Gradient descent with batch size $B$ computes: $$\theta \leftarrow \theta - \eta \frac{1}{B} \sum_{i=1}^{B} \nabla L_i$$
This is mathematically equivalent to: $$\theta \leftarrow \theta - \eta \frac{1}{k \cdot b} \sum_{j=1}^{k} \sum_{i=1}^{b} \nabla L_{j,i}$$
where $b$ is mini-batch size and $k$ is accumulation steps. We get the exact same update as a batch of size $B = k \cdot b$!
Memory Impact: Activations only need to fit one mini-batch. We trade compute time (more forward/backward passes) for memory (smaller batch at once).
Use gradient accumulation when: (1) Your desired batch size exceeds GPU memory, (2) You're training on consumer hardware but want stable training, (3) You're using very high-resolution inputs (images) or long sequences. The main downside is slower training (k× the forward/backward passes).
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
# Gradient accumulation implementation def training_with_gradient_accumulation( model, optimizer, dataloader, accumulation_steps=4): """ Training loop with gradient accumulation. Effective batch size = dataloader.batch_size * accumulation_steps """ model.train() optimizer.zero_grad() for step, batch in enumerate(dataloader): inputs, targets = batch # Forward and backward pass outputs = model(inputs) loss = loss_fn(outputs, targets) # Scale loss by accumulation steps # (so that accumulated gradient is properly averaged) loss = loss / accumulation_steps loss.backward() # Gradients accumulate in .grad # Update weights every accumulation_steps if (step + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() # Handle remaining steps if (step + 1) % accumulation_steps != 0: optimizer.step() optimizer.zero_grad() # Memory comparisondef compare_batch_strategies(): """ Compare memory requirements for different batching strategies. """ # Target batch size: 128 # GPU memory limit: 16 GB # Activation memory per sample: 200 MB mem_per_sample = 0.2 # GB target_batch = 128 gpu_memory = 16 # GB print("Scenario: Target batch size 128, 200 MB/sample activations") print("=" * 60) # Strategy 1: Full batch (if it fits) full_batch_mem = target_batch * mem_per_sample print(f"\n1. Full batch (128 samples):") print(f" Required: {full_batch_mem:.1f} GB") print(f" Status: {'✓ Fits' if full_batch_mem <= gpu_memory else '✗ Exceeds limit'}") # Strategy 2: Gradient accumulation accumulation_steps = 8 micro_batch = target_batch // accumulation_steps micro_batch_mem = micro_batch * mem_per_sample print(f"\n2. Gradient accumulation (8 steps × 16 samples):") print(f" Required: {micro_batch_mem:.1f} GB") print(f" Status: {'✓ Fits' if micro_batch_mem <= gpu_memory else '✗ Exceeds limit'}") print(f" Tradeoff: 8× forward/backward passes") # Strategy 3: Smaller effective batch max_batch = int(gpu_memory / mem_per_sample) print(f"\n3. Reduced batch size (max {max_batch} samples):") print(f" Required: {max_batch * mem_per_sample:.1f} GB") print(f" Status: ✓ Fits") print(f" Tradeoff: Different training dynamics, may need LR adjustment") compare_batch_strategies()So far we've seen techniques that work within the constraint of storing all activations. But there's a more radical approach: what if we don't store all activations?
The Observation:
We only need activations during the backward pass. Instead of storing them, we could recompute them when needed. This trades compute time for memory.
Basic Idea:
This technique, called activation checkpointing or gradient checkpointing, is covered in depth in the next section. It's essential for training very large models.
Activation checkpointing exemplifies a fundamental principle: memory and computation are often interchangeable. By spending more compute (recomputing activations), we save memory. Many memory optimization techniques exploit this tradeoff.
Quick Comparison:
| Approach | Memory | Compute | When to Use |
|---|---|---|---|
| Standard | O(L) | 1× forward, 1× backward | Memory abundant |
| Checkpointing | O(√L) | 1× forward, ~2× backward | Memory limited |
| Full recompute | O(1) | L× forward per backward | Extreme memory limits |
The next section will dive deep into gradient checkpointing, deriving the optimal checkpoint placement and understanding the exact tradeoffs involved.
We have developed a comprehensive understanding of why backpropagation consumes memory and how that memory scales with model and training parameters.
Looking Ahead:
The next section dives deep into gradient checkpointing—the technique for trading compute for memory by strategically recomputing activations. We'll derive the optimal checkpoint placement, understand the exact memory-compute tradeoff, and see how modern frameworks implement this essential technique for training large models.
You now have a mental model for understanding and estimating training memory requirements. This knowledge is essential for planning how to train models on available hardware, and for making informed decisions about memory optimization techniques. Next, we'll master gradient checkpointing—the key technique for training models larger than your GPU.