Loading content...
In the previous page, we explored regularization approaches that protect important parameters. While elegant, these methods face a fundamental limitation: they only indirectly preserve old task knowledge through weight constraints. As tasks accumulate, the constraints become so restrictive that new learning becomes impossible.\n\nReplay methods take a fundamentally different approach: instead of restricting how weights can change, they ensure that old task knowledge is directly rehearsed during new task training. By mixing old and new examples in each training batch, replay maintains a consistent learning signal for all tasks.\n\nThis approach has strong biological precedent. During sleep, the hippocampus replays recent experiences to the neocortex, gradually consolidating episodic memories into more stable semantic knowledge. This memory consolidation process is thought to be crucial for biological continual learning—and replay methods implement a computational analog.
By the end of this page, you will understand experience replay with exemplar selection strategies, generative replay using learned generators, the trade-offs between true replay and pseudo-replay, modern hybrid approaches combining replay with regularization, and practical considerations for memory-efficient continual learning.
The fundamental insight of replay is simple: if you don't want to forget something, keep practicing it. In neural network terms, this means including samples from previous tasks in each training batch.\n\nFormal Setup:\n\nGiven a sequence of tasks $T_1, T_2, \ldots, T_n$ with corresponding datasets $D_1, D_2, \ldots, D_n$, standard sequential training optimizes on each dataset independently:\n\n$$\theta_t = \arg\min_\theta \mathcal{L}(\theta; D_t)$$\n\nReplay-augmented training constructs a mixed dataset at each task:\n\n$$\theta_t = \arg\min_\theta \mathcal{L}(\theta; D_t \cup M_t)$$\n\nwhere $M_t$ is a memory buffer containing samples from tasks $T_1, \ldots, T_{t-1}$.\n\nThe key questions become:\n1. What to store in $M_t$? (exemplar selection)\n2. How much to store? (memory budget)\n3. How to sample during training? (replay strategy)\n4. Can we avoid storing real data? (generative replay)
Replay directly addresses the root cause of forgetting. Regularization methods try to prevent weight changes that would hurt old tasks; replay provides actual gradients from old tasks that actively maintain performance. It's the difference between saying 'don't move away from the old solution' versus 'here are examples of what the old solution needs to do.'
The Ideal Baseline: Training on All Data\n\nThe theoretical gold standard is training on all data simultaneously:\n\n$$\theta^* = \arg\min_\theta \sum_{t=1}^{n} \mathcal{L}(\theta; D_t)$$\n\nThis is called offline or joint training—it's what we would do if all data were available at once. Replay methods attempt to approximate this ideal with limited memory:\n\n$$\mathcal{L}{\text{replay}} \approx \mathcal{L}{\text{joint}}$$\n\nThe quality of this approximation depends on how well the memory buffer $M$ represents the full data distribution.
Experience Replay (ER) stores a subset of real examples from previous tasks in a fixed-size memory buffer. During training on new tasks, these exemplars are sampled and mixed with current task data.\n\nMemory Buffer Management:\n\nWith a fixed memory budget $|M| = K$, the challenge is distributing capacity across an increasing number of tasks. Common strategies:\n\n1. Equal Allocation: Reserve $K/t$ samples per task after $t$ tasks\n2. Age-Based: Keep more recent examples (ring buffer)\n3. Importance-Based: Keep samples based on computed importance scores\n4. Reservoir Sampling: Maintain a uniform random sample over all seen examples
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
import torchimport torch.nn as nnfrom torch.utils.data import DataLoader, Datasetimport numpy as npfrom typing import List, Tuple, Optional, Dictimport random class MemoryBuffer: """ Fixed-size memory buffer for experience replay. Maintains a balanced set of exemplars from all seen tasks, using reservoir sampling for online updates. """ def __init__(self, max_size: int = 2000, per_class: bool = True): """ Args: max_size: Maximum total samples in buffer per_class: If True, balance samples per class """ self.max_size = max_size self.per_class = per_class # Storage: list of (input, target, task_id) tuples self.buffer: List[Tuple[torch.Tensor, int, int]] = [] # Class-wise storage for balanced sampling self.class_buffers: Dict[int, List[Tuple[torch.Tensor, int]]] = {} # Counts for reservoir sampling self.seen_per_class: Dict[int, int] = {} def add( self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int ) -> None: """ Add new samples to the buffer. Uses reservoir sampling to maintain uniform random subset without knowing total data size in advance. """ for x, y in zip(inputs, targets): y_item = y.item() if isinstance(y, torch.Tensor) else y if self.per_class: self._add_class_balanced(x.clone(), y_item, task_id) else: self._add_reservoir(x.clone(), y_item, task_id) def _add_class_balanced( self, x: torch.Tensor, y: int, task_id: int ) -> None: """Add with class-balanced reservoir sampling.""" if y not in self.class_buffers: self.class_buffers[y] = [] self.seen_per_class[y] = 0 self.seen_per_class[y] += 1 n = self.seen_per_class[y] # Capacity per class (updated dynamically) n_classes = len(self.class_buffers) capacity = self.max_size // max(n_classes, 1) if len(self.class_buffers[y]) < capacity: # Buffer not full: just add self.class_buffers[y].append((x, task_id)) else: # Reservoir sampling: replace random element with prob capacity/n if random.random() < capacity / n: idx = random.randint(0, capacity - 1) self.class_buffers[y][idx] = (x, task_id) def _add_reservoir( self, x: torch.Tensor, y: int, task_id: int ) -> None: """Standard reservoir sampling.""" if len(self.buffer) < self.max_size: self.buffer.append((x, y, task_id)) else: # Replace with probability max_size / seen n = len(self.buffer) + 1 # Simplified; track actual count in practice if random.random() < self.max_size / n: idx = random.randint(0, self.max_size - 1) self.buffer[idx] = (x, y, task_id) def sample(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Sample a batch from the buffer. Returns: inputs: Tensor of shape (batch_size, ...) targets: Tensor of shape (batch_size,) """ if self.per_class: return self._sample_class_balanced(batch_size) else: return self._sample_uniform(batch_size) def _sample_class_balanced( self, batch_size: int ) -> Tuple[torch.Tensor, torch.Tensor]: """Sample uniformly across classes.""" samples = [] classes = list(self.class_buffers.keys()) for _ in range(batch_size): # Pick random class c = random.choice(classes) if self.class_buffers[c]: x, task_id = random.choice(self.class_buffers[c]) samples.append((x, c)) if not samples: return None, None inputs = torch.stack([s[0] for s in samples]) targets = torch.tensor([s[1] for s in samples]) return inputs, targets def _sample_uniform( self, batch_size: int ) -> Tuple[torch.Tensor, torch.Tensor]: """Sample uniformly from buffer.""" if not self.buffer: return None, None samples = random.sample( self.buffer, min(batch_size, len(self.buffer)) ) inputs = torch.stack([s[0] for s in samples]) targets = torch.tensor([s[1] for s in samples]) return inputs, targets def __len__(self) -> int: if self.per_class: return sum(len(buf) for buf in self.class_buffers.values()) return len(self.buffer) class ExperienceReplay: """ Experience Replay training loop for continual learning. Mixes stored exemplars with current task data in each batch. """ def __init__( self, model: nn.Module, buffer_size: int = 2000, replay_batch_ratio: float = 0.5 ): """ Args: model: Neural network model buffer_size: Total memory buffer size replay_batch_ratio: Fraction of batch from memory (0-1) """ self.model = model self.memory = MemoryBuffer(max_size=buffer_size) self.replay_batch_ratio = replay_batch_ratio def train_task( self, dataloader: DataLoader, task_id: int, epochs: int, criterion: nn.Module, optimizer: torch.optim.Optimizer ) -> List[float]: """ Train on a new task with experience replay. Each batch combines: - (1 - replay_ratio) from current task - replay_ratio from memory buffer """ device = next(self.model.parameters()).device losses = [] for epoch in range(epochs): epoch_loss = 0.0 batches = 0 for inputs, targets in dataloader: inputs = inputs.to(device) targets = targets.to(device) batch_size = inputs.size(0) # Determine replay batch size replay_size = int(batch_size * self.replay_batch_ratio) current_size = batch_size - replay_size # Get replay samples if len(self.memory) > 0 and replay_size > 0: replay_inputs, replay_targets = self.memory.sample(replay_size) if replay_inputs is not None: replay_inputs = replay_inputs.to(device) replay_targets = replay_targets.to(device) # Combine current and replay combined_inputs = torch.cat([ inputs[:current_size], replay_inputs ]) combined_targets = torch.cat([ targets[:current_size], replay_targets ]) else: combined_inputs = inputs combined_targets = targets else: combined_inputs = inputs combined_targets = targets # Training step optimizer.zero_grad() outputs = self.model(combined_inputs) loss = criterion(outputs, combined_targets) loss.backward() optimizer.step() epoch_loss += loss.item() batches += 1 losses.append(epoch_loss / batches) # After training: add samples to memory for inputs, targets in dataloader: self.memory.add(inputs, targets, task_id) return losses # Example usagedef demonstrate_experience_replay(): """Demonstrate experience replay on sequential tasks.""" model = nn.Sequential( nn.Flatten(), nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10) ) er = ExperienceReplay( model, buffer_size=2000, replay_batch_ratio=0.5 ) print("Experience Replay initialized") print(f"Buffer size: {er.memory.max_size}") print(f"Replay ratio: {er.replay_batch_ratio}") print("\nTraining loop mixes 50% current task, 50% memory buffer") demonstrate_experience_replay()Not all samples are equally valuable for preventing forgetting. Intelligent exemplar selection can dramatically improve replay effectiveness within the same memory budget.\n\nThe Selection Problem:\n\nGiven $N$ samples and a budget of $K << N$, which $K$ samples best summarize the data distribution for replay purposes?
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
import torchimport torch.nn as nnimport numpy as npfrom typing import List, Tuple, Dictfrom sklearn.cluster import KMeans class ExemplarSelector: """ Advanced exemplar selection strategies for experience replay. """ @staticmethod def random_selection( inputs: torch.Tensor, targets: torch.Tensor, budget: int ) -> Tuple[torch.Tensor, torch.Tensor]: """ Random selection baseline. Simple, unbiased, no computational overhead. """ n = inputs.size(0) if n <= budget: return inputs, targets indices = torch.randperm(n)[:budget] return inputs[indices], targets[indices] @staticmethod def herding_selection( features: torch.Tensor, inputs: torch.Tensor, targets: torch.Tensor, budget_per_class: int ) -> Tuple[torch.Tensor, torch.Tensor]: """ Herding selection (from iCaRL). For each class, iteratively select samples that bring the mean of selected samples closest to the true class mean. This ensures the selected set is a good summary of the class. Algorithm: 1. Compute class mean μ 2. For i = 1 to K: - Select sample that minimizes ||μ - mean(selected + sample)|| """ selected_inputs = [] selected_targets = [] unique_classes = torch.unique(targets) for c in unique_classes: mask = (targets == c) class_features = features[mask] class_inputs = inputs[mask] # Class mean mu = class_features.mean(dim=0) # Greedily select exemplars n = class_features.size(0) budget = min(budget_per_class, n) selected_idx = [] running_sum = torch.zeros_like(mu) for _ in range(budget): # For each candidate, compute new mean if selected candidates = [i for i in range(n) if i not in selected_idx] best_idx = None best_dist = float('inf') for idx in candidates: new_sum = running_sum + class_features[idx] new_mean = new_sum / (len(selected_idx) + 1) dist = (mu - new_mean).pow(2).sum().item() if dist < best_dist: best_dist = dist best_idx = idx selected_idx.append(best_idx) running_sum += class_features[best_idx] # Add selected exemplars for idx in selected_idx: selected_inputs.append(class_inputs[idx]) selected_targets.append(c) return torch.stack(selected_inputs), torch.tensor(selected_targets) @staticmethod def kcenter_coreset( features: torch.Tensor, inputs: torch.Tensor, targets: torch.Tensor, budget: int ) -> Tuple[torch.Tensor, torch.Tensor]: """ K-center coreset selection. Select samples that minimize the maximum distance from any sample to its nearest selected sample. This ensures good coverage of the entire feature space. Algorithm (greedy approximation): 1. Select random initial point 2. Repeat: - Select point farthest from all selected points """ n = features.size(0) if n <= budget: return inputs, targets # Distance matrix (can be memory-intensive) features_np = features.numpy() selected = [np.random.randint(n)] for _ in range(budget - 1): # Compute distance to nearest selected point selected_features = features_np[selected] min_dists = np.full(n, np.inf) for feat in selected_features: dists = np.linalg.norm(features_np - feat, axis=1) min_dists = np.minimum(min_dists, dists) # Exclude already selected for idx in selected: min_dists[idx] = -1 # Select farthest point next_idx = np.argmax(min_dists) selected.append(next_idx) indices = torch.tensor(selected) return inputs[indices], targets[indices] @staticmethod def gradient_diversity_selection( model: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, criterion: nn.Module, budget: int ) -> Tuple[torch.Tensor, torch.Tensor]: """ Select samples with diverse gradients. Samples that produce different gradients provide diverse learning signals, which is valuable for maintaining decision boundary knowledge. """ n = inputs.size(0) if n <= budget: return inputs, targets device = next(model.parameters()).device # Compute gradients for all samples gradients = [] model.eval() for x, y in zip(inputs, targets): model.zero_grad() x = x.unsqueeze(0).to(device) y = torch.tensor([y]).to(device) out = model(x) loss = criterion(out, y) loss.backward() # Flatten all gradients into single vector grad = torch.cat([ p.grad.flatten() for p in model.parameters() if p.grad is not None ]).cpu() gradients.append(grad) gradients = torch.stack(gradients) # Use k-center selection on gradient space return ExemplarSelector.kcenter_coreset( gradients, inputs, targets, budget ) class ForgettingBasedSelection: """ Track which samples are frequently forgotten and prioritize them. Some samples are inherently more forgettable - typically those near decision boundaries or with atypical features. """ def __init__(self): self.forgetting_counts: Dict[int, int] = {} # sample_id -> forget count self.last_correct: Dict[int, bool] = {} # sample_id -> was correct last time def update( self, sample_ids: List[int], predictions: torch.Tensor, targets: torch.Tensor ) -> None: """Update forgetting counts based on current predictions.""" correct = predictions.eq(targets) for i, sid in enumerate(sample_ids): is_correct = correct[i].item() if sid in self.last_correct: was_correct = self.last_correct[sid] # Forgetting event: was correct before, wrong now if was_correct and not is_correct: self.forgetting_counts[sid] = self.forgetting_counts.get(sid, 0) + 1 self.last_correct[sid] = is_correct def select( self, inputs: torch.Tensor, targets: torch.Tensor, sample_ids: List[int], budget: int ) -> Tuple[torch.Tensor, torch.Tensor]: """Select most frequently forgotten samples.""" # Sort by forgetting count counts = [self.forgetting_counts.get(sid, 0) for sid in sample_ids] sorted_indices = np.argsort(counts)[::-1] # Descending # Select top-budget samples selected = sorted_indices[:budget] return inputs[selected], targets[selected]iCaRL (Incremental Classifier and Representation Learning) popularized herding-based selection. It combines herding with nearest-mean classification and representation learning updates. The key insight is that class means in feature space provide a robust classifier that's less susceptible to forgetting than linear classifiers.
What if we could replay samples from previous tasks without storing them? Generative replay (also called pseudo-replay) trains a generative model alongside the task model. The generator learns to produce samples resembling past data, which are then used for replay.\n\nThe Key Insight:\n\nInstead of storing:\n$$M = \{(x_1, y_1), (x_2, y_2), \ldots, (x_K, y_K)\}$$\n\nWe learn a generator $G(z)$ that can produce unlimited samples resembling past data:\n$$\tilde{x} = G(z), \quad z \sim \mathcal{N}(0, I)$$\n\nThe memory requirement shifts from storing K samples to storing the generator parameters—which can be more efficient when data is high-dimensional or when we need unlimited replay samples.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoaderfrom typing import Tuple, Optional class VAEGenerator(nn.Module): """ Variational Autoencoder for generative replay. VAE provides a stable training objective (ELBO) and smooth latent space, making it suitable for continual learning. Architecture: Encoder: x -> μ, log σ² -> z (latent) Decoder: z -> x' (reconstruction) """ def __init__( self, input_dim: int = 784, hidden_dim: int = 400, latent_dim: int = 50 ): super().__init__() # Encoder self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU() ) self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # Decoder self.decoder = nn.Sequential( nn.Linear(latent_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Sigmoid() ) self.latent_dim = latent_dim def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: h = self.encoder(x) return self.fc_mu(h), self.fc_logvar(h) def reparameterize( self, mu: torch.Tensor, logvar: torch.Tensor ) -> torch.Tensor: """Reparameterization trick for backprop through sampling.""" std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z: torch.Tensor) -> torch.Tensor: return self.decoder(z) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode(x.view(-1, 784)) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar def generate(self, n_samples: int, device: torch.device) -> torch.Tensor: """Generate samples from prior.""" z = torch.randn(n_samples, self.latent_dim).to(device) return self.decode(z) class GenerativeReplay: """ Deep Generative Replay (DGR) for continual learning. Jointly trains a task solver and generator. After each task, the generator produces pseudo-samples of old data for replay. Key insight: Generator must ALSO be trained with replay to avoid forgetting how to generate old data. Reference: Shin et al., "Continual Learning with Deep Generative Replay" """ def __init__( self, solver: nn.Module, input_dim: int = 784, generator_hidden: int = 400, latent_dim: int = 50, replay_ratio: float = 0.5 ): self.solver = solver self.generator = VAEGenerator( input_dim=input_dim, hidden_dim=generator_hidden, latent_dim=latent_dim ) self.replay_ratio = replay_ratio # Previous generator for generating old samples self.old_generator: Optional[VAEGenerator] = None def vae_loss( self, recon_x: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor ) -> torch.Tensor: """ VAE ELBO loss: reconstruction + KL divergence """ # Reconstruction (binary cross-entropy for images) recon_loss = F.binary_cross_entropy( recon_x.view(-1, 784), x.view(-1, 784), reduction='sum' ) # KL divergence: -0.5 * sum(1 + log(σ²) - μ² - σ²) kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return recon_loss + kl_loss def train_task( self, dataloader: DataLoader, task_id: int, epochs: int, solver_criterion: nn.Module, solver_optimizer: torch.optim.Optimizer, generator_optimizer: torch.optim.Optimizer ): """ Train both solver and generator on new task with replay. """ device = next(self.solver.parameters()).device self.generator = self.generator.to(device) for epoch in range(epochs): for inputs, targets in dataloader: inputs = inputs.to(device) targets = targets.to(device) batch_size = inputs.size(0) # === Generate replay samples === if self.old_generator is not None: n_replay = int(batch_size * self.replay_ratio) with torch.no_grad(): # Generate pseudo-inputs from old generator replay_inputs = self.old_generator.generate(n_replay, device) # Get pseudo-labels from solver # (We use current solver to label, which is key insight) replay_outputs = self.solver(replay_inputs) replay_targets = replay_outputs.argmax(dim=1) # Combine real and replay combined_inputs = torch.cat([ inputs[:batch_size - n_replay], replay_inputs ]) combined_targets = torch.cat([ targets[:batch_size - n_replay], replay_targets ]) else: combined_inputs = inputs combined_targets = targets n_replay = 0 # === Train Solver === solver_optimizer.zero_grad() solver_outputs = self.solver(combined_inputs.view(-1, 784)) solver_loss = solver_criterion(solver_outputs, combined_targets) solver_loss.backward() solver_optimizer.step() # === Train Generator === generator_optimizer.zero_grad() # Current task reconstruction recon, mu, logvar = self.generator(inputs) current_vae_loss = self.vae_loss(recon, inputs, mu, logvar) # Replay: also reconstruct old (generated) samples if self.old_generator is not None: with torch.no_grad(): replay_gen = self.old_generator.generate(n_replay, device) recon_replay, mu_r, logvar_r = self.generator(replay_gen) replay_vae_loss = self.vae_loss( recon_replay, replay_gen, mu_r, logvar_r ) total_gen_loss = current_vae_loss + replay_vae_loss else: total_gen_loss = current_vae_loss total_gen_loss.backward() generator_optimizer.step() # After task: save current generator for next task's replay import copy self.old_generator = copy.deepcopy(self.generator) self.old_generator.eval() for p in self.old_generator.parameters(): p.requires_grad = False class DualMemoryGenerativeReplay: """ Brain-Inspired Replay with dual memory systems. Mimics hippocampal-cortical memory consolidation: - Fast-learning "hippocampus" (generator) captures recent experiences - Slow-learning "neocortex" (solver) consolidates through replay The generator learns quickly but forgets; interleaved replay transfers knowledge to the more stable solver. """ def __init__( self, solver: nn.Module, generator: nn.Module, consolidation_steps: int = 100 ): self.solver = solver self.generator = generator # "Hippocampus" self.consolidation_steps = consolidation_steps # Use different learning rates: fast for generator, slow for solver self.solver_lr = 0.0001 # Slow learning self.generator_lr = 0.01 # Fast learning (but will forget) def consolidate(self): """ Offline consolidation phase (like sleep). Generator replays experiences to solver without new data. """ device = next(self.solver.parameters()).device for _ in range(self.consolidation_steps): # Generate replay batch with torch.no_grad(): replay_inputs = self.generator.generate(64, device) replay_targets = self.solver(replay_inputs).argmax(dim=1) # Train solver on replayed memories # (Implementation omitted for brevity) passA critical challenge: the generator itself must be continually learned! If we only train it on new data, it forgets how to generate old data. The solution is recursive: use the old generator to produce replay samples for training the new generator. This creates a chain of generators, and any quality degradation compounds over tasks.
State-of-the-art continual learning often combines multiple strategies. Pure replay or pure regularization have limitations; hybrids leverage the strengths of each.\n\nKey Hybrid Strategies:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import List, Tuple, Dictimport random class DarkExperienceReplay: """ Dark Experience Replay (DER) and DER++ Regular replay stores (input, target). DER stores (input, logits) where logits are the model's output at the time of storage. During replay, we distill toward these stored logits, preserving the model's full output distribution, not just the class label. DER++ adds standard cross-entropy on stored labels as well. Reference: Buzzega et al., "Dark Experience for General Continual Learning" """ def __init__( self, model: nn.Module, buffer_size: int = 2000, alpha: float = 0.5, # Weight for logit distillation beta: float = 0.5, # Weight for label loss (DER++ only) use_plus: bool = True # DER++ vs DER ): self.model = model self.alpha = alpha self.beta = beta self.use_plus = use_plus # Buffer stores: (input, target, logits) self.buffer: List[Tuple[torch.Tensor, int, torch.Tensor]] = [] self.buffer_size = buffer_size def add_to_buffer( self, inputs: torch.Tensor, targets: torch.Tensor, logits: torch.Tensor ) -> None: """Store samples with their output logits.""" for x, y, l in zip(inputs, targets, logits): if len(self.buffer) < self.buffer_size: self.buffer.append(( x.clone().cpu(), y.item(), l.clone().detach().cpu() )) else: # Reservoir sampling idx = random.randint(0, len(self.buffer) + len(self.buffer) - 1) if idx < self.buffer_size: self.buffer[idx] = ( x.clone().cpu(), y.item(), l.clone().detach().cpu() ) def sample_buffer( self, batch_size: int, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Sample (inputs, targets, stored_logits) from buffer.""" if not self.buffer: return None, None, None samples = random.sample( self.buffer, min(batch_size, len(self.buffer)) ) inputs = torch.stack([s[0] for s in samples]).to(device) targets = torch.tensor([s[1] for s in samples]).to(device) logits = torch.stack([s[2] for s in samples]).to(device) return inputs, targets, logits def training_step( self, current_inputs: torch.Tensor, current_targets: torch.Tensor, criterion: nn.Module, optimizer: torch.optim.Optimizer ) -> Dict[str, float]: """ DER/DER++ training step. Loss = L_current + α * L_logit_distill + β * L_buffer_ce (DER++) """ device = current_inputs.device optimizer.zero_grad() losses = {} # Current task loss current_outputs = self.model(current_inputs) loss_current = criterion(current_outputs, current_targets) losses['current'] = loss_current.item() total_loss = loss_current # Buffer replay buf_inputs, buf_targets, buf_logits = self.sample_buffer( current_inputs.size(0), device ) if buf_inputs is not None: buf_outputs = self.model(buf_inputs) # DER: Distillation toward stored logits # MSE between current and stored logits loss_logit = F.mse_loss(buf_outputs, buf_logits) losses['logit_distill'] = loss_logit.item() total_loss = total_loss + self.alpha * loss_logit # DER++: Also add cross-entropy on buffer samples if self.use_plus: loss_buffer_ce = criterion(buf_outputs, buf_targets) losses['buffer_ce'] = loss_buffer_ce.item() total_loss = total_loss + self.beta * loss_buffer_ce total_loss.backward() optimizer.step() # Store current samples with their logits with torch.no_grad(): self.add_to_buffer( current_inputs, current_targets, current_outputs.detach() ) losses['total'] = total_loss.item() return losses class GDumb: """ GDumb: A simple but surprisingly effective baseline. Strategy: 1. Greedily fill memory with seen samples 2. When new task arrives, retrain from scratch on memory only No replay during training - just store and retrain. Success shows importance of exemplar quality over replay sophistication. Reference: Prabhu et al., "GDumb: A Simple Approach that Questions..." """ def __init__( self, model_fn, # Function to create fresh model buffer_size: int = 2000, train_epochs: int = 100 ): self.model_fn = model_fn self.buffer_size = buffer_size self.train_epochs = train_epochs # Class-balanced buffer self.buffer: Dict[int, List[torch.Tensor]] = {} def add_to_buffer(self, inputs: torch.Tensor, targets: torch.Tensor): """Add samples maintaining class balance.""" for x, y in zip(inputs, targets): y_item = y.item() if y_item not in self.buffer: self.buffer[y_item] = [] n_classes = len(self.buffer) max_per_class = self.buffer_size // n_classes if len(self.buffer[y_item]) < max_per_class: self.buffer[y_item].append(x.clone()) else: # Random replacement if random.random() < 0.5: idx = random.randint(0, len(self.buffer[y_item]) - 1) self.buffer[y_item][idx] = x.clone() def get_buffer_dataset(self): """Return all buffered samples as lists.""" inputs = [] targets = [] for label, samples in self.buffer.items(): for sample in samples: inputs.append(sample) targets.append(label) return inputs, targets def train_on_buffer(self, device: torch.device): """Train fresh model on buffer content only.""" # Create fresh model model = self.model_fn().to(device) optimizer = torch.optim.Adam(model.parameters()) criterion = nn.CrossEntropyLoss() inputs, targets = self.get_buffer_dataset() if not inputs: return model inputs = torch.stack(inputs).to(device) targets = torch.tensor(targets).to(device) model.train() for epoch in range(self.train_epochs): # Shuffle perm = torch.randperm(len(inputs)) inputs = inputs[perm] targets = targets[perm] # Train in batches for i in range(0, len(inputs), 64): batch_x = inputs[i:i+64] batch_y = targets[i:i+64] optimizer.zero_grad() out = model(batch_x) loss = criterion(out, batch_y) loss.backward() optimizer.step() return modelDER's insight is that class labels discard information. Storing the full logit vector preserves inter-class relationships and model confidence. When replaying, distilling toward stored logits transfers this rich information, not just the argmax label. This is especially powerful when classes are similar or the model was uncertain.
Implementing replay in production systems requires careful consideration of practical constraints:
| Factor | Experience Replay | Generative Replay | Hybrid (DER) |
|---|---|---|---|
| Memory for 1K samples (MNIST) | ~3 MB | ~1-2 MB (generator) | ~6 MB (+ logits) |
| Memory for 1K samples (ImageNet) | ~600 MB | ~50-100 MB (generator) | ~1.2 GB (+ logits) |
| Sample quality | Perfect (real data) | Depends on generator | Perfect + distillation |
| Privacy concerns | High (stores data) | Lower (learned representation) | High (stores data) |
| Computational cost | Low (just storage) | High (generator training) | Medium |
| Scalability with tasks | Good | Generator may degrade | Very good |
There's typically a logarithmic relationship between buffer size and performance. Doubling memory from 500 to 1000 samples helps much more than doubling from 5000 to 10000. The first few hundred samples per class are crucial; additional samples provide diminishing returns.
We have explored the landscape of replay-based continual learning methods. Let's consolidate the key insights:
What's Next:\n\nIn the next page, we explore dynamic architectures—methods that add new parameters or modules for new tasks rather than forcing all tasks into fixed capacity. This fundamentally different approach can achieve zero forgetting by isolating task-specific knowledge.
You now have comprehensive knowledge of replay-based continual learning: from basic experience replay through advanced methods like DER and generative replay. You understand the trade-offs between memory, privacy, and performance, and can select appropriate replay strategies for various production constraints.