Loading learning content...
Catastrophic forgetting is the Achilles' heel of neural network fine-tuning. It occurs when a model, trained on a new task, abruptly loses the ability to perform previous tasks—not gradually, but catastrophically.
Consider a model pre-trained on ImageNet (1000 classes) that you fine-tune for a medical imaging task (10 classes). After fine-tuning, you might have excellent performance on medical images, but if you test on ImageNet again, accuracy may have plummeted from 76% to 15%. The model hasn't just adapted—it has forgotten.
This phenomenon is particularly troubling for:
This page explores the mechanisms of catastrophic forgetting, diagnostic approaches, and mitigation strategies that preserve pre-trained knowledge while enabling target task learning.
By the end of this page, you will understand why catastrophic forgetting occurs at a mechanistic level, detect it through appropriate metrics, and apply mitigation strategies including EWC, knowledge distillation, and architectural approaches.
Why Does Forgetting Happen?
Neural networks store knowledge in a distributed fashion across millions of parameters. When you optimize for a new task, you modify these parameters to minimize loss on new data. The problem: there's no inherent mechanism to preserve performance on old data.
The Optimization Perspective:
During pre-training, the model finds parameters θ* that minimize source task loss: $$\theta^*{source} = \arg\min\theta \mathcal{L}_{source}(\theta)$$
During fine-tuning, optimization finds new parameters: $$\theta^*{target} = \arg\min\theta \mathcal{L}_{target}(\theta)$$
These optimization objectives can be conflicting. Weights that are critical for source task performance may be modified to improve target task performance.
The Representation Perspective:
Pre-trained representations encode knowledge about the source domain. Fine-tuning reshapes these representations for the target domain. If the reshaping is too aggressive, the original representational structure is lost.
| Factor | More Forgetting | Less Forgetting |
|---|---|---|
| Learning Rate | High LR | Low LR with warmup |
| Training Duration | Many epochs | Early stopping |
| Task Similarity | Very different tasks | Related tasks |
| Data Size | Large target data | Small target data |
| Model Capacity | Small model (forced reuse) | Large model (spare capacity) |
| Fine-tuning Strategy | Full fine-tuning | Selective/frozen layers |
This is an instance of the classic stability-plasticity dilemma in neural systems. Stability (retaining old knowledge) and plasticity (learning new knowledge) are inherently in tension. A perfectly stable network cannot learn; a perfectly plastic network cannot remember.
Before mitigating forgetting, we must measure it. Several metrics quantify the extent of knowledge loss.
Backward Transfer (BWT):
Measures performance change on previous tasks after learning new ones: $$BWT = \frac{1}{T-1} \sum_{i=1}^{T-1} (R_{T,i} - R_{i,i})$$
where R_{t,i} is performance on task i after training on task t. Negative BWT indicates forgetting.
Forgetting Measure (FM):
Maximum performance drop for each task: $$FM_i = \max_{t \in {1,...,T-1}} (R_{t,i} - R_{T,i})$$
Remembering (REM):
Percentage of original performance retained: $$REM = \frac{R_{after}}{R_{before}} \times 100%$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
import torchimport numpy as npfrom typing import Dict, List class ForgettingMetrics: """Metrics for measuring catastrophic forgetting.""" def __init__(self): self.performance_matrix = {} # R[t, i] = perf on task i after training on task t def record_performance(self, training_task: str, eval_task: str, performance: float): """Record performance on eval_task after training on training_task.""" if training_task not in self.performance_matrix: self.performance_matrix[training_task] = {} self.performance_matrix[training_task][eval_task] = performance def compute_backward_transfer(self, task_order: List[str]) -> float: """ Compute Backward Transfer (BWT). Negative values indicate forgetting. """ T = len(task_order) if T < 2: return 0.0 bwt_sum = 0.0 count = 0 final_task = task_order[-1] for i, task in enumerate(task_order[:-1]): R_final_i = self.performance_matrix.get(final_task, {}).get(task, 0) R_i_i = self.performance_matrix.get(task, {}).get(task, 0) bwt_sum += (R_final_i - R_i_i) count += 1 return bwt_sum / count if count > 0 else 0.0 def compute_remembering(self, task: str, before_task: str, after_task: str) -> float: """ Compute remembering percentage for a specific task. Args: task: The task to evaluate before_task: Training state before fine-tuning after_task: Training state after fine-tuning """ perf_before = self.performance_matrix.get(before_task, {}).get(task, 0) perf_after = self.performance_matrix.get(after_task, {}).get(task, 0) if perf_before == 0: return 100.0 return (perf_after / perf_before) * 100.0 def generate_report(self, task_order: List[str]) -> str: """Generate a forgetting analysis report.""" report = ["=" * 50, "FORGETTING ANALYSIS REPORT", "=" * 50, ""] bwt = self.compute_backward_transfer(task_order) report.append(f"Backward Transfer (BWT): {bwt:.4f}") report.append(" (Negative = forgetting, Positive = positive transfer)") report.append("") # Per-task remembering report.append("Per-Task Remembering:") for task in task_order[:-1]: rem = self.compute_remembering(task, task, task_order[-1]) report.append(f" {task}: {rem:.1f}%") return "".join(report) def evaluate_forgetting(model, source_loader, target_loader, criterion, device="cuda"): """ Evaluate model on both source and target tasks to measure forgetting. Returns: dict with 'source_acc', 'target_acc', 'remembering' """ model.eval() def evaluate_loader(loader): correct = 0 total = 0 with torch.no_grad(): for inputs, targets in loader: inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() return 100.0 * correct / total source_acc = evaluate_loader(source_loader) target_acc = evaluate_loader(target_loader) return { 'source_acc': source_acc, 'target_acc': target_acc }The most practical forgetting check: evaluate on a held-out set from the source domain before and after fine-tuning. If accuracy drops more than 5-10%, forgetting is significant. Always maintain a small validation set from the source domain for this purpose.
Elastic Weight Consolidation (EWC) is a principled approach that constrains important weights from changing too much during fine-tuning.
Key Insight: Not all weights are equally important for the source task. Some weights are critical—changing them dramatically hurts source performance. Other weights are relatively unimportant and can be freely modified.
The Method:
EWC adds a penalty term to the loss: $$\mathcal{L}{total} = \mathcal{L}{target}(\theta) + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta^*_i)^2$$
where:
Computing Fisher Information:
Fisher information measures how much the loss changes when a weight changes. High Fisher = important weight.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
import torchimport torch.nn as nnfrom copy import deepcopyfrom typing import Dict class EWC: """ Elastic Weight Consolidation for preventing catastrophic forgetting. Based on 'Overcoming catastrophic forgetting in neural networks' (Kirkpatrick et al., 2017) """ def __init__(self, model: nn.Module, dataloader, device="cuda"): self.model = model self.device = device # Store optimal weights for previous task self.optimal_weights = { name: param.clone().detach() for name, param in model.named_parameters() } # Compute Fisher information self.fisher = self._compute_fisher(dataloader) def _compute_fisher(self, dataloader) -> Dict[str, torch.Tensor]: """ Compute diagonal Fisher information matrix. Uses empirical Fisher: expectation of squared gradients. """ fisher = { name: torch.zeros_like(param) for name, param in self.model.named_parameters() } self.model.eval() for inputs, targets in dataloader: inputs, targets = inputs.to(self.device), targets.to(self.device) self.model.zero_grad() outputs = self.model(inputs) # Sample from output distribution log_probs = torch.log_softmax(outputs, dim=1) labels = torch.argmax(outputs, dim=1) loss = nn.functional.nll_loss(log_probs, labels) loss.backward() # Accumulate squared gradients for name, param in self.model.named_parameters(): if param.grad is not None: fisher[name] += param.grad.data.clone() ** 2 # Normalize by number of samples num_samples = len(dataloader.dataset) for name in fisher: fisher[name] /= num_samples return fisher def penalty(self) -> torch.Tensor: """ Compute EWC penalty term. Add this to your target task loss. """ loss = 0.0 for name, param in self.model.named_parameters(): if name in self.fisher: diff = param - self.optimal_weights[name] loss += (self.fisher[name] * diff ** 2).sum() return loss class EWCTrainer: """Trainer incorporating EWC regularization.""" def __init__( self, model: nn.Module, ewc: EWC, ewc_lambda: float = 1000, # EWC strength device: str = "cuda" ): self.model = model self.ewc = ewc self.ewc_lambda = ewc_lambda self.device = device def training_step(self, inputs, targets, optimizer, criterion): """Single training step with EWC regularization.""" inputs, targets = inputs.to(self.device), targets.to(self.device) optimizer.zero_grad() outputs = self.model(inputs) # Target task loss task_loss = criterion(outputs, targets) # EWC regularization ewc_loss = self.ewc.penalty() # Total loss total_loss = task_loss + self.ewc_lambda * ewc_loss total_loss.backward() optimizer.step() return { 'total_loss': total_loss.item(), 'task_loss': task_loss.item(), 'ewc_loss': ewc_loss.item() }λ controls the stability-plasticity trade-off. Too high: model can't adapt (underfits target). Too low: forgetting still occurs. Start with λ = 1000-10000 and tune based on the remembering metric. Higher λ for more forgetting prevention.
Knowledge Distillation uses the pre-trained model as a 'teacher' to guide the fine-tuned 'student' model. This preserves source task behavior while learning the target task.
The Method:
Maintain a frozen copy of the pre-trained model. During fine-tuning, add a distillation loss that encourages the student's outputs to match the teacher's:
$$\mathcal{L}{total} = \alpha \mathcal{L}{target} + (1-\alpha) \mathcal{L}_{distill}$$
where: $$\mathcal{L}_{distill} = KL(\sigma(z_s/T) || \sigma(z_t/T))$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom copy import deepcopy class DistillationFineTuning: """ Fine-tuning with knowledge distillation to prevent forgetting. The pre-trained model serves as teacher, providing soft targets that encode its learned knowledge. """ def __init__( self, student_model: nn.Module, temperature: float = 2.0, alpha: float = 0.5, device: str = "cuda" ): self.student = student_model.to(device) # Create frozen teacher from initial weights self.teacher = deepcopy(student_model) self.teacher.eval() for param in self.teacher.parameters(): param.requires_grad = False self.teacher = self.teacher.to(device) self.temperature = temperature self.alpha = alpha self.device = device def distillation_loss(self, student_logits, teacher_logits): """ Compute KL divergence between student and teacher distributions. Higher temperature produces softer distributions that transfer more knowledge about class relationships. """ student_soft = F.log_softmax(student_logits / self.temperature, dim=1) teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1) # KL divergence, scaled by T^2 as per Hinton et al. kl_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') return kl_loss * (self.temperature ** 2) def training_step(self, inputs, targets, optimizer, task_criterion): """Training step with distillation.""" inputs, targets = inputs.to(self.device), targets.to(self.device) optimizer.zero_grad() # Student forward student_logits = self.student(inputs) # Teacher forward (no gradients) with torch.no_grad(): teacher_logits = self.teacher(inputs) # Task loss (hard targets) task_loss = task_criterion(student_logits, targets) # Distillation loss (soft targets from teacher) distill_loss = self.distillation_loss(student_logits, teacher_logits) # Combined loss total_loss = self.alpha * task_loss + (1 - self.alpha) * distill_loss total_loss.backward() optimizer.step() return { 'total_loss': total_loss.item(), 'task_loss': task_loss.item(), 'distill_loss': distill_loss.item() } class FeatureDistillation(DistillationFineTuning): """ Extend distillation to intermediate features, not just outputs. More effective for preserving learned representations. """ def __init__( self, student_model: nn.Module, feature_layers: list, # Names of layers to match feature_weight: float = 0.1, **kwargs ): super().__init__(student_model, **kwargs) self.feature_layers = feature_layers self.feature_weight = feature_weight # Register hooks to capture features self.student_features = {} self.teacher_features = {} self._register_hooks() def _register_hooks(self): """Register forward hooks to capture intermediate features.""" def make_hook(storage, name): def hook(module, input, output): storage[name] = output return hook for name, module in self.student.named_modules(): if name in self.feature_layers: module.register_forward_hook(make_hook(self.student_features, name)) for name, module in self.teacher.named_modules(): if name in self.feature_layers: module.register_forward_hook(make_hook(self.teacher_features, name)) def feature_matching_loss(self): """Compute MSE between student and teacher features.""" loss = 0.0 for name in self.feature_layers: if name in self.student_features and name in self.teacher_features: s_feat = self.student_features[name] t_feat = self.teacher_features[name].detach() loss += F.mse_loss(s_feat, t_feat) return loss / len(self.feature_layers) if self.feature_layers else 0Temperature (T): Higher values (2-4) create softer targets that transfer more knowledge. Start with T=2. Alpha (α): Balance between task and distillation. For more retention, use α=0.3-0.4. For more adaptation, use α=0.7-0.8.
Beyond EWC and distillation, several practical strategies help prevent forgetting:
1. Learning Rate Control: Lower learning rates = less forgetting. Use discriminative LRs with very low rates for early layers.
2. Selective Fine-Tuning: Freeze layers not needed for the target task. Frozen layers can't forget.
3. Rehearsal/Replay: Mix source domain data into target training. Even 5-10% source data significantly reduces forgetting.
4. Early Stopping: Stop before full convergence on target task. Trade some target performance for retention.
5. Regularization: L2 regularization toward pre-trained weights (simpler than EWC). Dropout increases robustness.
| Scenario | Priority | Recommended Strategies |
|---|---|---|
| Source access available | High retention | Rehearsal + Low LR + Distillation |
| Source access unavailable | High retention | EWC + Selective freezing + Low LR |
| Computational constraints | Moderate retention | Selective freezing + Early stopping |
| Multi-task deployment | All tasks matter | Adapters + Distillation ensemble |
What's Next:
We've now covered the major challenges of fine-tuning. The final page explores regularization approaches—techniques like dropout, weight decay, and data augmentation that stabilize fine-tuning and improve generalization on the target task.
You now understand catastrophic forgetting: why it occurs, how to measure it, and strategies from EWC to knowledge distillation for mitigating it. This prepares you for the regularization approaches covered next.