Loading learning content...
The beginning of neural network training is paradoxically the most delicate phase. Parameters are randomly initialized, gradient estimates are maximally noisy, adaptive optimizer statistics are uninitialized, and batch normalization is calibrating. Jumping immediately to a high learning rate in this volatile state can cause irreversible training failures—exploding gradients, divergence, or convergence to poor solutions.
Learning rate warmup addresses this challenge by gradually increasing the learning rate from a small initial value to the target learning rate over the first portion of training. This seemingly simple technique has become essential for training large models, large batch training, and transformer architectures.
This page provides a deep exploration of warmup strategies: the theoretical principles that necessitate warmup, the various warmup curve shapes and their properties, practical implementation considerations, and guidelines for tuning warmup hyperparameters. By the end, you'll understand warmup not as a mysterious trick but as a principled response to early-training optimization dynamics.
By the end of this page, you will understand why warmup is essential for large models and large batches, implement linear, exponential, and gradual warmup variants, tune warmup duration based on your specific training scenario, and diagnose warmup-related training issues.
Understanding why warmup helps requires examining several interacting factors that make early training inherently unstable.
The Random Initialization Problem:
Neural network weights are typically initialized from distributions designed to maintain activation variance (Xavier, He initialization, etc.). Despite these efforts, the initial parameter configuration is essentially arbitrary—the network hasn't learned anything yet.
In this state, gradients are particularly problematic:
High Variance: Gradient estimates vary wildly between batches because the model's predictions are essentially random.
Poor Direction Quality: The 'direction' indicated by early gradients often points nowhere useful—the loss landscape near random initialization differs dramatically from the landscape near a good solution.
Magnitude Instability: Gradient magnitudes can be unexpectedly large or small depending on the random initialization.
Applying a high learning rate to these chaotic gradients amplifies the chaos, potentially causing:
| Factor | Why It's Unstable Early | How Warmup Helps |
|---|---|---|
| Random weights | Arbitrary starting point, no learned structure | Small LR prevents large, arbitrary updates |
| Noisy gradients | High variance from mini-batch sampling | Small LR reduces noise amplification |
| Adam/RMSprop init | Second moment estimates (v) start at zero | Time for optimizer statistics to stabilize |
| Batch normalization | Running statistics need calibration | Stable updates while stats accumulate |
| Large batch size | Linear scaling hypothesis breakdown | Gradual LR increase respects convergence limits |
The Adaptive Optimizer Initialization Problem:
Adaptive optimizers (Adam, RMSprop, AdaGrad) maintain per-parameter statistics for learning rate scaling:
These are typically initialized to zero (or small values). Before sufficient gradient history accumulates, the adaptive scaling is unreliable:
$$\text{Adam update} = \frac{m_t / (1 - \beta_1^t)}{\sqrt{v_t / (1 - \beta_2^t)} + \epsilon}$$
The bias correction terms $(1 - \beta^t)$ help, but the underlying issue remains: early $v_t$ estimates are noisy. Combined with a high learning rate, this can produce wildly incorrect updates.
Warmup provides time for these statistics to stabilize before applying aggressive learning rates.
Large batch training (batch size >> 1024) particularly requires warmup. The linear scaling rule suggests multiplying LR proportionally to batch size, but this fails dramatically at training start. Warmup is essential to prevent immediate divergence when using large batches with correspondingly high learning rates.
The most common warmup strategy is linear warmup, but several alternatives exist with different properties.
Linear Warmup:
The learning rate increases linearly from a starting value to the target:
$$\eta_t = \eta_{start} + \frac{t}{T_w}(\eta_{target} - \eta_{start})$$
Where:
Properties:
Exponential Warmup:
The learning rate increases exponentially:
$$\eta_t = \eta_{start} \cdot \left(\frac{\eta_{target}}{\eta_{start}}\right)^{t/T_w}$$
Properties:
| Curve Type | Formula | Shape | Best For |
|---|---|---|---|
| Linear | η_start + (t/T_w)(η_target - η_start) | Straight line | Most scenarios, default choice |
| Exponential | η_start × (η_target/η_start)^(t/T_w) | Curves upward | Quick stabilization, aggressive ramp |
| Polynomial | η_target × (t/T_w)^p | Customizable curve | Fine control over warmup shape |
| Gradual (sqrt) | η_target × sqrt(t/T_w) | Slower initial increase | Very sensitive architectures |
| Cosine warmup | η_target × (1 - cos(π×t/2T_w))/2 | S-curve | Smooth start and smooth transition |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
import numpy as npimport torchfrom torch.optim.lr_scheduler import LambdaLRfrom typing import Optional, Literalimport math # =====================================================# Implementation 1: Linear Warmup# =====================================================def linear_warmup( optimizer, warmup_epochs: int, warmup_start_factor: float = 0.0, # 0 = start from zero last_epoch: int = -1): """ Standard linear warmup. LR ramps linearly from warmup_start_factor * base_lr to base_lr over warmup_epochs epochs. """ def lr_lambda(epoch: int) -> float: if epoch < warmup_epochs: # Linear interpolation return warmup_start_factor + (1.0 - warmup_start_factor) * epoch / warmup_epochs return 1.0 # After warmup, return factor of 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) # =====================================================# Implementation 2: Exponential Warmup# =====================================================def exponential_warmup( optimizer, warmup_epochs: int, warmup_start_factor: float = 0.01, # Must be > 0 for exp last_epoch: int = -1): """ Exponential warmup: LR grows exponentially to target. Faster initial increase, smoother approach to target. """ def lr_lambda(epoch: int) -> float: if epoch < warmup_epochs: # Exponential growth from start_factor to 1.0 progress = epoch / warmup_epochs return warmup_start_factor * (1.0 / warmup_start_factor) ** progress return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) # =====================================================# Implementation 3: Polynomial Warmup# =====================================================def polynomial_warmup( optimizer, warmup_epochs: int, power: float = 1.0, # 1.0 = linear, 2.0 = quadratic, 0.5 = sqrt last_epoch: int = -1): """ Polynomial warmup: LR = base_lr * (progress)^power power < 1: Faster initial increase (like sqrt) power = 1: Linear (standard) power > 1: Slower initial increase (like quadratic) """ def lr_lambda(epoch: int) -> float: if epoch < warmup_epochs: return (epoch / warmup_epochs) ** power return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) # =====================================================# Implementation 4: Cosine Warmup (S-curve)# =====================================================def cosine_warmup( optimizer, warmup_epochs: int, last_epoch: int = -1): """ Cosine warmup: S-shaped curve from 0 to target. LR = base_lr * (1 - cos(π * t / (2 * T_w))) / 2 Smooth at both start (0) and transition to main schedule. """ def lr_lambda(epoch: int) -> float: if epoch < warmup_epochs: return (1 - math.cos(math.pi * epoch / (2 * warmup_epochs))) / 2 return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) # =====================================================# Implementation 5: Unified Warmup Scheduler# =====================================================class UnifiedWarmup: """ Unified warmup implementation with multiple curve options and subsequent schedule integration. """ WARMUP_TYPES = Literal['linear', 'exponential', 'polynomial', 'cosine'] def __init__( self, optimizer, warmup_epochs: int, warmup_type: str = 'linear', warmup_start_factor: float = 0.0, polynomial_power: float = 1.0, post_warmup_scheduler = None ): """ Args: optimizer: Wrapped optimizer warmup_epochs: Duration of warmup warmup_type: 'linear', 'exponential', 'polynomial', or 'cosine' warmup_start_factor: Starting LR as fraction of base (for linear/exp) polynomial_power: Exponent for polynomial warmup post_warmup_scheduler: Scheduler to use after warmup (optional) """ self.optimizer = optimizer self.warmup_epochs = warmup_epochs self.warmup_type = warmup_type self.warmup_start_factor = warmup_start_factor self.polynomial_power = polynomial_power self.post_warmup_scheduler = post_warmup_scheduler self.base_lrs = [pg['lr'] for pg in optimizer.param_groups] self.current_epoch = 0 # Validation if warmup_type == 'exponential' and warmup_start_factor <= 0: raise ValueError("Exponential warmup requires warmup_start_factor > 0") def _get_warmup_factor(self) -> float: """Compute LR multiplier for current epoch during warmup.""" progress = self.current_epoch / self.warmup_epochs if self.warmup_type == 'linear': return self.warmup_start_factor + (1 - self.warmup_start_factor) * progress elif self.warmup_type == 'exponential': return self.warmup_start_factor * (1 / self.warmup_start_factor) ** progress elif self.warmup_type == 'polynomial': return progress ** self.polynomial_power elif self.warmup_type == 'cosine': return (1 - math.cos(math.pi * progress / 2)) / 2 else: raise ValueError(f"Unknown warmup type: {self.warmup_type}") def get_lr(self) -> list: if self.current_epoch < self.warmup_epochs: factor = self._get_warmup_factor() return [base_lr * factor for base_lr in self.base_lrs] elif self.post_warmup_scheduler is not None: return [pg['lr'] for pg in self.optimizer.param_groups] else: return self.base_lrs def step(self, epoch: Optional[int] = None): if epoch is not None: self.current_epoch = epoch else: self.current_epoch += 1 if self.current_epoch < self.warmup_epochs: # Apply warmup lrs = self.get_lr() for param_group, lr in zip(self.optimizer.param_groups, lrs): param_group['lr'] = lr elif self.post_warmup_scheduler is not None: # Delegate to post-warmup scheduler self.post_warmup_scheduler.step() # If no post-warmup scheduler, LR stays at base_lr def state_dict(self): state = { 'current_epoch': self.current_epoch, 'warmup_epochs': self.warmup_epochs, 'warmup_type': self.warmup_type, 'base_lrs': self.base_lrs, } if self.post_warmup_scheduler is not None: state['post_warmup_state'] = self.post_warmup_scheduler.state_dict() return state def load_state_dict(self, state_dict): self.current_epoch = state_dict['current_epoch'] self.warmup_epochs = state_dict['warmup_epochs'] self.warmup_type = state_dict['warmup_type'] self.base_lrs = state_dict['base_lrs'] if 'post_warmup_state' in state_dict and self.post_warmup_scheduler: self.post_warmup_scheduler.load_state_dict(state_dict['post_warmup_state']) # =====================================================# Implementation 6: Step-Level vs Epoch-Level Warmup# =====================================================class StepLevelWarmup: """ Warmup that operates at step (iteration) level rather than epoch level. More fine-grained, especially useful for very short warmup periods or when epochs are very long. """ def __init__( self, optimizer, warmup_steps: int, warmup_type: str = 'linear', warmup_start_factor: float = 0.0 ): self.optimizer = optimizer self.warmup_steps = warmup_steps self.warmup_type = warmup_type self.warmup_start_factor = warmup_start_factor self.base_lrs = [pg['lr'] for pg in optimizer.param_groups] self.current_step = 0 self.warmup_complete = False def step(self): """Call after each optimizer.step(), not after epoch.""" if self.warmup_complete: return self.current_step += 1 if self.current_step >= self.warmup_steps: self.warmup_complete = True # Set to full base LR for param_group, lr in zip(self.optimizer.param_groups, self.base_lrs): param_group['lr'] = lr return # Compute warmup factor progress = self.current_step / self.warmup_steps if self.warmup_type == 'linear': factor = self.warmup_start_factor + (1 - self.warmup_start_factor) * progress elif self.warmup_type == 'cosine': factor = (1 - math.cos(math.pi * progress / 2)) / 2 else: factor = progress # Default to linear # Apply for param_group, lr in zip(self.optimizer.param_groups, self.base_lrs): param_group['lr'] = lr * factor # =====================================================# Visualization# =====================================================def visualize_warmup_curves(warmup_epochs: int = 20): """Compare different warmup curve shapes.""" import matplotlib.pyplot as plt epochs = np.arange(warmup_epochs + 1) base_lr = 0.1 start_factor = 0.01 # Linear linear = np.where( epochs <= warmup_epochs, base_lr * (start_factor + (1 - start_factor) * epochs / warmup_epochs), base_lr ) # Exponential exp = np.where( epochs <= warmup_epochs, base_lr * start_factor * (1 / start_factor) ** (epochs / warmup_epochs), base_lr ) # Square root (polynomial p=0.5) sqrt = np.where( epochs <= warmup_epochs, base_lr * np.sqrt(epochs / warmup_epochs), base_lr ) # Cosine cosine = np.where( epochs <= warmup_epochs, base_lr * (1 - np.cos(np.pi * epochs / (2 * warmup_epochs))) / 2, base_lr ) fig, ax = plt.subplots(figsize=(10, 6)) ax.plot(epochs, linear, label='Linear', linewidth=2) ax.plot(epochs, exp, label='Exponential', linewidth=2) ax.plot(epochs, sqrt, label='Square Root', linewidth=2) ax.plot(epochs, cosine, label='Cosine', linewidth=2) ax.axhline(y=base_lr, color='gray', linestyle='--', alpha=0.5, label='Target LR') ax.axvline(x=warmup_epochs, color='gray', linestyle=':', alpha=0.5) ax.set_xlabel('Epoch') ax.set_ylabel('Learning Rate') ax.set_title('Warmup Curve Comparison') ax.legend() ax.grid(True, alpha=0.3) return figDespite the variety of warmup curves, linear warmup works well in the vast majority of cases. The other curves exist for edge cases or fine-tuning. If you're not sure which to use, start with linear—you can always experiment with alternatives if training is unstable.
Selecting the right warmup duration is one of the most impactful hyperparameter decisions. Too short, and warmup fails to stabilize training; too long, and you waste compute on suboptimal learning rates.
Factors Affecting Warmup Duration:
Batch Size: Larger batches require longer warmup. The gradient noise reduction from large batches makes early training more sensitive to high LR.
Model Size: Larger models (more parameters) often benefit from longer warmup. More parameters mean more potential for gradient chaos.
Optimizer Choice: SGD with momentum is often more stable than Adam early in training; Adam may benefit from longer warmup.
Architecture: Transformers typically need more warmup than CNNs due to attention layer sensitivity.
Learning Rate: Higher target LR necessitates longer warmup to safely approach it.
General Guidelines:
| Scenario | Warmup Duration | As % of Training | Reasoning |
|---|---|---|---|
| Small CNN (CIFAR) | 0-5 epochs | 0-5% | Simple architecture, stable gradients |
| ResNet-50 (ImageNet) | 5-10 epochs | 5-10% | Deeper network benefits from stability |
| BERT pretraining | 10,000 steps | ~10% | Large model, high LR, critical setup |
| GPT training | 2,000-10,000 steps | Variable | Depends on sequence length, model size |
| Vision Transformer | 10-20 epochs | 5-10% | Attention layers need stabilization |
| Large batch (4096+) | Extended (10-20%) | 10-20% | Compensates for linear LR scaling stress |
The Batch Size Scaling Rule:
When using the linear scaling rule (LR proportional to batch size), warmup duration should also scale:
$$T_{warmup} \propto \sqrt{\text{batch_size} / \text{base_batch_size}}$$
Or more conservatively, scale warmup linearly with effective LR increase.
Practical Heuristics:
Transformers: Start with 10% of total training or 10,000 steps (whichever is shorter), adjust based on loss curves.
CNNs: 5% of total epochs is often sufficient, or even less for stable architectures.
Fine-tuning: 1-5 epochs or 1-2% of total training. Pretrained models are already in a good region.
Large batch: For each 2× increase in batch size with linear LR scaling, add 50-100% more warmup.
Diagnosing Warmup Issues:
Not all training requires warmup. Small models, small batches, low learning rates, and stable architectures may train fine without warmup. If you've never had early training instability, you can likely skip warmup. But when training large-scale models or using aggressive hyperparameters, warmup becomes essential.
Large batch training is where warmup transforms from 'nice to have' to 'absolutely essential.' Understanding this connection illuminates both why warmup matters and how to configure it for distributed training.
The Linear Scaling Rule:
The seminal insight from large batch training research: when multiplying batch size by $k$, multiply learning rate by $k$ to maintain equivalent training dynamics.
$$LR_{large} = k \cdot LR_{base}$$
This works because the gradient variance reduction from averaging more samples should be offset by larger step sizes. But this rule has limits.
The Breakdown at High LR:
At training start, the linear scaling rule fails because:
Curvature Effects: The loss landscape curvature isn't flat; high LR can overshoot even when gradient direction is correct.
Gradient Quality: Early gradients are noisy regardless of batch size; scaling LR amplifies this noise.
Optimizer State: Adaptive optimizer statistics need time to calibrate; immediate high LR causes wrong adaptations.
Warmup as the Solution:
Gradual warmup allows the model (and optimizer) to reach a region where the linear scaling rule holds:
$$LR(t) = k \cdot LR_{base} \cdot \min\left(1, \frac{t}{T_{warmup}}\right)$$
This has enabled training with batch sizes of 8,192, 16,384, and beyond—achieving near-linear scaling of training speed with batch size.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
import torchimport torch.distributed as distfrom typing import Optional class LargeBatchWarmup: """ Warmup schedule specifically designed for large batch training. Implements: - Linear LR scaling with batch size - Extended warmup proportional to LR increase - Gradient accumulation awareness - Distributed training compatibility """ def __init__( self, optimizer, base_lr: float, base_batch_size: int, actual_batch_size: int, warmup_epochs: int, total_epochs: int, gradient_accumulation_steps: int = 1, linear_scaling: bool = True, warmup_extension_factor: float = 1.0 # Increase for extra stability ): """ Args: optimizer: Wrapped optimizer base_lr: Learning rate for base_batch_size base_batch_size: Reference batch size (usually 256 for ImageNet) actual_batch_size: Actual per-GPU batch size warmup_epochs: Base warmup duration (will be extended for large batches) total_epochs: Total training epochs gradient_accumulation_steps: For effective batch size calculation linear_scaling: Whether to apply linear LR scaling warmup_extension_factor: Extra multiplier for warmup duration """ self.optimizer = optimizer self.base_lr = base_lr self.base_batch_size = base_batch_size self.gradient_accumulation_steps = gradient_accumulation_steps self.total_epochs = total_epochs # Compute world size for distributed training if dist.is_initialized(): world_size = dist.get_world_size() else: world_size = 1 # Effective batch size self.effective_batch_size = actual_batch_size * world_size * gradient_accumulation_steps # Compute scaling factor if linear_scaling: self.lr_scale = self.effective_batch_size / base_batch_size else: self.lr_scale = 1.0 self.target_lr = base_lr * self.lr_scale # Extended warmup for large batch # Rule: warmup ∝ sqrt(batch_scale) or linear with scale batch_ratio = self.effective_batch_size / base_batch_size if batch_ratio > 1: warmup_extension = batch_ratio ** 0.5 # sqrt scaling else: warmup_extension = 1.0 self.warmup_epochs = int(warmup_epochs * warmup_extension * warmup_extension_factor) self.current_epoch = 0 self.base_lrs = [pg['lr'] for pg in optimizer.param_groups] # Logging print(f"LargeBatchWarmup Configuration:") print(f" Effective batch size: {self.effective_batch_size}") print(f" LR scale factor: {self.lr_scale:.2f}x") print(f" Target LR: {self.target_lr:.6f}") print(f" Warmup epochs: {self.warmup_epochs}") def get_lr(self) -> float: if self.current_epoch < self.warmup_epochs: # Linear warmup to scaled target progress = self.current_epoch / self.warmup_epochs return self.base_lr + progress * (self.target_lr - self.base_lr) else: return self.target_lr def step(self, epoch: Optional[int] = None): if epoch is not None: self.current_epoch = epoch else: self.current_epoch += 1 lr = self.get_lr() for param_group in self.optimizer.param_groups: param_group['lr'] = lr # Log at transitions if self.current_epoch == self.warmup_epochs: print(f"Warmup complete at epoch {self.current_epoch}, LR = {lr:.6f}") def compute_optimal_warmup( batch_size: int, base_batch_size: int = 256, base_warmup_epochs: int = 5, warmup_strategy: str = 'sqrt' # 'sqrt', 'linear', 'log') -> int: """ Compute optimal warmup duration based on batch size scaling. Different strategies for different use cases: - 'sqrt': Conservative, standard choice - 'linear': Aggressive, for very large batches - 'log': Minimal extension, for stable architectures """ batch_ratio = batch_size / base_batch_size if batch_ratio <= 1: return base_warmup_epochs if warmup_strategy == 'sqrt': factor = batch_ratio ** 0.5 elif warmup_strategy == 'linear': factor = batch_ratio elif warmup_strategy == 'log': import math factor = 1 + math.log2(batch_ratio) else: factor = batch_ratio ** 0.5 # Default to sqrt return int(base_warmup_epochs * factor) # Example usagedef setup_large_batch_training( model, base_lr: float = 0.1, batch_size_per_gpu: int = 64, num_gpus: int = 8): """ Complete setup for large batch distributed training. """ optimizer = torch.optim.SGD( model.parameters(), lr=base_lr, # Will be adjusted by scheduler momentum=0.9, weight_decay=1e-4 ) scheduler = LargeBatchWarmup( optimizer=optimizer, base_lr=base_lr, base_batch_size=256, actual_batch_size=batch_size_per_gpu, warmup_epochs=5, total_epochs=90, gradient_accumulation_steps=1, linear_scaling=True ) return optimizer, schedulerEven with warmup, extremely large batches (>32K) can cause generalization issues—the model may overfit to batch-level statistics or converge to flatter, less expressive minima. Warmup mitigates early instability but doesn't solve all large-batch challenges. Monitor validation performance carefully.
Different optimizers have different warmup needs due to their internal mechanics. Understanding these interactions enables more effective warmup configuration.
SGD with Momentum:
Momentum maintains an exponential moving average of gradients: $$v_t = \beta v_{t-1} + \nabla L(\theta_t)$$ $$\theta_{t+1} = \theta_t - \eta \cdot v_t$$
Warmup helps because:
Moderate warmup (5-10% of training) usually suffices.
Adam and Variants:
Adam maintains both first and second moment estimates: $$m_t = \beta_1 m_{t-1} + (1-\beta_1) \nabla L$$ $$v_t = \beta_2 v_{t-1} + (1-\beta_2) (\nabla L)^2$$
The second moment $v_t$ critically affects step scaling:
Longer warmup (often 10%+ of training) benefits Adam, especially with high β₂ values.
AdaFactor and LAMB:
These optimizers are designed for large-scale training and include their own internal warmup mechanisms or stabilization. External warmup remains beneficial but may need less duration.
| Optimizer | Recommended Warmup | Key Consideration | Special Notes |
|---|---|---|---|
| SGD (no momentum) | 5% or less | Minimal internal state | Often works without warmup |
| SGD + Momentum | 5-10% | Momentum buffer stabilization | Standard choice for CNNs |
| Adam | 10%+ | Second moment calibration | May spike if undercalibrated |
| AdamW | 10%+ | Same as Adam | Weight decay separate from optimization |
| LAMB | 5-10% | Built-in layer-wise adaptation | Designed for large batch |
| AdaFactor | 5-10% | Reduced memory, β₂ factorization | Often used with transformers |
Adam's ε parameter affects stability during warmup. Default ε (1e-8) can be too small; increasing to 1e-6 or 1e-4 provides additional stability, especially for large models. This can complement warmup or even reduce required warmup duration.
Deploying warmup in production training pipelines requires attention to several practical considerations beyond the basic algorithm.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
import torchimport mathfrom torch.optim.lr_scheduler import _LRSchedulerfrom typing import Optional, Dict, Any class ProductionWarmupScheduler(_LRScheduler): """ Production-ready linear warmup with arbitrary post-warmup schedule. Features: - Smooth transition from warmup to main schedule - Per-step or per-epoch scheduling - Full state serialization for checkpointing - Gradient norm tracking integration - Automatic validation """ def __init__( self, optimizer, warmup_steps: int, total_steps: int, warmup_start_lr: float = 0.0, post_warmup_schedule: str = 'cosine', # 'cosine', 'linear', 'constant' min_lr: float = 0.0, last_epoch: int = -1 ): """ Args: optimizer: Wrapped optimizer warmup_steps: Number of warmup steps (not epochs) total_steps: Total training steps warmup_start_lr: LR at step 0 post_warmup_schedule: Schedule after warmup min_lr: Minimum LR for decay schedules last_epoch: For resuming """ self.warmup_steps = warmup_steps self.total_steps = total_steps self.warmup_start_lr = warmup_start_lr self.post_warmup_schedule = post_warmup_schedule self.min_lr = min_lr self._validate_config() super().__init__(optimizer, last_epoch) def _validate_config(self): """Validate configuration at initialization.""" if self.warmup_steps >= self.total_steps: raise ValueError( f"warmup_steps ({self.warmup_steps}) must be < " f"total_steps ({self.total_steps})" ) if self.warmup_steps < 0: raise ValueError("warmup_steps must be non-negative") if self.post_warmup_schedule not in ['cosine', 'linear', 'constant']: raise ValueError(f"Unknown schedule: {self.post_warmup_schedule}") warmup_ratio = self.warmup_steps / self.total_steps if warmup_ratio > 0.3: print(f"Warning: warmup_ratio ({warmup_ratio:.1%}) is quite high. " "Consider reducing warmup_steps.") def get_lr(self): step = self.last_epoch + 1 # last_epoch is actually last step here if step < self.warmup_steps: # Warmup phase: linear from warmup_start_lr to base_lr alpha = step / self.warmup_steps return [ self.warmup_start_lr + alpha * (base_lr - self.warmup_start_lr) for base_lr in self.base_lrs ] else: # Post-warmup phase post_warmup_step = step - self.warmup_steps post_warmup_total = self.total_steps - self.warmup_steps if self.post_warmup_schedule == 'constant': return list(self.base_lrs) elif self.post_warmup_schedule == 'linear': # Linear decay to min_lr progress = post_warmup_step / post_warmup_total return [ self.min_lr + (base_lr - self.min_lr) * (1 - progress) for base_lr in self.base_lrs ] elif self.post_warmup_schedule == 'cosine': # Cosine decay to min_lr progress = post_warmup_step / post_warmup_total cosine_factor = 0.5 * (1 + math.cos(math.pi * progress)) return [ self.min_lr + (base_lr - self.min_lr) * cosine_factor for base_lr in self.base_lrs ] else: return list(self.base_lrs) def state_dict(self) -> Dict[str, Any]: """Full state for checkpointing.""" return { 'warmup_steps': self.warmup_steps, 'total_steps': self.total_steps, 'warmup_start_lr': self.warmup_start_lr, 'post_warmup_schedule': self.post_warmup_schedule, 'min_lr': self.min_lr, 'base_lrs': self.base_lrs, 'last_epoch': self.last_epoch } def load_state_dict(self, state_dict: Dict[str, Any]): """Restore from checkpoint.""" self.warmup_steps = state_dict['warmup_steps'] self.total_steps = state_dict['total_steps'] self.warmup_start_lr = state_dict['warmup_start_lr'] self.post_warmup_schedule = state_dict['post_warmup_schedule'] self.min_lr = state_dict['min_lr'] self.base_lrs = state_dict['base_lrs'] self.last_epoch = state_dict['last_epoch'] def get_diagnostics(self) -> Dict[str, Any]: """Return diagnostic information for logging.""" step = self.last_epoch + 1 return { 'current_step': step, 'warmup_complete': step >= self.warmup_steps, 'warmup_progress': min(step / self.warmup_steps, 1.0) if self.warmup_steps > 0 else 1.0, 'overall_progress': step / self.total_steps, 'current_lr': self.get_last_lr()[0], 'schedule_phase': 'warmup' if step < self.warmup_steps else self.post_warmup_schedule } class GradientNormTracker: """ Companion utility to track gradient norms during warmup. Useful for diagnosing warmup effectiveness. """ def __init__(self, window_size: int = 100): self.norms = [] self.window_size = window_size def update(self, model) -> float: """Compute and store gradient norm.""" total_norm = 0.0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** 0.5 self.norms.append(total_norm) if len(self.norms) > self.window_size: self.norms.pop(0) return total_norm def get_statistics(self) -> Dict[str, float]: """Return gradient norm statistics.""" if not self.norms: return {'mean': 0.0, 'max': 0.0, 'min': 0.0, 'std': 0.0} import numpy as np return { 'mean': np.mean(self.norms), 'max': np.max(self.norms), 'min': np.min(self.norms), 'std': np.std(self.norms) } def is_stable(self, threshold: float = 2.0) -> bool: """ Check if gradients are stable (no major spikes). Returns True if max/mean ratio is below threshold. """ stats = self.get_statistics() if stats['mean'] < 1e-8: return False # Something is wrong if gradients are near-zero return stats['max'] / stats['mean'] < thresholdIf training resumes from a checkpoint taken during warmup, the scheduler must correctly restore its state. A common bug is reinitializing the scheduler from epoch 0, causing unexpected LR behavior. Always verify that scheduler.last_epoch matches the checkpoint epoch after loading.
Learning rate warmup transforms the chaotic early phase of neural network training into a stable foundation for effective optimization. By gradually increasing the learning rate, warmup gives model parameters, batch normalization statistics, and optimizer state time to stabilize before aggressive updates begin.
What's Next:
The final page of this module explores cyclical learning rates, which take a radically different approach by oscillating the learning rate throughout training. We'll see how this can improve exploration, enable snapshot ensembles, and sometimes outperform monotonically decreasing schedules.
You now understand why warmup is essential, how to implement various warmup curves, and how to tune warmup duration for your specific training scenario. Combined with the decay schedules from previous pages, you have a complete toolkit for learning rate scheduling.