Loading learning content...
In the previous page, we established that catastrophic forgetting occurs because gradient descent updates weights without regard for previously learned tasks. The regularization approach to continual learning addresses this by adding constraints to the learning process itself—penalizing changes to parameters that were important for old tasks while allowing free modification of less critical parameters.\n\nThis philosophy is elegant in its simplicity: not all parameters contribute equally to a task's performance. Some weights are critical—small changes cause large performance drops. Others are incidental—they can vary substantially without affecting outputs. If we can identify which weights are important and protect them during new learning, we can achieve both stability (for important weights) and plasticity (for unimportant weights).\n\nThis page provides a rigorous treatment of the major regularization approaches, their mathematical foundations, implementation details, and comparative analysis.
By the end of this page, you will understand the mathematical derivation of Elastic Weight Consolidation (EWC), implement importance-weighted regularization from scratch, compare EWC, Synaptic Intelligence (SI), Memory Aware Synapses (MAS), and Learning without Forgetting (LwF), and know when to apply each method based on your use case constraints.
All regularization approaches share a common framework. Given a neural network with parameters $\theta$ trained on task $T_A$ (resulting in parameters $\theta_A^$), we want to train on new task $T_B$ while preventing forgetting of $T_A$.\n\nStandard Training Loss (No Protection):\n\n$$\mathcal{L}(\theta) = \mathcal{L}_B(\theta)$$\n\nThis optimizes purely for task B, ignoring task A entirely—leading to catastrophic forgetting.\n\nRegularization-Protected Loss:\n\n$$\mathcal{L}(\theta) = \mathcal{L}B(\theta) + \frac{\lambda}{2} \sum_i \Omega_i (\theta_i - \theta{A,i}^)^2$$\n\nHere:\n- $\mathcal{L}B(\theta)$ is the loss for the new task\n- $\theta{A,i}^*$ is the optimal value of parameter $i$ after task A\n- $\Omega_i$ is the importance of parameter $i$ for task A\n- $\lambda$ is a hyperparameter controlling regularization strength\n- The sum is over all parameters in the network\n\nThis formulation penalizes deviations from the optimal task-A configuration, weighted by each parameter's importance. High-importance parameters incur large penalties for change; low-importance parameters can change freely.
The entire challenge of regularization-based continual learning reduces to: How do we compute $\Omega_i$? Different methods (EWC, SI, MAS, etc.) propose different ways to estimate parameter importance. The effectiveness of the approach depends critically on the quality of these importance estimates.
Why Quadratic Penalty?\n\nThe squared term $(\theta_i - \theta_{A,i}^)^2$ creates a quadratic penalty centered at the old optimal. This has several desirable properties:\n\n1. Smoothness: The gradient is proportional to deviation, providing smooth training dynamics\n2. Symmetry: Equally penalizes positive and negative deviations\n3. Convexity: The regularization term is convex, preventing complex optimization landscapes\n4. Efficient computation: Simple to compute and differentiate\n\nThe gradient of the regularization term is:\n\n$$\nabla_\theta \mathcal{R} = \lambda \sum_i \Omega_i (\theta_i - \theta_{A,i}^)$$\n\nThis acts as an elastic force pulling parameters back toward their old-task optimal values, with strength proportional to importance.
Elastic Weight Consolidation (EWC), introduced by Kirkpatrick et al. (2017) at DeepMind, was a landmark contribution that made continual learning practical for deep neural networks. EWC derives parameter importance from a principled Bayesian perspective, using the Fisher Information Matrix to measure sensitivity.\n\nThe Bayesian Framework:\n\nFrom a Bayesian viewpoint, we want to find parameters $\theta$ that are probable given all data seen so far. Using Bayes' rule:\n\n$$p(\theta | D_A, D_B) = \frac{p(D_B | \theta) p(\theta | D_A)}{p(D_B | D_A)}$$\n\nwhere $D_A$ and $D_B$ are data from tasks A and B.\n\nTaking the log and ignoring the normalizing constant:\n\n$$\log p(\theta | D_A, D_B) = \log p(D_B | \theta) + \log p(\theta | D_A) + \text{const}$$\n\nThe first term is the likelihood for task B (standard training objective). The second term is the posterior from task A—encoding what we learned from that task.
The Laplace Approximation:\n\nThe posterior $p(\theta | D_A)$ is intractable to compute exactly for deep networks. EWC approximates it as a Gaussian centered at the task-A optimal $\theta_A^$:\n\n$$p(\theta | D_A) \approx \mathcal{N}(\theta_A^, F_A^{-1})$$\n\nwhere $F_A$ is the Fisher Information Matrix computed at $\theta_A^$.\n\nThe Fisher Information Matrix measures the curvature of the loss landscape—how much the likelihood changes when parameters change:\n\n$$F = \mathbb{E}{p(x)}\left[ \nabla\theta \log p(y|x,\theta) \nabla_\theta \log p(y|x,\theta)^T \right]$$\n\nParameters in directions of high curvature (large Fisher Information) are important—small changes significantly affect predictions. Parameters in flat directions are unimportant.\n\nThe EWC Loss:\n\nUsing the Laplace approximation and taking the diagonal of $F$ (for computational tractability), EWC adds:\n\n$$\mathcal{L}(\theta) = \mathcal{L}B(\theta) + \frac{\lambda}{2} \sum_i F{A,ii} (\theta_i - \theta_{A,i}^)^2$$\n\nwhere $F_{A,ii}$ is the diagonal of the Fisher Information Matrix for task A.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoaderfrom typing import Dict, List, Tupleimport copy class EWC: """ Elastic Weight Consolidation implementation. EWC protects important parameters by penalizing changes proportional to their Fisher Information, which measures parameter sensitivity. Mathematical formulation: L(θ) = L_new(θ) + (λ/2) Σᵢ Fᵢ(θᵢ - θ*ᵢ)² where F is the diagonal of the Fisher Information Matrix. """ def __init__(self, model: nn.Module, lambda_ewc: float = 1000): """ Args: model: The neural network model lambda_ewc: Regularization strength (higher = more protection) """ self.model = model self.lambda_ewc = lambda_ewc # Storage for task-specific parameters and Fisher information self.saved_params: Dict[str, torch.Tensor] = {} self.fisher_diag: Dict[str, torch.Tensor] = {} def compute_fisher( self, dataloader: DataLoader, num_samples: int = 2000 ) -> None: """ Compute the diagonal of the Fisher Information Matrix. The Fisher Information measures how much the log-likelihood changes when parameters change. Higher Fisher = more important parameter. Mathematical definition: F = E[∇log p(y|x,θ) ∇log p(y|x,θ)ᵀ] Args: dataloader: DataLoader for the completed task num_samples: Number of samples for Fisher estimation """ self.model.eval() # Initialize Fisher to zeros fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters() if p.requires_grad} samples_seen = 0 for inputs, targets in dataloader: if samples_seen >= num_samples: break inputs = inputs.to(next(self.model.parameters()).device) batch_size = inputs.size(0) # Forward pass outputs = self.model(inputs) # For classification: sample from output distribution # This is the key insight - we use the model's own predictions probs = F.softmax(outputs, dim=1) sampled_labels = probs.multinomial(1).squeeze() # Compute log-likelihood for sampled labels log_probs = F.log_softmax(outputs, dim=1) log_likelihood = log_probs.gather(1, sampled_labels.unsqueeze(1)) # Compute gradients for each sample for i in range(min(batch_size, num_samples - samples_seen)): self.model.zero_grad() log_likelihood[i].backward(retain_graph=(i < batch_size - 1)) # Accumulate squared gradients (diagonal of outer product) for n, p in self.model.named_parameters(): if p.requires_grad and p.grad is not None: fisher[n] += p.grad.data.clone() ** 2 samples_seen += 1 # Normalize by number of samples for n in fisher: fisher[n] /= samples_seen self.fisher_diag = fisher # Save current parameters as the optimal for this task self.saved_params = { n: p.data.clone() for n, p in self.model.named_parameters() if p.requires_grad } self.model.train() def penalty(self) -> torch.Tensor: """ Compute the EWC penalty term. Returns: Scalar tensor: (λ/2) Σᵢ Fᵢ(θᵢ - θ*ᵢ)² """ if not self.fisher_diag: return torch.tensor(0.0) loss = torch.tensor(0.0).to(next(self.model.parameters()).device) for n, p in self.model.named_parameters(): if n in self.fisher_diag: # Importance-weighted quadratic penalty fisher = self.fisher_diag[n] optimal = self.saved_params[n] loss += (fisher * (p - optimal) ** 2).sum() return (self.lambda_ewc / 2) * loss def training_step( self, inputs: torch.Tensor, targets: torch.Tensor, criterion: nn.Module, optimizer: torch.optim.Optimizer ) -> float: """ Perform one EWC training step. Returns: Combined loss value """ optimizer.zero_grad() # Forward pass outputs = self.model(inputs) # Task loss + EWC penalty task_loss = criterion(outputs, targets) ewc_loss = self.penalty() total_loss = task_loss + ewc_loss # Backward and optimize total_loss.backward() optimizer.step() return total_loss.item() # Example usagedef demonstrate_ewc(): """Demonstrate EWC on sequential tasks.""" # Simple MLP for demonstration model = nn.Sequential( nn.Flatten(), nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10) ) ewc = EWC(model, lambda_ewc=1000) # After training on Task A: # ewc.compute_fisher(task_a_dataloader) # Training on Task B with EWC protection: # for inputs, targets in task_b_dataloader: # loss = ewc.training_step(inputs, targets, criterion, optimizer) print("EWC ready for continual learning!") demonstrate_ewc()When computing the Fisher Information, EWC samples labels from the model's own output distribution rather than using ground truth labels. This is mathematically important: we're measuring the curvature of the likelihood function, which depends on the model's predictions, not the data labels. This subtle detail significantly affects performance.
The original EWC formulation has a significant scalability problem: it requires storing the Fisher Information and optimal parameters for every previous task. For $n$ tasks, this means $n$ times the model size in storage.\n\nThe Memory Problem:\n\nWith standard EWC, the loss function becomes:\n\n$$\mathcal{L}(\theta) = \mathcal{L}{\text{new}}(\theta) + \frac{\lambda}{2} \sum{k=1}^{n-1} \sum_i F_{k,ii} (\theta_i - \theta_{k,i}^)^2$$\n\nThis sums over all previous tasks $k$, requiring storage of $F_k$ and $\theta_k^$ for each.\n\nOnline EWC Solution:\n\nSchwarz et al. (2018) proposed Online EWC, which maintains a running estimate of importance that grows in a single storage buffer. Instead of keeping separate Fisher matrices, Online EWC uses exponential moving averages:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
import torchimport torch.nn as nnfrom typing import Dict class OnlineEWC: """ Online Elastic Weight Consolidation. Maintains a running Fisher estimate and reference point, avoiding linear storage growth with number of tasks. Update rules: F_cumulative = γ * F_old + F_new θ* = θ_current (updated after each task) This approach trades some precision for constant memory usage. """ def __init__( self, model: nn.Module, lambda_ewc: float = 1000, gamma: float = 0.9 ): """ Args: model: Neural network model lambda_ewc: Regularization strength gamma: Decay factor for old Fisher (0 = only new, 1 = cumulative) """ self.model = model self.lambda_ewc = lambda_ewc self.gamma = gamma self.cumulative_fisher: Dict[str, torch.Tensor] = {} self.reference_params: Dict[str, torch.Tensor] = {} def update_fisher(self, new_fisher: Dict[str, torch.Tensor]) -> None: """ Update cumulative Fisher with new task's Fisher. Uses exponential moving average: F_cumulative = γ * F_old + F_new """ if not self.cumulative_fisher: # First task: just store the Fisher self.cumulative_fisher = {k: v.clone() for k, v in new_fisher.items()} else: # Subsequent tasks: weighted combination for n in new_fisher: self.cumulative_fisher[n] = ( self.gamma * self.cumulative_fisher[n] + new_fisher[n] ) def update_reference(self) -> None: """Update reference parameters to current values.""" self.reference_params = { n: p.data.clone() for n, p in self.model.named_parameters() if p.requires_grad } def consolidate(self, new_fisher: Dict[str, torch.Tensor]) -> None: """ Consolidate knowledge after completing a task. Call this after training on each task: 1. Update Fisher with new task's importance 2. Store current parameters as reference point """ self.update_fisher(new_fisher) self.update_reference() def penalty(self) -> torch.Tensor: """ Compute Online EWC penalty. Same formulation as standard EWC but uses cumulative Fisher and single reference point. """ if not self.cumulative_fisher: return torch.tensor(0.0) device = next(self.model.parameters()).device loss = torch.tensor(0.0).to(device) for n, p in self.model.named_parameters(): if n in self.cumulative_fisher: fisher = self.cumulative_fisher[n] ref = self.reference_params[n] loss += (fisher * (p - ref) ** 2).sum() return (self.lambda_ewc / 2) * loss class ProgressiveEWC: """ Progressive Network variant with EWC for forward transfer. Freezes old task columns while allowing new capacity. Uses EWC only for shared representations. """ def __init__(self, base_model: nn.Module, lambda_ewc: float = 1000): self.base_model = base_model self.lambda_ewc = lambda_ewc self.task_columns = [base_model] self.ewc_modules = [OnlineEWC(base_model, lambda_ewc)] def add_task_column(self, new_column: nn.Module) -> None: """Add a new column for a new task with lateral connections.""" # Freeze previous columns for param in self.task_columns[-1].parameters(): param.requires_grad = False self.task_columns.append(new_column) self.ewc_modules.append(OnlineEWC(new_column, self.lambda_ewc)) def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: """Forward through appropriate column.""" return self.task_columns[task_id](x)The decay factor γ controls how much old importance persists. γ = 1 gives equal weight to all past tasks (cumulative Fisher). γ < 1 weights recent tasks more heavily, effectively 'forgetting' very old importance estimates. Values between 0.8 and 0.95 typically work well in practice.
Synaptic Intelligence (SI), introduced by Zenke et al. (2017), offers an alternative to EWC's Fisher Information approach. Instead of computing importance after training completes, SI accumulates importance during training by tracking how much each parameter contributed to loss reduction.\n\nThe Key Insight:\n\nA parameter is important if changing it significantly affected the loss. SI measures this by tracking the integral:\n\n$$\omega_k = \sum_t \left( \frac{\partial \mathcal{L}}{\partial \theta_k} \right)_t \Delta \theta_k^t$$\n\nwhere the sum is over all training steps. This computes the total 'work' done by each parameter in reducing the loss—parameters that moved a lot in the direction of the gradient did more work and are hence more important.\n\nNormalization and Regularization:\n\nTo avoid scale issues, SI normalizes by the total change in each parameter:\n\n$$\Omega_k = \frac{\omega_k}{(\Delta \theta_k)^2 + \xi}$$\n\nwhere $\Delta \theta_k$ is the total change in parameter $k$ during training and $\xi$ is a damping constant for numerical stability.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
import torchimport torch.nn as nnfrom typing import Dict, Optionalimport copy class SynapticIntelligence: """ Synaptic Intelligence for continual learning. SI measures parameter importance by tracking the contribution of each parameter to the loss reduction during training. Key equations: ω_k = Σ_t (∂L/∂θ_k)_t * Δθ_k^t (online importance) Ω_k = ω_k / ((Δθ_k)² + ξ) (normalized importance) Reference: Zenke et al., "Continual Learning Through Synaptic Intelligence" """ def __init__( self, model: nn.Module, c: float = 0.1, # Regularization strength xi: float = 1e-3 # Damping constant ): """ Args: model: Neural network model c: Coefficient for importance regularization (like λ in EWC) xi: Small constant for numerical stability """ self.model = model self.c = c self.xi = xi # Parameter tracking self.params: Dict[str, torch.Tensor] = {} self.prev_params: Dict[str, torch.Tensor] = {} # Importance accumulator (ω in the paper) self.omega_accumulator: Dict[str, torch.Tensor] = {} # Consolidated importance (Ω in the paper) self.importance: Dict[str, torch.Tensor] = {} # Reference parameters from previous task self.reference_params: Dict[str, torch.Tensor] = {} self._initialize_tracking() def _initialize_tracking(self) -> None: """Initialize parameter tracking dictionaries.""" for n, p in self.model.named_parameters(): if p.requires_grad: self.params[n] = p self.prev_params[n] = p.data.clone() self.omega_accumulator[n] = torch.zeros_like(p.data) def update_omega(self) -> None: """ Update importance accumulator after each training step. Called after optimizer.step() to track: ω_k += -gradient_k * Δθ_k Note: We use negative gradient because parameter update is in the direction of -gradient, so contribution to loss reduction is -gradient * Δθ. """ for n, p in self.model.named_parameters(): if n in self.omega_accumulator and p.grad is not None: # Parameter change in this step delta = p.data - self.prev_params[n] # Accumulate contribution: -gradient * delta # Gradient points toward increasing loss, so contribution # to decreasing loss is in the negative gradient direction self.omega_accumulator[n] += (-p.grad.data * delta).clamp(min=0) # Update previous parameter value self.prev_params[n] = p.data.clone() def consolidate(self) -> None: """ Consolidate importance after completing a task. Computes normalized importance: Ω_k = ω_k / ((Δθ_k)² + ξ) And adds to cumulative importance from previous tasks. """ for n, p in self.model.named_parameters(): if n in self.omega_accumulator: # Total parameter change during this task delta_squared = (p.data - self.reference_params.get( n, self.prev_params[n] )) ** 2 # Compute normalized importance for this task task_importance = self.omega_accumulator[n] / (delta_squared + self.xi) # Add to cumulative importance if n in self.importance: self.importance[n] += task_importance else: self.importance[n] = task_importance # Reset accumulator for next task self.omega_accumulator[n].zero_() # Update reference parameters to current values self.reference_params = { n: p.data.clone() for n, p in self.model.named_parameters() if p.requires_grad } def penalty(self) -> torch.Tensor: """ Compute SI regularization penalty. Returns: c * Σ_k Ω_k * (θ_k - θ*_k)² """ if not self.importance: return torch.tensor(0.0) device = next(self.model.parameters()).device loss = torch.tensor(0.0).to(device) for n, p in self.model.named_parameters(): if n in self.importance: ref = self.reference_params[n] loss += (self.importance[n] * (p - ref) ** 2).sum() return self.c * loss def training_step( self, inputs: torch.Tensor, targets: torch.Tensor, criterion: nn.Module, optimizer: torch.optim.Optimizer ) -> float: """ Single SI training step with importance tracking. IMPORTANT: Call update_omega() after optimizer.step()! """ optimizer.zero_grad() outputs = self.model(inputs) task_loss = criterion(outputs, targets) si_loss = self.penalty() total_loss = task_loss + si_loss total_loss.backward() optimizer.step() # CRITICAL: Update omega after parameter update self.update_omega() return total_loss.item() def compare_ewc_si(): """Side-by-side comparison of importance computation.""" print("EWC vs SI Importance Computation") print("=" * 50) print() print("EWC (Fisher Information):") print(" - Computed AFTER training on task") print(" - Based on: curvature of likelihood") print(" - Formula: F = E[∇log p(y|x)²]") print(" - Requires: separate pass over data") print() print("SI (Path Integral):") print(" - Computed DURING training") print(" - Based on: contribution to loss reduction") print(" - Formula: ω = Σ (-∇L) · Δθ") print(" - Requires: tracking at each step") compare_ewc_si()Memory Aware Synapses (MAS), introduced by Aljundi et al. (2018), takes yet another approach to computing parameter importance. Unlike EWC (which uses Fisher Information based on log-likelihood) and SI (which tracks loss reduction), MAS measures how much the network's outputs change when parameters change—regardless of labels.\n\nThe Output Sensitivity Approach:\n\nMAS importance is based on:\n\n$$\Omega_k = \frac{1}{N} \sum_{n=1}^{N} \left\| \frac{\partial F(x_n)}{\partial \theta_k} \right\|$$\n\nwhere $F(x)$ is the network's output function and the sum is over $N$ data points.\n\nThis measures: If I change parameter $k$, how much does the output change? Parameters that significantly affect outputs are important; those that don't matter can safely change.
A key advantage of MAS is that it doesn't require labels at all. The importance is based purely on output sensitivity, making MAS applicable in unsupervised and self-supervised settings where EWC's likelihood-based approach doesn't apply directly.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
import torchimport torch.nn as nnfrom torch.utils.data import DataLoaderfrom typing import Dict class MemoryAwareSynapses: """ Memory Aware Synapses (MAS) for continual learning. MAS computes importance based on output sensitivity: how much the network output changes when parameters change. Key equation: Ω_k = (1/N) Σ_n || ∂F(x_n)/∂θ_k || This is label-free, making MAS suitable for unsupervised learning. Reference: Aljundi et al., "Memory Aware Synapses" """ def __init__( self, model: nn.Module, lambda_reg: float = 1.0 ): """ Args: model: Neural network model lambda_reg: Regularization strength """ self.model = model self.lambda_reg = lambda_reg self.importance: Dict[str, torch.Tensor] = {} self.reference_params: Dict[str, torch.Tensor] = {} def compute_importance( self, dataloader: DataLoader, num_samples: int = 2000 ) -> None: """ Compute MAS importance weights. For each parameter, we compute the average gradient magnitude of the output with respect to that parameter. Args: dataloader: Data to compute importance over num_samples: Maximum samples to use """ self.model.eval() # Initialize importance accumulators importance = { n: torch.zeros_like(p) for n, p in self.model.named_parameters() if p.requires_grad } samples_seen = 0 for inputs, _ in dataloader: # Labels not needed! if samples_seen >= num_samples: break inputs = inputs.to(next(self.model.parameters()).device) batch_size = inputs.size(0) # Forward pass outputs = self.model(inputs) # Compute gradient of output magnitude w.r.t. parameters # We use L2 norm of output as a scalar measure output_norm = outputs.norm(dim=-1).sum() self.model.zero_grad() output_norm.backward() # Accumulate absolute gradients for n, p in self.model.named_parameters(): if p.requires_grad and p.grad is not None: importance[n] += p.grad.data.abs() samples_seen += batch_size # Normalize by number of samples for n in importance: importance[n] /= samples_seen # Update cumulative importance if not self.importance: self.importance = importance else: for n in importance: self.importance[n] += importance[n] # Save reference parameters self.reference_params = { n: p.data.clone() for n, p in self.model.named_parameters() if p.requires_grad } self.model.train() def penalty(self) -> torch.Tensor: """ Compute MAS regularization penalty. Returns: λ * Σ_k Ω_k * (θ_k - θ*_k)² """ if not self.importance: return torch.tensor(0.0) device = next(self.model.parameters()).device loss = torch.tensor(0.0).to(device) for n, p in self.model.named_parameters(): if n in self.importance: ref = self.reference_params[n] loss += (self.importance[n] * (p - ref) ** 2).sum() return self.lambda_reg * loss class GradientEpisodicMemory: """ Gradient Episodic Memory (GEM) - a constraint-based approach. Instead of soft regularization, GEM projects gradients to ensure they don't increase loss on previous tasks. This is a fundamentally different approach: rather than penalizing weight changes, GEM directly modifies gradients. """ def __init__( self, model: nn.Module, memory_per_task: int = 256, margin: float = 0.5 ): self.model = model self.memory_per_task = memory_per_task self.margin = margin self.episodic_memory: Dict[int, tuple] = {} # task_id -> (inputs, targets) self.reference_gradients: Dict[int, Dict[str, torch.Tensor]] = {} def store_memory(self, task_id: int, inputs: torch.Tensor, targets: torch.Tensor): """Store exemplars for gradient projection.""" self.episodic_memory[task_id] = (inputs, targets) def project_gradient(self, current_grad: Dict[str, torch.Tensor]): """ Project current gradient to not conflict with previous tasks. Uses quadratic programming to find closest gradient that doesn't increase loss on any previous task. """ # Implement gradient projection logic # This modifies gradients in-place passLearning without Forgetting (LwF), proposed by Li and Hoiem (2016), takes a fundamentally different approach from parameter-based regularization. Instead of protecting parameters, LwF uses knowledge distillation to preserve the network's behavior on old tasks.\n\nThe Distillation Approach:\n\nBefore training on a new task, LwF:\n1. Records the network's outputs on the new task's data using the old model\n2. Uses these 'soft targets' as additional supervision during new task training\n\nThe loss function becomes:\n\n$$\mathcal{L} = \mathcal{L}{\text{new}}(y, \hat{y}{\text{new}}) + \lambda \mathcal{L}{\text{distill}}(\hat{y}{\text{old}}, \tilde{y}{\text{old}})$$\n\nwhere:\n- $\hat{y}{\text{new}}$ is the prediction for the new task\n- $\hat{y}{\text{old}}$ is the current model's prediction on old task outputs\n- $\tilde{y}{\text{old}}$ is the frozen old model's prediction (recorded before training)
LwF is essentially self-distillation: the old model acts as a 'teacher' for the new model on previous tasks. The soft targets from the teacher contain richer information than hard labels—they encode relative class confidences and inter-class relationships. This helps preserve nuanced knowledge that hard labels would lose.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Optional, Dictimport copy class LearningWithoutForgetting: """ Learning without Forgetting (LwF) for continual learning. LwF preserves behavior on old tasks through knowledge distillation. Before training on a new task, it records the model's outputs and uses them as soft targets during training. Key insight: Soft targets (logits/probabilities) contain more information than hard labels, preserving inter-class relationships. Loss = L_new(input, label) + λ * L_distill(current_output_old, frozen_output_old) """ def __init__( self, model: nn.Module, temperature: float = 2.0, lambda_distill: float = 1.0 ): """ Args: model: Neural network model temperature: Softmax temperature for distillation (higher = softer) lambda_distill: Weight for distillation loss """ self.model = model self.temperature = temperature self.lambda_distill = lambda_distill self.old_model: Optional[nn.Module] = None self.old_task_output_dim = 0 def prepare_for_new_task(self) -> None: """ Prepare for training on a new task. Creates a frozen copy of the current model to serve as the 'teacher' for knowledge distillation. """ # Create frozen copy of current model self.old_model = copy.deepcopy(self.model) self.old_model.eval() # Freeze all parameters for param in self.old_model.parameters(): param.requires_grad = False # Track output dimension for old tasks # (Assumes model has an 'output_dim' or similar attribute) self.old_task_output_dim = self._get_output_dim() def _get_output_dim(self) -> int: """Get output dimension of the model.""" # Implementation depends on model architecture for module in self.model.modules(): if isinstance(module, nn.Linear): last_linear = module return last_linear.out_features def distillation_loss( self, current_logits: torch.Tensor, old_logits: torch.Tensor ) -> torch.Tensor: """ Compute knowledge distillation loss. Uses softmax with temperature to create soft targets, then computes KL divergence. Temperature > 1 creates softer probability distributions, which transfers more information about relative confidences. """ T = self.temperature # Soft probabilities from old model (teacher) soft_targets = F.softmax(old_logits / T, dim=1) # Log-soft probabilities from current model (student) log_soft_current = F.log_softmax(current_logits / T, dim=1) # KL divergence loss (scaled by T²) # The T² scaling is important for gradient magnitudes loss = F.kl_div( log_soft_current, soft_targets, reduction='batchmean' ) * (T ** 2) return loss def compute_loss( self, inputs: torch.Tensor, targets: torch.Tensor, criterion: nn.Module ) -> torch.Tensor: """ Compute combined LwF loss. Returns: L_task + λ * L_distillation """ # Forward through current model outputs = self.model(inputs) # Task loss for new task (only on new output heads) task_loss = criterion(outputs, targets) # Distillation loss for old tasks if self.old_model is not None: with torch.no_grad(): old_outputs = self.old_model(inputs) # Only distill on old task output dimensions current_old = outputs[:, :self.old_task_output_dim] distill_loss = self.distillation_loss(current_old, old_outputs) total_loss = task_loss + self.lambda_distill * distill_loss else: total_loss = task_loss return total_loss class LwFMultiTask: """ LwF variant for class-incremental learning with shared head. In class-incremental learning, new classes are added over time. This changes the output layer architecture, requiring careful handling of distillation. """ def __init__( self, feature_extractor: nn.Module, initial_classes: int, temperature: float = 2.0, lambda_distill: float = 1.0 ): self.feature_extractor = feature_extractor self.temperature = temperature self.lambda_distill = lambda_distill # Get feature dimension with torch.no_grad(): dummy = torch.randn(1, 3, 32, 32) feat_dim = feature_extractor(dummy).shape[1] # Output layer (grows with new classes) self.classifier = nn.Linear(feat_dim, initial_classes) self.num_classes = initial_classes # Frozen copy for distillation self.old_feature_extractor = None self.old_classifier = None self.old_num_classes = 0 def add_classes(self, new_classes: int) -> None: """ Expand classifier for new classes while preserving old weights. """ # Store old model self.old_feature_extractor = copy.deepcopy(self.feature_extractor) self.old_classifier = copy.deepcopy(self.classifier) self.old_num_classes = self.num_classes # Freeze old model for p in self.old_feature_extractor.parameters(): p.requires_grad = False for p in self.old_classifier.parameters(): p.requires_grad = False # Create expanded classifier old_weight = self.classifier.weight.data old_bias = self.classifier.bias.data new_total = self.num_classes + new_classes new_classifier = nn.Linear(old_weight.shape[1], new_total) # Copy old weights new_classifier.weight.data[:self.num_classes] = old_weight new_classifier.bias.data[:self.num_classes] = old_bias self.classifier = new_classifier self.num_classes = new_total def forward(self, x: torch.Tensor) -> torch.Tensor: features = self.feature_extractor(x) return self.classifier(features) def compute_loss( self, inputs: torch.Tensor, targets: torch.Tensor, criterion: nn.Module ) -> torch.Tensor: """Combined task and distillation loss.""" # Current model predictions features = self.feature_extractor(inputs) outputs = self.classifier(features) # Task loss task_loss = criterion(outputs, targets) # Distillation loss on old class outputs if self.old_feature_extractor is not None: with torch.no_grad(): old_features = self.old_feature_extractor(inputs) old_outputs = self.old_classifier(old_features) current_old = outputs[:, :self.old_num_classes] distill_loss = F.kl_div( F.log_softmax(current_old / self.temperature, dim=1), F.softmax(old_outputs / self.temperature, dim=1), reduction='batchmean' ) * (self.temperature ** 2) return task_loss + self.lambda_distill * distill_loss return task_lossHigher temperatures (T > 1) produce softer probability distributions, which contain more information about inter-class similarities. T = 2-4 typically works well for continual learning. Too high a temperature can wash out important distinctions; too low approaches hard labels and loses the distillation benefit.
We've covered four major regularization approaches: EWC, SI, MAS, and LwF. Each has distinct characteristics making them suitable for different scenarios. Let's synthesize this understanding.
| Method | Importance Basis | When Computed | Memory | Label Required | Best For |
|---|---|---|---|---|---|
| EWC | Fisher Information (likelihood curvature) | After training | O(params) | Yes | Well-defined tasks with labeled data |
| SI | Contribution to loss reduction | During training | O(params) | Yes | When training trajectory matters |
| MAS | Output sensitivity | After training | O(params) | No | Unsupervised/self-supervised learning |
| LwF | Output matching (distillation) | Before new task | O(model) for old model | No (for old tasks) | Same-domain task incremental |
When to Choose Each Method:\n\nChoose EWC when:\n- You have well-labeled data with clear task boundaries\n- The Bayesian interpretation is appealing for your use case\n- You want a principled, theoretically grounded approach\n- Fisher computation is computationally feasible\n\nChoose SI when:\n- You prefer online importance tracking without separate passes\n- Training dynamics and trajectory are important\n- You want a biologically-inspired approach\n- Tasks may have variable training lengths\n\nChoose MAS when:\n- You work with unsupervised or self-supervised learning\n- Labels are unavailable or unreliable for old tasks\n- You want purely output-based importance\n- Model outputs are the key quantity to preserve\n\nChoose LwF when:\n- Tasks share similar input distributions\n- You can afford to store a frozen model copy\n- You want behavior-level preservation, not weight-level\n- You're doing class-incremental learning
State-of-the-art approaches often combine regularization with other techniques. For example, regularization + small replay buffer, or regularization + task-specific output heads. Pure regularization has limits—it cannot perfectly preserve information when weight capacity is insufficient.
Regularization approaches are elegant but have fundamental limitations that practitioners must understand:
Practical Recommendations:\n\n1. Start with EWC or SI as baselines—they're well-understood and relatively easy to tune\n\n2. Use validation data from old tasks (if available) to tune λ rather than guessing\n\n3. Combine with replay for best results—even a small buffer helps significantly\n\n4. Monitor per-task accuracy throughout training, not just average accuracy\n\n5. Consider network capacity—larger networks with more parameters have more room for multiple tasks\n\n6. Test on realistic task counts—many papers only show 5-10 tasks; production may need 50+
Regularization alone cannot overcome the fundamental capacity limit. A network with fixed parameters can only encode so much information. For truly long-lived continual learning (100+ tasks), architectural solutions (next page) become necessary.
We have explored the family of regularization-based approaches to continual learning in depth. Let's consolidate the key insights:
What's Next:\n\nIn the next page, we explore replay methods—approaches that maintain a memory of past experiences and rehearse them alongside new learning. Replay addresses some limitations of pure regularization by providing direct supervision from old tasks during new task training.
You now have deep understanding of regularization-based continual learning: the mathematical foundations of EWC, SI, MAS, and LwF, their implementations, and when to apply each. This knowledge enables you to select and tune these methods for practical applications while understanding their inherent limitations.