Loading content...
The previous pages explored regularization (constraining weight updates) and replay (rehearsing old data). Both approaches share a fundamental assumption: all tasks must fit within fixed network capacity. As tasks accumulate, this fixed capacity becomes a bottleneck—old knowledge competes with new knowledge for limited representational space.\n\nDynamic architectures take a fundamentally different approach: instead of forcing all tasks into a fixed network, they expand the network as needed. New tasks receive new parameters, ensuring that learning new information cannot overwrite old representations.\n\nThis approach offers a powerful guarantee: zero forgetting by construction. If old parameters are never modified, old task performance cannot degrade. However, this comes at the cost of a growing model, introducing new challenges around capacity management and forward transfer.
By the end of this page, you will understand progressive neural networks and lateral connections, parameter isolation strategies (PackNet, Piggyback), dynamic expansion and capacity allocation, sparse networks and structured sparsity for continual learning, and how to balance model growth with computational constraints.
Progressive Neural Networks (PNNs), introduced by Rusu et al. (2016) at DeepMind, represent the seminal work in architectural approaches to continual learning. The key idea is strikingly simple: add a new column of layers for each task, connecting it to all previous columns.\n\nArchitecture:\n\nFor task $t$, PNN maintains $t$ columns (neural network pathways). Each column has its own parameters, ensuring no interference. To enable forward transfer (leveraging past knowledge for new tasks), lateral connections allow new columns to read from old columns.\n\nFormally, for layer $i$ in column $t$:\n\n$$h_i^{(t)} = f\left( W_i^{(t)} h_{i-1}^{(t)} + \sum_{j<t} U_{i}^{(t:j)} h_{i-1}^{(j)} \right)$$\n\nwhere:\n- $W_i^{(t)}$ are the standard weights for column $t$, layer $i$\n- $U_i^{(t:j)}$ are lateral connections from column $j$ to column $t$\n- $h_{i-1}^{(j)}$ are activations from previous columns (frozen)
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
import torchimport torch.nn as nnfrom typing import List, Optional class ProgressiveColumn(nn.Module): """ A single column in a Progressive Neural Network. Each column has its own parameters plus lateral connections to read from all previous columns (for forward transfer). """ def __init__( self, input_dim: int, hidden_dims: List[int], output_dim: int, previous_columns: Optional[List['ProgressiveColumn']] = None ): """ Args: input_dim: Input dimension hidden_dims: List of hidden layer dimensions output_dim: Output dimension previous_columns: List of frozen previous columns for lateral connections """ super().__init__() self.previous_columns = previous_columns or [] n_prev = len(self.previous_columns) # Build layers with lateral connections dims = [input_dim] + hidden_dims self.layers = nn.ModuleList() self.lateral_connections = nn.ModuleList() # U matrices for i in range(len(hidden_dims)): # Standard layer weights (W) self.layers.append(nn.Sequential( nn.Linear(dims[i], dims[i+1]), nn.ReLU() )) # Lateral connections from previous columns (U) if n_prev > 0: # One adapter per previous column at this layer laterals = nn.ModuleList([ nn.Linear(dims[i], dims[i+1], bias=False) for _ in range(n_prev) ]) self.lateral_connections.append(laterals) else: self.lateral_connections.append(nn.ModuleList()) # Output head self.output_layer = nn.Linear(hidden_dims[-1], output_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through this column with lateral connections. At each layer: h = f(Wx + Σ U_j * h_prev_j) where h_prev_j comes from frozen previous columns. """ # Get activations from previous columns (frozen) prev_activations = [] if self.previous_columns: for col in self.previous_columns: # Store intermediate activations from previous column acts = col.get_intermediate_activations(x) prev_activations.append(acts) h = x for layer_idx, layer in enumerate(self.layers): # Standard forward h_main = layer(h) # Add lateral contributions lateral_sum = torch.zeros_like(h_main) if prev_activations and layer_idx < len(self.lateral_connections): for col_idx, col_acts in enumerate(prev_activations): if layer_idx < len(col_acts): lateral = self.lateral_connections[layer_idx][col_idx] lateral_sum += lateral(col_acts[layer_idx]) h = h_main + lateral_sum return self.output_layer(h) def get_intermediate_activations(self, x: torch.Tensor) -> List[torch.Tensor]: """Get activations at each layer (for lateral connections).""" activations = [x] h = x for layer in self.layers: h = layer(h) activations.append(h) return activations class ProgressiveNeuralNetwork: """ Progressive Neural Network for continual learning. Architecture grows with each task: - Task 1: Column 1 only - Task 2: Column 1 (frozen) + Column 2 (trainable) + laterals - Task N: Columns 1..N-1 (frozen) + Column N (trainable) + laterals Guarantees: - Zero forgetting: old columns never modified - Forward transfer: laterals leverage previous knowledge Trade-off: Model size grows linearly with number of tasks """ def __init__( self, input_dim: int, hidden_dims: List[int], output_dim: int ): self.input_dim = input_dim self.hidden_dims = hidden_dims self.output_dim = output_dim self.columns: List[ProgressiveColumn] = [] self.current_task = 0 def add_task(self) -> ProgressiveColumn: """ Add a new column for a new task. Returns the trainable column for the new task. Previous columns are frozen. """ # Freeze all existing columns for col in self.columns: for param in col.parameters(): param.requires_grad = False # Create new column with lateral connections to all previous new_column = ProgressiveColumn( input_dim=self.input_dim, hidden_dims=self.hidden_dims, output_dim=self.output_dim, previous_columns=self.columns.copy() ) self.columns.append(new_column) self.current_task += 1 return new_column def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: """Forward through the column for given task.""" if task_id >= len(self.columns): raise ValueError(f"No column for task {task_id}") return self.columns[task_id](x) def get_trainable_parameters(self): """Get parameters of the current (latest) column.""" if not self.columns: return [] return self.columns[-1].parameters() def total_parameters(self) -> int: """Count total parameters across all columns.""" total = 0 for col in self.columns: total += sum(p.numel() for p in col.parameters()) return total def demonstrate_progressive_nn(): """Demonstrate Progressive Neural Network.""" pnn = ProgressiveNeuralNetwork( input_dim=784, hidden_dims=[256, 128], output_dim=10 ) print("Progressive Neural Network Demo") print("=" * 50) for task_id in range(3): column = pnn.add_task() n_params = sum(p.numel() for p in column.parameters()) print(f"\nTask {task_id + 1}:") print(f" New column parameters: {n_params:,}") print(f" Total network parameters: {pnn.total_parameters():,}") print(f" Number of columns: {len(pnn.columns)}") # Train column here... # optimizer = torch.optim.Adam(pnn.get_trainable_parameters()) demonstrate_progressive_nn()PNN's model size grows linearly (actually quadratically including lateral connections) with the number of tasks. For 100 tasks, you need 100 columns. This is impractical for long-lived continual learning systems. However, PNN remains conceptually important as it established the architectural paradigm and provides an upper bound on what's achievable with zero forgetting.
Rather than adding entirely new networks, parameter isolation methods allocate subsets of a fixed network to different tasks. Each task uses only a portion of parameters, preventing interference while maintaining a constant model size.\n\nKey Methods in This Family:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Dict, List, Optional, Tupleimport numpy as np class PackNet: """ PackNet: Sequential pruning for continual learning. Strategy: 1. Train on task until convergence 2. Prune least important weights (by magnitude) 3. Freeze remaining (important) weights 4. Next task uses pruned capacity Advantage: Fixed model size, zero forgetting for frozen weights Limitation: Eventually runs out of capacity Reference: Mallya & Lazebnik, "PackNet: Adding Multiple Tasks..." """ def __init__( self, model: nn.Module, prune_ratio: float = 0.75 # Fraction to prune per task ): """ Args: model: The neural network (will be modified in-place) prune_ratio: Fraction of weights to prune after each task """ self.model = model self.prune_ratio = prune_ratio # Track which weights are free vs frozen self.frozen_masks: Dict[str, torch.Tensor] = {} self.task_masks: Dict[int, Dict[str, torch.Tensor]] = {} # Initialize all weights as free for name, param in model.named_parameters(): if 'weight' in name: self.frozen_masks[name] = torch.zeros_like(param, dtype=torch.bool) def get_free_mask(self, name: str) -> torch.Tensor: """Get mask of weights that are still free (not frozen).""" return ~self.frozen_masks[name] def prune_and_freeze(self, task_id: int) -> Dict[str, float]: """ After training: prune unimportant weights, freeze rest. Returns dict of sparsity statistics. """ stats = {} task_mask = {} with torch.no_grad(): for name, param in self.model.named_parameters(): if 'weight' not in name: continue free_mask = self.get_free_mask(name) free_weights = param.data * free_mask.float() # Find magnitude threshold for pruning abs_weights = free_weights.abs() non_zero = abs_weights[free_mask] if len(non_zero) == 0: continue # Prune bottom prune_ratio by magnitude threshold = torch.quantile(non_zero, self.prune_ratio) # Weights above threshold are kept (frozen) keep_mask = abs_weights > threshold # Update frozen mask: previously frozen OR kept this time self.frozen_masks[name] = self.frozen_masks[name] | keep_mask task_mask[name] = keep_mask.clone() # Zero out pruned weights param.data *= keep_mask.float() # Statistics n_total = param.numel() n_frozen = self.frozen_masks[name].sum().item() stats[name] = { 'frozen_ratio': n_frozen / n_total, 'free_remaining': 1 - n_frozen / n_total } self.task_masks[task_id] = task_mask return stats def apply_free_gradient_mask(self): """ Zero gradients for frozen weights during training. Call after loss.backward() to prevent updates to frozen weights. """ for name, param in self.model.named_parameters(): if name in self.frozen_masks and param.grad is not None: free_mask = self.get_free_mask(name) param.grad *= free_mask.float() class Piggyback(nn.Module): """ Piggyback: Learning task-specific masks for pretrained networks. Key idea: Start with a good pretrained network. For each new task, learn a binary mask that selects which weights to use. The weights themselves never change—only the masks are learned. Advantages: - Very parameter efficient (only masks stored per task) - Leverages strong pretrained features - Zero forgetting (weights frozen) Reference: Mallya et al., "Piggyback: Adapting a Single Network..." """ def __init__(self, base_model: nn.Module, threshold: float = 0.0): """ Args: base_model: Pretrained network (weights will be frozen) threshold: Threshold for binarizing masks """ super().__init__() # Freeze base model self.base_model = base_model for param in self.base_model.parameters(): param.requires_grad = False self.threshold = threshold # Learnable mask scores for current task self.mask_scores: Dict[str, nn.Parameter] = {} self.task_masks: Dict[int, Dict[str, torch.Tensor]] = {} self.current_task = -1 def add_task(self) -> List[nn.Parameter]: """ Add mask parameters for a new task. Returns list of trainable mask parameters. """ self.current_task += 1 # Create learnable mask scores for each weight trainable = [] for name, param in self.base_model.named_parameters(): if 'weight' in name: # Initialize mask scores around threshold scores = nn.Parameter( torch.randn_like(param) * 0.01 + self.threshold ) self.mask_scores[name] = scores trainable.append(scores) return trainable def get_masked_weight(self, name: str, weight: torch.Tensor) -> torch.Tensor: """ Apply current mask to weight. Uses sigmoid + threshold for differentiable binarization. """ if name not in self.mask_scores: return weight scores = self.mask_scores[name] # Differentiable mask: sigmoid with straight-through estimator if self.training: # Soft mask for gradients mask = torch.sigmoid(scores - self.threshold) else: # Hard mask at inference mask = (scores > self.threshold).float() return weight * mask def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward with masked weights. This requires hooking into the base model's forward. Implementation depends on model architecture. """ # Simplified: assumes base_model is a nn.Sequential # Real implementation would use forward hooks h = x for name, module in self.base_model.named_modules(): if hasattr(module, 'weight'): # Get masked weight weight_name = f"{name}.weight" masked_weight = self.get_masked_weight(weight_name, module.weight) # Apply layer with masked weight if isinstance(module, nn.Linear): h = F.linear(h, masked_weight, module.bias) elif isinstance(module, nn.Conv2d): h = F.conv2d(h, masked_weight, module.bias, module.stride, module.padding) elif isinstance(module, nn.ReLU): h = F.relu(h) # ... handle other module types return h def consolidate_task(self): """Save current masks for the current task.""" binarized = {} for name, scores in self.mask_scores.items(): binarized[name] = (scores > self.threshold).clone() self.task_masks[self.current_task] = binarized def set_task(self, task_id: int): """Load masks for a specific task.""" if task_id not in self.task_masks: raise ValueError(f"No masks for task {task_id}") for name, mask in self.task_masks[task_id].items(): self.mask_scores[name].data = mask.float() * 2 # Above threshold class HardAttentionToTask(nn.Module): """ HAT: Hard Attention to the Task Learns task-specific attention masks on neurons (not weights). Uses gradient-based mask learning with annealing for sharpness. Key equations: a = σ(s*e) # attention mask, s = scaling (annealed), e = embeddings h_masked = h ⊙ a # apply mask to activations Regularization encourages sparse, task-specific masks. Reference: Serrà et al., "Overcoming Catastrophic Forgetting..." """ def __init__( self, layer_dims: List[int], n_tasks: int, s_max: float = 400 ): super().__init__() self.layer_dims = layer_dims self.n_tasks = n_tasks self.s_max = s_max # Task embeddings for each layer # Shape: (n_tasks, layer_dim) self.embeddings = nn.ParameterList([ nn.Parameter(torch.randn(n_tasks, dim) * 0.01) for dim in layer_dims ]) # Accumulated mask for determining what's "used" self.accumulated_masks: List[torch.Tensor] = [ torch.zeros(1, dim) for dim in layer_dims ] def get_mask( self, task_id: int, layer_idx: int, s: float = None ) -> torch.Tensor: """ Get attention mask for a layer given task. Args: task_id: Current task layer_idx: Which layer s: Scaling factor (annealed during training) """ s = s or self.s_max e = self.embeddings[layer_idx][task_id] return torch.sigmoid(s * e) def apply_mask( self, activations: torch.Tensor, task_id: int, layer_idx: int, s: float = None ) -> torch.Tensor: """Apply task-specific mask to activations.""" mask = self.get_mask(task_id, layer_idx, s) return activations * mask.unsqueeze(0) def regularization_loss( self, task_id: int, s: float ) -> torch.Tensor: """ HAT regularization: encourage sparse masks that don't overlap with previously used capacity. L_reg = Σ_l [ a_l ⊙ (1 - accumulated_l) ]·1 This penalizes using neurons that previous tasks used. """ loss = 0.0 for layer_idx, emb in enumerate(self.embeddings): mask = torch.sigmoid(s * emb[task_id]) prev_mask = self.accumulated_masks[layer_idx] # Penalize overlap with previous tasks overlap = mask * prev_mask.squeeze() loss += overlap.sum() return loss def consolidate_task(self, task_id: int): """After task: update accumulated masks.""" with torch.no_grad(): for layer_idx, emb in enumerate(self.embeddings): mask = torch.sigmoid(self.s_max * emb[task_id]) self.accumulated_masks[layer_idx] = torch.max( self.accumulated_masks[layer_idx], mask.unsqueeze(0) )Binary masks are not differentiable, creating a training challenge. Methods use various solutions: Piggyback uses sigmoid + thresholding with straight-through estimators; HAT uses temperature annealing (gradually sharpen masks during training); SupSup uses population-based mask search. The choice affects how well masks can be optimized.
While PNN adds entire columns and PackNet uses fixed capacity, dynamic expansion methods find a middle ground: add capacity only when needed, and add it efficiently.\n\nThe Key Question:\n\nHow do we decide when to add capacity vs. reuse existing capacity?
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
import torchimport torch.nn as nnfrom typing import List, Dict, Tuple, Optionalimport copy class DynamicallyExpandableNetwork: """ Dynamically Expandable Networks (DEN). Key ideas: 1. Start with minimal network 2. Train on each task 3. If performance plateaus, add neurons (expand) 4. Use selective retraining for related knowledge 5. Apply group sparsity to prevent unbounded growth Reference: Yoon et al., "Lifelong Learning with DEN" """ def __init__( self, input_dim: int, initial_hidden: int, output_dim: int, expansion_threshold: float = 0.01, sparsity_lambda: float = 0.001 ): self.input_dim = input_dim self.output_dim = output_dim self.expansion_threshold = expansion_threshold self.sparsity_lambda = sparsity_lambda # Start with minimal network self.hidden_sizes = [initial_hidden] self.model = self._build_model() # Track which neurons belong to which task self.neuron_task_map: Dict[int, List[Tuple[int, int, int]]] = {} def _build_model(self) -> nn.Sequential: """Build network with current architecture.""" layers = [] prev_dim = self.input_dim for hidden in self.hidden_sizes: layers.extend([ nn.Linear(prev_dim, hidden), nn.ReLU() ]) prev_dim = hidden layers.append(nn.Linear(prev_dim, self.output_dim)) return nn.Sequential(*layers) def expand_layer(self, layer_idx: int, n_new_neurons: int) -> None: """ Add neurons to a specific layer. This requires careful weight matrix surgery: - Extend weight matrix of this layer (output dim) - Extend weight matrix of next layer (input dim) - Initialize new weights appropriately """ # Get current layers linear_idx = layer_idx * 2 # Account for ReLU layers current_layer = self.model[linear_idx] old_out = current_layer.out_features new_out = old_out + n_new_neurons # Create expanded layer new_layer = nn.Linear(current_layer.in_features, new_out) # Copy old weights with torch.no_grad(): new_layer.weight[:old_out] = current_layer.weight new_layer.bias[:old_out] = current_layer.bias # Initialize new weights (small random, or zeros) nn.init.kaiming_normal_(new_layer.weight[old_out:]) nn.init.zeros_(new_layer.bias[old_out:]) self.model[linear_idx] = new_layer # Update next layer's input dimension if layer_idx + 1 < len(self.hidden_sizes): next_linear_idx = (layer_idx + 1) * 2 next_layer = self.model[next_linear_idx] new_next_layer = nn.Linear(new_out, next_layer.out_features) with torch.no_grad(): new_next_layer.weight[:, :old_out] = next_layer.weight new_next_layer.bias[:] = next_layer.bias nn.init.zeros_(new_next_layer.weight[:, old_out:]) self.model[next_linear_idx] = new_next_layer else: # This was last hidden layer, update output layer output_layer = self.model[-1] new_output_layer = nn.Linear(new_out, output_layer.out_features) with torch.no_grad(): new_output_layer.weight[:, :old_out] = output_layer.weight new_output_layer.bias[:] = output_layer.bias nn.init.zeros_(new_output_layer.weight[:, old_out:]) self.model[-1] = new_output_layer self.hidden_sizes[layer_idx] = new_out def should_expand( self, losses: List[float], window: int = 10 ) -> bool: """ Detect if learning has plateaued (semantic drift). Returns True if loss hasn't improved by threshold over window. """ if len(losses) < window: return False recent = losses[-window:] improvement = recent[0] - recent[-1] return improvement < self.expansion_threshold def group_sparsity_loss(self) -> torch.Tensor: """ Group sparsity regularization. Encourages entire neurons (groups of weights) to be zero, enabling future pruning and preventing unbounded growth. L_group = Σ_j ||w_j||_2 (L2 norm per neuron, summed) """ loss = 0.0 for module in self.model.modules(): if isinstance(module, nn.Linear): # L2 norm of each output neuron's weights neuron_norms = torch.norm(module.weight, dim=1) loss += neuron_norms.sum() return self.sparsity_lambda * loss class ExpertGate(nn.Module): """ Expert Gate: Mixture of Experts for Continual Learning. Architecture: - Shared encoder (may be frozen) - Multiple expert sub-networks - Gating network decides which expert to use - New expertise triggers new expert creation Key insight: Different tasks may need different expertise. Add experts for novel tasks; reuse experts for similar tasks. """ def __init__( self, input_dim: int, expert_hidden: int, output_dim: int, gate_hidden: int = 64 ): super().__init__() self.input_dim = input_dim self.expert_hidden = expert_hidden self.output_dim = output_dim # Experts: list of specialist networks self.experts: nn.ModuleList = nn.ModuleList() # Gating network: decides which expert to use self.gate = nn.Sequential( nn.Linear(input_dim, gate_hidden), nn.ReLU(), nn.Linear(gate_hidden, 1) # Will expand as experts added ) # Task-to-expert mapping self.task_experts: Dict[int, int] = {} def add_expert(self) -> int: """Add a new expert. Returns the expert index.""" expert = nn.Sequential( nn.Linear(self.input_dim, self.expert_hidden), nn.ReLU(), nn.Linear(self.expert_hidden, self.expert_hidden), nn.ReLU(), nn.Linear(self.expert_hidden, self.output_dim) ) self.experts.append(expert) # Expand gate output dimension n_experts = len(self.experts) if n_experts > 1: old_gate = self.gate new_gate_output = nn.Linear(64, n_experts) # Copy old weights with torch.no_grad(): new_gate_output.weight[:n_experts-1] = old_gate[-1].weight new_gate_output.bias[:n_experts-1] = old_gate[-1].bias self.gate = nn.Sequential( old_gate[0], # Linear old_gate[1], # ReLU new_gate_output ) return len(self.experts) - 1 def forward( self, x: torch.Tensor, task_id: Optional[int] = None ) -> torch.Tensor: """ Forward pass. If task_id given, use assigned expert. Otherwise, use gating to select expert. """ if task_id is not None and task_id in self.task_experts: # Use assigned expert expert_idx = self.task_experts[task_id] return self.experts[expert_idx](x) else: # Use gating (soft or hard) gate_scores = self.gate(x) # (batch, n_experts) gate_weights = torch.softmax(gate_scores, dim=1) # Weighted combination of experts output = 0 for i, expert in enumerate(self.experts): expert_out = expert(x) output = output + gate_weights[:, i:i+1] * expert_out return output def assign_expert(self, task_id: int, expert_idx: int): """Assign a task to use a specific expert.""" self.task_experts[task_id] = expert_idx class AdaptiveNetwork: """ Adaptive network that decides between reuse and expansion. Decision process: 1. Try to learn new task with existing capacity 2. Measure fit quality (loss, gradient magnitude) 3. If poor fit, expand; if good fit, reuse 4. Prune redundant capacity periodically """ def __init__( self, base_model: nn.Module, expansion_rate: float = 0.25, fit_threshold: float = 0.5 ): self.model = base_model self.expansion_rate = expansion_rate self.fit_threshold = fit_threshold def evaluate_fit( self, dataloader, criterion, n_batches: int = 10 ) -> float: """ Evaluate how well current capacity fits new task. Uses gradient magnitude as proxy: high gradients suggest the network needs significant changes (poor fit). """ self.model.train() total_grad_norm = 0.0 for i, (inputs, targets) in enumerate(dataloader): if i >= n_batches: break outputs = self.model(inputs) loss = criterion(outputs, targets) loss.backward() # Compute gradient norm grad_norm = 0.0 for p in self.model.parameters(): if p.grad is not None: grad_norm += p.grad.norm().item() ** 2 grad_norm = grad_norm ** 0.5 total_grad_norm += grad_norm self.model.zero_grad() return total_grad_norm / n_batches def decide_expansion(self, grad_norm: float) -> bool: """Decide whether to expand based on fit quality.""" return grad_norm > self.fit_thresholdRecent work on sparse neural networks and the Lottery Ticket Hypothesis provides new perspectives on continual learning. The key insight: neural networks contain sparse subnetworks (often just 5-10% of weights) that can match full network performance when trained in isolation.\n\nThe Lottery Ticket Connection:\n\nIf a network contains multiple winning 'lottery tickets' (sparse trainable subnetworks), perhaps different tickets can be assigned to different tasks. This enables:\n\n- Task-specific sparse subnetworks within shared parameters\n- Minimal interference (sparse networks have little overlap)\n- Capacity for many tasks (each using small fraction of weights)
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Dict, Listimport random class SupSup(nn.Module): """ Superposition of Supermasks (SupSup). Key idea: Keep weights fixed at random initialization. For each task, learn a sparse binary mask (supermask). Different tasks have different masks, no interference. Stunning fact: Random networks with learned masks can perform comparably to trained networks! Advantages: - Zero forgetting (weights never change) - Very parameter efficient (only binary masks per task) - Unlimited task capacity (theoretically) Reference: Wortsman et al., "Supermasks in Superposition" """ def __init__( self, input_dim: int, hidden_dims: List[int], output_dim: int, sparsity: float = 0.5 # Fraction of weights to keep ): super().__init__() self.sparsity = sparsity # Build fixed random network dims = [input_dim] + hidden_dims + [output_dim] self.layers = nn.ModuleList() for i in range(len(dims) - 1): layer = nn.Linear(dims[i], dims[i+1]) # Freeze weights at initialization layer.weight.requires_grad = False layer.bias.requires_grad = False self.layers.append(layer) # Task masks: task_id -> list of mask scores per layer self.mask_scores: Dict[int, List[nn.Parameter]] = {} self.current_task = -1 def add_task(self) -> List[nn.Parameter]: """ Add trainable mask parameters for a new task. Returns list of mask parameters to optimize. """ self.current_task += 1 mask_params = [] for layer in self.layers: # Mask scores: same shape as weights scores = nn.Parameter( torch.randn_like(layer.weight) * 0.01 ) mask_params.append(scores) self.mask_scores[self.current_task] = mask_params return mask_params def get_mask( self, scores: torch.Tensor, training: bool = True ) -> torch.Tensor: """ Convert mask scores to sparse binary mask. Uses top-k selection to enforce sparsity. Straight-through estimator for gradients. """ k = int(scores.numel() * self.sparsity) # Get top-k absolute scores flat = scores.view(-1) _, indices = torch.topk(flat.abs(), k) # Create binary mask mask = torch.zeros_like(flat) mask[indices] = 1.0 mask = mask.view_as(scores) if training: # Straight-through: forward is binary, backward uses scores return mask - scores.detach() + scores return mask def forward( self, x: torch.Tensor, task_id: Optional[int] = None ) -> torch.Tensor: """Forward with task-specific mask.""" task_id = task_id if task_id is not None else self.current_task if task_id not in self.mask_scores: raise ValueError(f"No mask for task {task_id}") masks = self.mask_scores[task_id] h = x for i, layer in enumerate(self.layers): mask = self.get_mask(masks[i], self.training) masked_weight = layer.weight * mask h = F.linear(h, masked_weight, layer.bias) # Apply ReLU to all but last layer if i < len(self.layers) - 1: h = F.relu(h) return h class SparseOverlap: """ Analysis of overlap between sparse task masks. Key insight: With sufficient sparsity, random masks have low overlap, naturally preventing interference. For two random masks with sparsity s, expected overlap is s². With s = 0.1 (10% density), overlap is only 1%! """ @staticmethod def compute_overlap(mask1: torch.Tensor, mask2: torch.Tensor) -> float: """Compute overlap ratio between two binary masks.""" intersection = (mask1 * mask2).sum() union = torch.maximum(mask1, mask2).sum() if union == 0: return 0.0 return (intersection / union).item() @staticmethod def expected_overlap(sparsity: float, n_tasks: int) -> float: """ Compute expected overlap for random sparse masks. For n tasks, the probability that a specific weight is used by all tasks is sparsity^n. """ return sparsity ** n_tasks @staticmethod def max_tasks_estimate(sparsity: float, overlap_threshold: float = 0.01) -> int: """ Estimate max tasks before expected overlap exceeds threshold. overlap = sparsity^n < threshold n > log(threshold) / log(sparsity) """ import math if sparsity >= 1.0 or sparsity <= 0.0: return 1 return int(math.log(overlap_threshold) / math.log(sparsity)) class SparseSplitNetwork: """ Network with structured sparsity splits. Idea: Divide network into non-overlapping sparse regions. Each task gets one region, guaranteeing zero overlap. Trade-off: More structured than random, but capacity is strictly bounded by number of regions. """ def __init__( self, model: nn.Module, n_splits: int = 10 # Max number of tasks ): self.model = model self.n_splits = n_splits # Create non-overlapping index assignments self.weight_assignments = {} for name, param in model.named_parameters(): if 'weight' in name: n = param.numel() # Assign each weight to a split assignments = torch.randperm(n) % n_splits self.weight_assignments[name] = assignments.view_as(param) self.task_split = {} self.next_split = 0 def assign_task(self, task_id: int) -> int: """Assign task to next available split.""" if task_id in self.task_split: return self.task_split[task_id] if self.next_split >= self.n_splits: raise RuntimeError("No more capacity!") self.task_split[task_id] = self.next_split self.next_split += 1 return self.task_split[task_id] def get_task_mask(self, task_id: int) -> Dict[str, torch.Tensor]: """Get binary mask for task's assigned split.""" split = self.task_split[task_id] masks = {} for name, assignments in self.weight_assignments.items(): masks[name] = (assignments == split).float() return masksSupSup's finding that randomly-initialized, frozen networks can achieve competitive performance purely through learned masks is profound. It suggests that much of what we 'learn' in training is actually identifying which parts of an already-capable random network to use. For continual learning, this means task-specific subnetworks can be found without modifying shared weights at all.
Let's compare the architectural approaches we've covered:
| Method | Model Growth | Forgetting | Forward Transfer | Backward Transfer | Main Limitation |
|---|---|---|---|---|---|
| Progressive NN | Linear (full network) | Zero | Via laterals | None | Rapid model growth |
| PackNet | None (fixed) | Zero (frozen) | Uses shared stem | None | Capacity exhaustion |
| Piggyback | Masks only | Zero | From pretrained | None | Requires good pretrained model |
| HAT | Masks only | Near-zero | Learned | Limited | Complex mask optimization |
| DEN | Adaptive (neurons) | Low (selective) | Selective retraining | Possible | Complex pipeline |
| SupSup | Masks only | Zero | Limited | None | Random initialization quality |
Selection Guidelines:\n\nChoose Progressive NN when:\n- You have < 10 tasks and compute is not a constraint\n- Maximum forward transfer is critical\n- Zero forgetting is mandatory\n\nChoose PackNet when:\n- You have a fixed compute/memory budget\n- Task sequence length is known and limited\n- You want simple implementation\n\nChoose Piggyback/SupSup when:\n- You have a strong pretrained model OR don't mind random features\n- Storage is the primary concern (only masks stored)\n- Many tasks expected\n\nChoose DEN when:\n- Task sequence is long/unknown\n- Some backward transfer is desirable\n- You can handle implementation complexity
We have explored the landscape of architectural approaches to continual learning. Let's consolidate the key insights:
What's Next:\n\nIn the final page of this module, we explore evaluation protocols for continual learning—how to properly measure forgetting, forward transfer, and overall performance across task sequences.
You now have comprehensive knowledge of architectural approaches to continual learning. You understand how progressive networks, parameter isolation, mask learning, and dynamic expansion each address the stability-plasticity dilemma with different trade-offs. This knowledge enables you to design architectures for continual learning systems with appropriate guarantees for your application.