Loading content...
While step decay creates distinct training phases with sharp transitions, many real-world optimization landscapes evolve continuously rather than in discrete jumps. Exponential decay addresses this by smoothly reducing the learning rate at every step or epoch, creating a gradual transition from aggressive exploration to precise refinement.
The philosophical distinction is subtle but important: step decay assumes training naturally separates into phases, while exponential decay assumes the optimal learning rate decreases continuously as training progresses. Neither assumption is universally correct—the choice depends on your specific problem, architecture, and training dynamics.
This page explores exponential decay in depth: its mathematical formulation, theoretical justifications, implementation nuances, and the practical scenarios where it outperforms discrete alternatives. By the end, you'll understand not just how to implement exponential decay, but when and why to choose it.
By the end of this page, you will understand the mathematical formulations of exponential decay (continuous and discrete), implement multiple variants with proper production considerations, recognize scenarios where exponential decay outperforms step-based approaches, and tune decay rate hyperparameters effectively.
Exponential decay admits two equivalent mathematical formulations that reveal different aspects of its behavior. Understanding both perspectives deepens intuition and guides hyperparameter selection.
Multiplicative Form (Discrete):
$$\eta_t = \eta_{t-1} \cdot \gamma = \eta_0 \cdot \gamma^t$$
Where:
This form shows that each step multiplies the previous learning rate by a fixed factor. If $\gamma = 0.99$, each step reduces the learning rate by 1%.
Continuous Form (Half-Life Parameterization):
$$\eta_t = \eta_0 \cdot e^{-\lambda t}$$
Where:
This form connects to radioactive decay and other natural exponential processes. The half-life (time for LR to halve) is:
$$t_{1/2} = \frac{\ln(2)}{\lambda} = \frac{-\ln(2)}{\ln(\gamma)}$$
| γ (Multiplicative) | λ (Decay Rate) | Half-Life (epochs) | LR after 100 epochs |
|---|---|---|---|
| 0.99 | 0.01005 | 69 epochs | 36.6% of initial |
| 0.995 | 0.00501 | 138 epochs | 60.6% of initial |
| 0.999 | 0.001 | 693 epochs | 90.5% of initial |
| 0.95 | 0.0513 | 13.5 epochs | 0.59% of initial |
| 0.9 | 0.1054 | 6.6 epochs | 0.003% of initial |
A critical decision: should decay apply per step (iteration) or per epoch? Per-epoch decay is more common and easier to tune, but per-step provides smoother curves. If applying per-step, your gamma should be much closer to 1 (e.g., γ = 0.9999 per step vs. γ = 0.95 per epoch).
Target-Based Formulation:
Often it's more intuitive to specify the final learning rate and let the decay rate be computed:
$$\gamma = \left(\frac{\eta_{\text{final}}}{\eta_0}\right)^{1/T}$$
Where $T$ is the total number of epochs. This ensures the learning rate reaches exactly $\eta_{\text{final}}$ at the end of training.
Example: To decay from $\eta_0 = 0.1$ to $\eta_{\text{final}} = 0.0001$ over 100 epochs: $$\gamma = \left(\frac{0.0001}{0.1}\right)^{1/100} = (0.001)^{0.01} \approx 0.9313$$
This target-based formulation is highly practical: you specify the endpoints and the schedule is automatically computed.
Exponential decay's effectiveness traces to fundamental properties of optimization dynamics and stochastic gradient descent. Understanding these foundations transforms hyperparameter selection from trial-and-error into principled reasoning.
Optimal Convergence Rate Theory:
For strongly convex optimization problems, theoretical results establish that the optimal learning rate sequence for SGD satisfies:
$$\eta_t \propto \frac{1}{t^\alpha}$$
for some $\alpha \in (0.5, 1)$. Polynomial decay is theoretically optimal, but exponential decay provides a practical approximation that's easier to tune and often performs comparably.
The key insight: as training progresses, the signal-to-noise ratio of gradients decreases. Early gradients strongly point toward the optimum; late gradients increasingly reflect random noise. Reducing the learning rate attenuates this noise amplification.
SGD Noise Variance Scaling:
For SGD with learning rate $\eta$ and batch size $B$, the variance of weight updates scales as:
$$\text{Var}(\Delta\theta) \propto \frac{\eta^2 \sigma^2}{B}$$
where $\sigma^2$ is the gradient variance across samples. Exponential decay continuously reduces this variance, enabling increasingly precise convergence.
The Continuity Advantage:
Step decay's sharp transitions can cause temporary training instability. When the learning rate drops by 10×, the effective batch size (in terms of optimization dynamics) shifts dramatically. This can manifest as:
Exponential decay avoids these discontinuity effects. The smooth curve means no single step creates dramatic dynamics changes, promoting stability throughout training.
Think of exponential decay as maintaining a constant 'relative change rate.' A 1% reduction per epoch means the same proportional change whether the learning rate is 0.1 or 0.0001. This scale-invariance aligns with the logarithmic nature of optimization dynamics.
Exponential decay admits several implementation variants, each with distinct properties. Understanding these variants enables selecting the right tool for each situation.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
import numpy as npimport torchfrom torch.optim.lr_scheduler import ExponentialLR, LambdaLRfrom typing import Callable, Optional # =====================================================# Implementation 1: PyTorch Native ExponentialLR# =====================================================def create_exponential_scheduler( optimizer, gamma: float = 0.99, last_epoch: int = -1): """ Standard PyTorch exponential decay. LR(t) = LR(0) * gamma^t Args: optimizer: Wrapped optimizer gamma: Multiplicative factor (0 < gamma < 1) last_epoch: For resuming training Example: gamma=0.99 means 1% reduction per epoch """ return ExponentialLR(optimizer, gamma=gamma, last_epoch=last_epoch) # =====================================================# Implementation 2: Target-Based Exponential Decay# =====================================================class TargetExponentialDecay: """ Specify initial and final LR; decay rate computed automatically. More intuitive than raw gamma: "I want LR to go from 0.1 to 0.0001 over 100 epochs" rather than "I want gamma = 0.9313". """ def __init__( self, optimizer, initial_lr: float, final_lr: float, total_epochs: int, last_epoch: int = -1 ): self.optimizer = optimizer self.initial_lr = initial_lr self.final_lr = final_lr self.total_epochs = total_epochs # Compute gamma from endpoints # final_lr = initial_lr * gamma^total_epochs # gamma = (final_lr / initial_lr)^(1/total_epochs) self.gamma = (final_lr / initial_lr) ** (1.0 / total_epochs) self.current_epoch = last_epoch # Validate parameters if not (0 < self.gamma < 1): raise ValueError( f"Invalid decay: from {initial_lr} to {final_lr} in " f"{total_epochs} epochs gives gamma={self.gamma:.6f}" ) print(f"TargetExponentialDecay: gamma={self.gamma:.6f}, " f"half-life={-np.log(2)/np.log(self.gamma):.1f} epochs") def get_lr(self) -> float: if self.current_epoch < 0: return self.initial_lr return self.initial_lr * (self.gamma ** self.current_epoch) def step(self, epoch: Optional[int] = None): if epoch is not None: self.current_epoch = epoch else: self.current_epoch += 1 new_lr = self.get_lr() for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr def state_dict(self): return { 'initial_lr': self.initial_lr, 'final_lr': self.final_lr, 'total_epochs': self.total_epochs, 'gamma': self.gamma, 'current_epoch': self.current_epoch } def load_state_dict(self, state_dict): self.initial_lr = state_dict['initial_lr'] self.final_lr = state_dict['final_lr'] self.total_epochs = state_dict['total_epochs'] self.gamma = state_dict['gamma'] self.current_epoch = state_dict['current_epoch'] # =====================================================# Implementation 3: Per-Step Exponential Decay# =====================================================class PerStepExponentialDecay: """ Decay learning rate every step (iteration) rather than epoch. Provides smoother decay curve at the cost of more frequent updates. Gamma should be much closer to 1 (e.g., 0.9999 vs 0.95 for per-epoch). """ def __init__( self, optimizer, initial_lr: float, final_lr: float, total_steps: int ): self.optimizer = optimizer self.initial_lr = initial_lr self.final_lr = final_lr self.total_steps = total_steps # Per-step gamma for smooth transition self.gamma = (final_lr / initial_lr) ** (1.0 / total_steps) self.current_step = 0 print(f"PerStepExponentialDecay: gamma={self.gamma:.8f} " f"(applied {total_steps} times)") def step(self): """Call after each optimizer.step(), not after epoch.""" self.current_step += 1 new_lr = self.initial_lr * (self.gamma ** self.current_step) for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr def get_lr(self) -> float: return self.initial_lr * (self.gamma ** self.current_step) # =====================================================# Implementation 4: Exponential Decay with Warmup# =====================================================class WarmupExponentialDecay: """ Linear warmup followed by exponential decay. Warmup prevents unstable early training (especially with large batches or transformer architectures). """ def __init__( self, optimizer, base_lr: float, warmup_epochs: int, decay_epochs: int, final_lr: float, warmup_start_lr: float = 1e-7 ): self.optimizer = optimizer self.base_lr = base_lr self.warmup_epochs = warmup_epochs self.decay_epochs = decay_epochs self.final_lr = final_lr self.warmup_start_lr = warmup_start_lr # Compute decay gamma for post-warmup period self.gamma = (final_lr / base_lr) ** (1.0 / decay_epochs) self.current_epoch = 0 self.total_epochs = warmup_epochs + decay_epochs def get_lr(self) -> float: if self.current_epoch < self.warmup_epochs: # Linear warmup alpha = self.current_epoch / self.warmup_epochs return self.warmup_start_lr + alpha * (self.base_lr - self.warmup_start_lr) else: # Exponential decay decay_step = self.current_epoch - self.warmup_epochs return self.base_lr * (self.gamma ** decay_step) def step(self, epoch: Optional[int] = None): if epoch is not None: self.current_epoch = epoch else: self.current_epoch += 1 new_lr = self.get_lr() for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr # =====================================================# Implementation 5: Lambda-Based Flexible Schedule# =====================================================def create_lambda_exponential( optimizer, gamma: float = 0.99, warmup_epochs: int = 0, min_lr_ratio: float = 0.001, last_epoch: int = -1): """ Use LambdaLR for maximum flexibility. LambdaLR takes a function that maps epoch -> lr_multiplier. This enables combining exponential decay with warmup and floors. """ def lr_lambda(epoch: int) -> float: if epoch < warmup_epochs: # Linear warmup return epoch / warmup_epochs else: # Exponential decay with floor decay_factor = gamma ** (epoch - warmup_epochs) return max(decay_factor, min_lr_ratio) return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) # =====================================================# Implementation 6: Staircase Exponential (Hybrid)# =====================================================class StaircaseExponential: """ Apply exponential decay but update LR only at fixed intervals. Combines smoothness of exponential with interpretability of step. LR is computed exponentially but applied in discrete steps. Example: Compute LR exponentially, but update every 5 epochs. """ def __init__( self, optimizer, initial_lr: float, gamma: float, staircase_interval: int = 10 ): self.optimizer = optimizer self.initial_lr = initial_lr self.gamma = gamma self.staircase_interval = staircase_interval self.current_epoch = 0 self.last_applied_lr = initial_lr def get_lr(self) -> float: # Compute exponential value at nearest staircase step staircase_epoch = (self.current_epoch // self.staircase_interval) * self.staircase_interval return self.initial_lr * (self.gamma ** staircase_epoch) def step(self, epoch: Optional[int] = None): if epoch is not None: self.current_epoch = epoch else: self.current_epoch += 1 new_lr = self.get_lr() # Only log when LR actually changes if abs(new_lr - self.last_applied_lr) > 1e-10: print(f"Epoch {self.current_epoch}: LR -> {new_lr:.6f}") self.last_applied_lr = new_lr for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr # =====================================================# Visualization and Analysis Tools# =====================================================def compare_exponential_schedules(total_epochs: int = 100): """ Visualize different exponential decay configurations. """ import matplotlib.pyplot as plt epochs = np.arange(total_epochs) initial_lr = 0.1 schedules = { 'γ=0.99 (Slow)': initial_lr * (0.99 ** epochs), 'γ=0.95 (Medium)': initial_lr * (0.95 ** epochs), 'γ=0.9 (Fast)': initial_lr * (0.9 ** epochs), 'Target: 0.1→0.001': initial_lr * ((0.001/0.1)**(epochs/total_epochs)), } fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Linear scale ax1 = axes[0] for name, lr_curve in schedules.items(): ax1.plot(epochs, lr_curve, label=name, linewidth=2) ax1.set_xlabel('Epoch') ax1.set_ylabel('Learning Rate') ax1.set_title('Exponential Decay Schedules (Linear Scale)') ax1.legend() ax1.grid(True, alpha=0.3) # Log scale (shows exponential is linear) ax2 = axes[1] for name, lr_curve in schedules.items(): ax2.semilogy(epochs, lr_curve, label=name, linewidth=2) ax2.set_xlabel('Epoch') ax2.set_ylabel('Learning Rate (log scale)') ax2.set_title('Exponential Decay Schedules (Log Scale)') ax2.legend() ax2.grid(True, alpha=0.3) plt.tight_layout() return fig def analyze_decay_rate( initial_lr: float, final_lr: float, total_epochs: int) -> dict: """ Analyze properties of exponential decay for given parameters. """ gamma = (final_lr / initial_lr) ** (1.0 / total_epochs) half_life = -np.log(2) / np.log(gamma) # Compute LR at key points key_epochs = [0, total_epochs//4, total_epochs//2, 3*total_epochs//4, total_epochs-1] lr_values = { epoch: initial_lr * (gamma ** epoch) for epoch in key_epochs } return { 'gamma': gamma, 'lambda': -np.log(gamma), 'half_life_epochs': half_life, 'decay_per_epoch_pct': (1 - gamma) * 100, 'lr_at_epochs': lr_values, 'total_decay_factor': final_lr / initial_lr }Exponential decay appears as a straight line on a log-scale plot. This provides a quick visual diagnostic: if your LR curve isn't linear on log scale, you're not using pure exponential decay. This also makes exponential decay's cumulative effect easy to predict.
The decay rate (γ or λ) is the central hyperparameter of exponential decay. Selecting it well requires understanding both the mathematical implications and the training-specific considerations.
The Half-Life Perspective:
The most intuitive way to reason about decay rate is through half-life: the number of epochs for the learning rate to halve.
$$t_{1/2} = \frac{\ln(2)}{\lambda} = -\frac{\ln(2)}{\ln(\gamma)}$$
Practical Guidelines:
Short training (50-100 epochs): Half-life of 15-30 epochs. This ensures meaningful decay while preserving exploration early.
Medium training (100-200 epochs): Half-life of 30-60 epochs. More gradual decay to sustain optimization over the longer horizon.
Long training (200+ epochs): Half-life of 60-100+ epochs. Very gentle decay that maintains learning capability throughout.
The Final LR Criterion:
Often the most practical approach is to specify the desired final learning rate and compute γ accordingly:
| Training Duration | Recommended Half-Life | Example γ (per epoch) | Final LR Ratio |
|---|---|---|---|
| 50 epochs | 15-20 epochs | 0.95 - 0.965 | 1/10 to 1/20 |
| 100 epochs | 25-35 epochs | 0.97 - 0.98 | 1/50 to 1/200 |
| 200 epochs | 50-70 epochs | 0.985 - 0.99 | 1/100 to 1/500 |
| 500 epochs | 100-150 epochs | 0.993 - 0.995 | 1/500 to 1/2000 |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
import numpy as np def calculate_decay_params( initial_lr: float, final_lr: float, total_epochs: int) -> dict: """ Calculate all decay parameters from initial/final LR goals. This is the recommended approach: specify endpoints, compute decay. """ # Compute gamma gamma = (final_lr / initial_lr) ** (1.0 / total_epochs) # Compute lambda (continuous decay rate) lambda_rate = -np.log(gamma) # Compute half-life half_life = np.log(2) / lambda_rate # Decay per epoch in percentage decay_per_epoch = (1 - gamma) * 100 # Time to reach 1% of initial (useful benchmark) time_to_1pct = np.log(0.01) / np.log(gamma) return { 'gamma': gamma, 'lambda': lambda_rate, 'half_life_epochs': half_life, 'decay_per_epoch_pct': decay_per_epoch, 'epochs_to_1pct': time_to_1pct, 'summary': f"LR halves every {half_life:.1f} epochs, " f"decays {decay_per_epoch:.2f}% per epoch" } def gamma_from_half_life(half_life_epochs: float) -> float: """ Compute gamma that achieves specified half-life. """ return np.exp(-np.log(2) / half_life_epochs) def validate_schedule( initial_lr: float, gamma: float, total_epochs: int, min_acceptable_lr: float = 1e-8) -> dict: """ Validate an exponential schedule for common issues. """ final_lr = initial_lr * (gamma ** total_epochs) half_life = -np.log(2) / np.log(gamma) issues = [] if final_lr < min_acceptable_lr: issues.append( f"Final LR ({final_lr:.2e}) below minimum threshold " f"({min_acceptable_lr}). May cause numerical issues." ) if half_life > total_epochs: issues.append( f"Half-life ({half_life:.1f}) exceeds total epochs ({total_epochs}). " "LR won't reduce by even 50%." ) if half_life < total_epochs / 10: issues.append( f"Half-life ({half_life:.1f}) very short relative to training. " "LR may decay too quickly." ) # Check for reasonable number of halvings num_halvings = total_epochs / half_life if num_halvings < 1: issues.append("Less than one full halving during training - consider faster decay.") elif num_halvings > 10: issues.append("More than 10 halvings - LR may become negligibly small.") return { 'valid': len(issues) == 0, 'issues': issues, 'final_lr': final_lr, 'half_life': half_life, 'num_halvings': num_halvings } # Interactive explorationif __name__ == "__main__": # Example: Design schedule for 150-epoch training result = calculate_decay_params( initial_lr=0.1, final_lr=0.0001, # 1/1000 of initial total_epochs=150 ) print("Schedule Design:") for key, value in result.items(): print(f" {key}: {value}") # Validation validation = validate_schedule( initial_lr=0.1, gamma=result['gamma'], total_epochs=150 ) print("Validation:") print(f" Valid: {validation['valid']}") for issue in validation['issues']: print(f" ⚠️ {issue}")Aggressive exponential decay can reduce learning rates to numerically insignificant values (< 1e-8), causing training to effectively stall while still consuming compute. Always validate that your final LR is above a reasonable threshold, or implement a minimum LR floor.
The choice between exponential and step decay isn't arbitrary—each has distinct properties that favor different training scenarios. Understanding these differences enables principled schedule selection.
Core Difference: Continuous vs. Discrete Change
Step decay: LR stays constant for many epochs, then drops sharply. Exponential decay: LR changes smoothly every epoch.
This fundamental difference propagates through training dynamics in subtle but important ways.
| Aspect | Exponential Decay | Step Decay |
|---|---|---|
| Transition smoothness | Continuous, no discontinuities | Sharp phase transitions |
| Interpretability | Harder to reason about phases | Clear phase boundaries |
| Hyperparameter count | One: γ or λ | Two: γ and step_size |
| Stability near transitions | No transition-related instability | May spike after LR drops |
| Debugging ease | Moderate (gradual changes) | Easier (discrete phases) |
| Batch norm compatibility | Better (no sudden shifts) | May need recalibration |
| Momentum buffer scaling | Gradual adjustment | Sudden rescaling needed |
| Empirical performance | Comparable on average | Strong baselines exist |
Empirical Equivalence Region:
For many problems, carefully tuned exponential and step decay achieve similar final performance. The differences emerge in:
The Practitioner's Rule: Start with whatever schedule has published baselines for your problem. If no baselines exist, both are reasonable starting points—exponential if stability matters, step if interpretability matters.
Some practitioners use 'staircase exponential': compute LR exponentially but update only every N epochs. This combines exponential's cumulative smooth behavior with step decay's discrete interpretability. PyTorch doesn't have this built-in, but it's easy to implement with LambdaLR.
Production deployment of exponential decay requires attention to several advanced considerations that academic treatments often overlook.
Interaction with Adaptive Optimizers:
Adaptive optimizers (Adam, AdaGrad, RMSprop) maintain per-parameter learning rate scales. Exponential decay of the base learning rate compounds with these adaptive scales:
$$\text{Effective LR}_{i,t} = \frac{\eta_0 \cdot \gamma^t}{\sqrt{v_i} + \epsilon}$$
where $v_i$ is the second moment estimate for parameter $i$. This means:
Recommendation for Adam: Use gentler exponential decay (larger γ) compared to SGD, as Adam's adaptivity already provides automatic per-parameter adjustment.
Distributed Training Considerations:
In distributed training with data parallelism, all workers should use identical scheduler state. Synchronization strategies:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
import torchimport torch.distributed as distfrom typing import Optionalimport json class ProductionExponentialDecay: """ Production-ready exponential decay with: - Minimum LR floor - Warmup support - Checkpointing - Distributed training compatibility - Logging hooks """ def __init__( self, optimizer, initial_lr: float, gamma: float, min_lr: float = 1e-7, warmup_epochs: int = 0, warmup_start_lr: float = 1e-7, log_lr_every: int = 1, distributed: bool = False ): self.optimizer = optimizer self.initial_lr = initial_lr self.gamma = gamma self.min_lr = min_lr self.warmup_epochs = warmup_epochs self.warmup_start_lr = warmup_start_lr self.log_lr_every = log_lr_every self.distributed = distributed self.current_epoch = 0 self.current_lr = warmup_start_lr if warmup_epochs > 0 else initial_lr # Apply initial LR self._apply_lr(self.current_lr) # Compute schedule statistics self._log_schedule_info() def _log_schedule_info(self): """Log schedule configuration for debugging.""" half_life = -np.log(2) / np.log(self.gamma) print(f"ExponentialDecay Schedule:") print(f" Initial LR: {self.initial_lr}") print(f" Gamma: {self.gamma}") print(f" Half-life: {half_life:.1f} epochs") print(f" Min LR: {self.min_lr}") print(f" Warmup: {self.warmup_epochs} epochs") def _compute_lr(self, epoch: int) -> float: """Compute learning rate for given epoch.""" if epoch < self.warmup_epochs: # Linear warmup alpha = epoch / self.warmup_epochs return self.warmup_start_lr + alpha * (self.initial_lr - self.warmup_start_lr) else: # Exponential decay with floor decay_epoch = epoch - self.warmup_epochs lr = self.initial_lr * (self.gamma ** decay_epoch) return max(lr, self.min_lr) def _apply_lr(self, lr: float): """Apply learning rate to all parameter groups.""" for param_group in self.optimizer.param_groups: param_group['lr'] = lr def step(self, epoch: Optional[int] = None): """ Advance scheduler by one epoch. In distributed mode, ensures all ranks use same LR. """ if epoch is not None: self.current_epoch = epoch else: self.current_epoch += 1 self.current_lr = self._compute_lr(self.current_epoch) # In distributed mode, broadcast LR from rank 0 if self.distributed and dist.is_initialized(): lr_tensor = torch.tensor([self.current_lr], dtype=torch.float32) dist.broadcast(lr_tensor, src=0) self.current_lr = lr_tensor.item() self._apply_lr(self.current_lr) # Logging if self.current_epoch % self.log_lr_every == 0: print(f"Epoch {self.current_epoch}: LR = {self.current_lr:.2e}") def get_lr(self) -> float: return self.current_lr def state_dict(self) -> dict: """Return scheduler state for checkpointing.""" return { 'current_epoch': self.current_epoch, 'current_lr': self.current_lr, 'initial_lr': self.initial_lr, 'gamma': self.gamma, 'min_lr': self.min_lr, 'warmup_epochs': self.warmup_epochs, 'warmup_start_lr': self.warmup_start_lr, } def load_state_dict(self, state_dict: dict): """Load scheduler state from checkpoint.""" self.current_epoch = state_dict['current_epoch'] self.current_lr = state_dict['current_lr'] self.initial_lr = state_dict['initial_lr'] self.gamma = state_dict['gamma'] self.min_lr = state_dict['min_lr'] self.warmup_epochs = state_dict['warmup_epochs'] self.warmup_start_lr = state_dict['warmup_start_lr'] # Apply loaded LR self._apply_lr(self.current_lr) print(f"Loaded scheduler: epoch={self.current_epoch}, lr={self.current_lr:.2e}") def save_training_state( model, optimizer, scheduler, epoch: int, loss: float, path: str): """Proper checkpoint saving with scheduler state.""" torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss, }, path) print(f"Checkpoint saved: {path}") def load_training_state( model, optimizer, scheduler, path: str) -> int: """Load checkpoint and return starting epoch.""" checkpoint = torch.load(path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) start_epoch = checkpoint['epoch'] + 1 print(f"Resumed from epoch {checkpoint['epoch']}, loss={checkpoint['loss']:.4f}") return start_epochExponential decay provides a smooth, continuous alternative to step-based learning rate scheduling. Its gradual reduction matches many real-world optimization dynamics and provides stability benefits in sensitive training scenarios.
What's Next:
The next page explores cosine annealing, which combines the benefits of smooth decay with principled restart mechanisms. We'll see how cosine schedules have become the default choice for modern transformer training.
You now understand exponential decay from mathematical foundations through production implementation. This continuous scheduling approach complements your step decay knowledge and provides an alternative for scenarios where smoothness matters.