Loading content...
Between the extremes of frozen feature extraction and full fine-tuning lies a powerful middle ground: selective fine-tuning. This approach recognizes that different layers in a neural network learn different types of features, and not all layers benefit equally from adaptation to a new task.
The intuition is compelling: early layers in deep networks tend to learn universal, low-level features (edges, textures, basic patterns) that transfer well across tasks. Later layers learn increasingly task-specific, high-level representations that may need modification for new domains. Selective fine-tuning exploits this hierarchy.
By choosing which layers to update and which to freeze, we can:
This page explores the science behind layer-wise transfer, practical strategies for layer selection, and the implementation patterns that make selective fine-tuning effective.
By the end of this page, you will understand the theoretical basis for layer-wise transfer, master techniques like gradual unfreezing and discriminative learning rates, and implement robust selective fine-tuning pipelines for both vision and language models.
To understand why selective fine-tuning works, we must first understand what different layers learn. Decades of research in deep learning visualization have revealed a consistent pattern: neural networks develop hierarchical representations where feature abstraction increases with depth.
Vision Networks:
In convolutional networks trained on image classification, layers develop a clear hierarchy:
Layer 1-2 (Early Layers): Gabor filters, edge detectors, simple color gradients. These features are nearly universal—every vision task needs to detect edges.
Layer 3-4 (Mid-Early Layers): Texture patterns, combinations of edges, basic shapes. Still fairly general but beginning to show domain influence.
Layer 5-7 (Mid Layers): Object parts, complex textures, spatial arrangements. More task-specific features emerge.
Layer 8+ (Late Layers): Object-level concepts, scene understanding, highly abstract representations. Most task-specific; may not transfer well.
Language Models:
Transformers exhibit a similar hierarchy, though the interpretation differs:
Early Layers (1-3): Positional encoding, basic syntax, local context. Word-level processing.
Middle Layers (4-8): Phrase-level semantics, basic reasoning, entity recognition. Compositional understanding.
Late Layers (9-12): Task-specific representations, complex reasoning, high-level abstractions. Most influenced by pre-training task.
| Layer Depth | Vision Features | Language Features | Transferability |
|---|---|---|---|
| Very Early (1-2) | Edges, gradients, colors | Position, basic tokens | Highest (near-universal) |
| Early (3-4) | Textures, simple shapes | Syntax, local patterns | High (domain-general) |
| Middle (5-7) | Object parts, complex patterns | Phrases, entities | Medium (some task specificity) |
| Late (8-10) | Object-level, scene context | Reasoning, abstractions | Lower (task-specific) |
| Very Late (11-12+) | Task-specific concepts | Pre-training task bias | Lowest (may need replacement) |
The Generalization Gradient:
This hierarchy creates what we call the generalization gradient—a continuous decrease in feature universality as depth increases. The practical implication: earlier layers benefit more from pre-training preservation; later layers benefit more from target adaptation.
Research by Yosinski et al. (2014) quantified this phenomenon. They measured how features transfer between halves of ImageNet (random splits) as a function of layer depth:
This gradient informs our selective fine-tuning strategy: preserve early layers, adapt late layers.
Early layers learn features that are expensive to learn (require massive data) but universally useful. Late layers learn features that are cheaper to learn but highly task-specific. Selective fine-tuning maximizes the benefit of pre-training by reusing the expensive general features while cheaply adapting the task-specific ones.
Selecting which layers to fine-tune is both art and science. Several strategies have proven effective in practice, each suited to different scenarios.
Strategy 1: Top-K Layers
The simplest approach: freeze the first N-K layers and fine-tune only the last K layers. This works well when domain similarity is moderate and you want to adapt high-level representations.
Guidelines for K selection:
Strategy 2: Block-Based Selection
Many architectures organize layers into blocks (ResNet stages, Transformer encoder blocks). Fine-tune entire blocks rather than individual layers to maintain internal consistency.
For ResNet:
For BERT:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
import torchimport torch.nn as nnfrom torchvision import modelsfrom typing import List, Optional class SelectiveFineTuning: """ Implements multiple layer selection strategies for selective fine-tuning. Strategies: 1. Top-K layers: Fine-tune only the last K layers 2. Block-based: Fine-tune specific architectural blocks 3. Pattern-based: Fine-tune layers matching a name pattern 4. Gradual unfreezing: Progressively unfreeze layers during training """ @staticmethod def freeze_all(model: nn.Module): """Freeze all parameters in the model.""" for param in model.parameters(): param.requires_grad = False @staticmethod def unfreeze_all(model: nn.Module): """Unfreeze all parameters in the model.""" for param in model.parameters(): param.requires_grad = True @staticmethod def freeze_by_name_pattern(model: nn.Module, patterns: List[str]): """ Freeze layers whose names match any of the given patterns. Example patterns: - "layer1" freezes all params in layer1 - "conv" freezes all conv layers - "bn" freezes all batch norm layers """ for name, param in model.named_parameters(): if any(pattern in name for pattern in patterns): param.requires_grad = False @staticmethod def unfreeze_by_name_pattern(model: nn.Module, patterns: List[str]): """Unfreeze layers whose names match any of the given patterns.""" for name, param in model.named_parameters(): if any(pattern in name for pattern in patterns): param.requires_grad = True @staticmethod def freeze_top_k_strategy( model: nn.Module, num_layers_to_finetune: int ) -> None: """ Freeze all layers except the last K layers. For ResNet: layers are indexed as layer1, layer2, layer3, layer4, fc For BERT: layers are indexed as embeddings, layer.0, layer.1, ..., classifier """ # First freeze everything SelectiveFineTuning.freeze_all(model) # Get all unique layer prefixes layer_prefixes = [] for name, _ in model.named_parameters(): prefix = name.split('.')[0] if prefix not in layer_prefixes: layer_prefixes.append(prefix) # Unfreeze the last K layers layers_to_unfreeze = layer_prefixes[-num_layers_to_finetune:] SelectiveFineTuning.unfreeze_by_name_pattern(model, layers_to_unfreeze) print(f"Unfroze layers: {layers_to_unfreeze}") @staticmethod def freeze_resnet_blocks( model: nn.Module, unfreeze_stages: List[int], unfreeze_fc: bool = True ) -> None: """ Fine-tune specific ResNet stages. ResNet structure: - conv1, bn1: Initial convolution - layer1, layer2, layer3, layer4: Four stages - fc: Final classifier Args: unfreeze_stages: List of stage numbers to unfreeze (1-4) unfreeze_fc: Whether to unfreeze the classifier """ SelectiveFineTuning.freeze_all(model) # Unfreeze specified stages for stage_num in unfreeze_stages: stage_name = f"layer{stage_num}" SelectiveFineTuning.unfreeze_by_name_pattern(model, [stage_name]) # Optionally unfreeze classifier if unfreeze_fc: for param in model.fc.parameters(): param.requires_grad = True # Print summary unfrozen_count = sum(p.requires_grad for p in model.parameters()) total_count = sum(1 for _ in model.parameters()) print(f"Unfroze {unfrozen_count}/{total_count} parameter groups") @staticmethod def freeze_bert_layers( model, # transformers.BertModel or similar unfreeze_from_layer: int, total_layers: int = 12 ) -> None: """ Fine-tune BERT layers starting from a specific layer. Args: unfreeze_from_layer: First layer to unfreeze (0-indexed) total_layers: Total number of encoder layers """ # Freeze embeddings for param in model.embeddings.parameters(): param.requires_grad = False # Freeze/unfreeze encoder layers for i, layer in enumerate(model.encoder.layer): if i >= unfreeze_from_layer: for param in layer.parameters(): param.requires_grad = True else: for param in layer.parameters(): param.requires_grad = False # Unfreeze pooler (if exists) and classifier if hasattr(model, 'pooler'): for param in model.pooler.parameters(): param.requires_grad = True layers_unfrozen = total_layers - unfreeze_from_layer print(f"Unfroze layers {unfreeze_from_layer} to {total_layers-1} ({layers_unfrozen} layers)") def create_selective_model_resnet( num_classes: int, unfreeze_stages: List[int] = [3, 4], pretrained: bool = True) -> nn.Module: """ Create a ResNet-50 with selective fine-tuning configuration. Default: Freeze stages 1-2, fine-tune stages 3-4 and classifier. This is optimal for moderate domain shift with sufficient data. """ model = models.resnet50(pretrained=pretrained) # Replace classifier model.fc = nn.Linear(model.fc.in_features, num_classes) # Apply selective freezing SelectiveFineTuning.freeze_resnet_blocks( model, unfreeze_stages=unfreeze_stages, unfreeze_fc=True ) return model # Example: Count trainable parametersdef count_parameters(model: nn.Module) -> dict: """Count trainable vs total parameters.""" trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) frozen = total - trainable return { "trainable": trainable, "frozen": frozen, "total": total, "trainable_percent": 100 * trainable / total }Strategy 3: Gradual Unfreezing
Rather than fixing which layers to fine-tune from the start, gradual unfreezing progressively unfreezes layers during training. This technique, popularized by the ULMFiT paper (Howard & Ruder, 2018), is particularly effective for language models.
The algorithm:
Benefits:
Strategy 4: Discriminative Learning Rates
Instead of binary freeze/unfreeze, apply different learning rates to different layers. Earlier layers get smaller learning rates; later layers get larger ones. This allows all layers to adapt, but at different speeds.
When freezing layers containing batch normalization, you must also set those layers to eval mode. Otherwise, batch statistics will be computed from target data while weights remain frozen—a mismatch that degrades performance. Use model.eval() for frozen BN layers or explicitly set layer.eval() for each frozen BN layer.
Gradual unfreezing deserves deeper exploration as it represents one of the most sophisticated selective fine-tuning strategies. Let's examine its theoretical motivation and practical implementation.
Theoretical Motivation:
Why does unfreezing layers gradually help? Consider what happens when you fine-tune all layers simultaneously:
Gradual unfreezing addresses this by stabilizing the foundation before modifying it:
This bottom-up (in terms of unfreezing direction) approach creates a stable adaptation path through the parameter space.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
from torch.optim import AdamWfrom torch.optim.lr_scheduler import CosineAnnealingLRfrom typing import List, Tuple, Callable class GradualUnfreezingTrainer: """ Implements gradual unfreezing training strategy. The approach: 1. Start with classifier only 2. Progressively unfreeze layers from output to input 3. Use warmup at each unfreezing step 4. Optionally apply discriminative learning rates """ def __init__( self, model: nn.Module, layer_groups: List[nn.Module], classifier: nn.Module, base_lr: float = 1e-3, lr_decay_factor: float = 2.5, epochs_per_stage: int = 1, warmup_steps: int = 100 ): """ Args: model: The complete model layer_groups: List of layer groups (modules), ordered from input to output classifier: The classifier head base_lr: Learning rate for the shallowest unfrozen layer lr_decay_factor: LR division factor for each deeper layer epochs_per_stage: Epochs to train before unfreezing next group warmup_steps: Warmup steps after each unfreezing """ self.model = model self.layer_groups = layer_groups self.classifier = classifier self.base_lr = base_lr self.lr_decay_factor = lr_decay_factor self.epochs_per_stage = epochs_per_stage self.warmup_steps = warmup_steps # Track unfreezing state self.unfrozen_depth = 0 # 0 = only classifier, 1 = last layer + classifier, etc. # Initially freeze all layers for group in layer_groups: for param in group.parameters(): param.requires_grad = False # Unfreeze classifier for param in classifier.parameters(): param.requires_grad = True self.optimizer = None self.scheduler = None def _create_param_groups_with_discriminative_lr(self) -> List[dict]: """ Create parameter groups with discriminative learning rates. Later layers get higher LR (base_lr). Earlier layers get lower LR (base_lr / decay^depth). """ param_groups = [] # Add unfrozen layer groups (from deepest to shallowest) for i in range(self.unfrozen_depth): layer_idx = len(self.layer_groups) - 1 - i layer = self.layer_groups[layer_idx] # Depth from output (0 = closest to output) depth = i lr = self.base_lr / (self.lr_decay_factor ** depth) param_groups.append({ "params": [p for p in layer.parameters() if p.requires_grad], "lr": lr, "name": f"layer_{layer_idx}" }) # Add classifier with highest LR param_groups.append({ "params": self.classifier.parameters(), "lr": self.base_lr, "name": "classifier" }) return param_groups def unfreeze_next_layer(self): """ Unfreeze the next layer group (moving from output toward input). Returns True if a layer was unfrozen, False if all layers are unfrozen. """ if self.unfrozen_depth >= len(self.layer_groups): return False # Unfreeze the next layer (counting from output) layer_idx = len(self.layer_groups) - 1 - self.unfrozen_depth layer = self.layer_groups[layer_idx] for param in layer.parameters(): param.requires_grad = True self.unfrozen_depth += 1 # Recreate optimizer with new parameter groups param_groups = self._create_param_groups_with_discriminative_lr() self.optimizer = AdamW(param_groups, weight_decay=1e-4) # Create new scheduler # (In practice, you'd configure this based on remaining epochs) self.scheduler = CosineAnnealingLR(self.optimizer, T_max=1000) print(f"Unfroze layer {layer_idx}, now training {self.unfrozen_depth} layer groups + classifier") self._print_lr_summary() return True def _print_lr_summary(self): """Print current learning rates for each parameter group.""" print("Current learning rates:") for group in self.optimizer.param_groups: print(f" {group['name']}: {group['lr']:.6f}") def train_with_gradual_unfreezing( self, train_loader, val_loader, total_epochs: int, criterion: nn.Module, device: str = "cuda" ) -> dict: """ Full training loop with gradual unfreezing. Unfreezes layers based on epochs_per_stage. """ self.model = self.model.to(device) history = {"train_loss": [], "val_acc": [], "unfreezing_events": []} # Initialize with just classifier trainable param_groups = self._create_param_groups_with_discriminative_lr() self.optimizer = AdamW(param_groups, weight_decay=1e-4) current_stage_epochs = 0 for epoch in range(total_epochs): # Check if we should unfreeze next layer if current_stage_epochs >= self.epochs_per_stage: if self.unfreeze_next_layer(): history["unfreezing_events"].append(epoch) current_stage_epochs = 0 # Training epoch self.model.train() total_loss = 0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(device), targets.to(device) self.optimizer.zero_grad() outputs = self.model(inputs) loss = criterion(outputs, targets) loss.backward() self.optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(train_loader) history["train_loss"].append(avg_loss) # Validation val_acc = self._validate(val_loader, device) history["val_acc"].append(val_acc) print(f"Epoch {epoch+1}/{total_epochs}: Loss={avg_loss:.4f}, Val Acc={val_acc:.2f}%") current_stage_epochs += 1 return history def _validate(self, val_loader, device: str) -> float: """Compute validation accuracy.""" self.model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, targets in val_loader: inputs, targets = inputs.to(device), targets.to(device) outputs = self.model(inputs) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() return 100.0 * correct / total # Example usage with ResNetdef setup_gradual_unfreezing_resnet( num_classes: int = 10) -> Tuple[nn.Module, GradualUnfreezingTrainer]: """ Setup ResNet-50 for gradual unfreezing. Layer groups: - Group 0: conv1 + bn1 + relu + maxpool - Group 1: layer1 - Group 2: layer2 - Group 3: layer3 - Group 4: layer4 """ model = models.resnet50(pretrained=True) model.fc = nn.Linear(model.fc.in_features, num_classes) # Define layer groups (input to output order) layer_groups = [ nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool), model.layer1, model.layer2, model.layer3, model.layer4, ] classifier = model.fc trainer = GradualUnfreezingTrainer( model=model, layer_groups=layer_groups, classifier=classifier, base_lr=1e-3, lr_decay_factor=2.5, epochs_per_stage=2 ) return model, trainerPractical Insights from ULMFiT:
The ULMFiT paper established several best practices for gradual unfreezing in language models:
Slanted Triangular Learning Rates (STLR): A learning rate schedule that quickly increases, then slowly decreases. This helps the model rapidly converge early in each stage, then refine.
Layer-wise LR Decay: Each layer gets a learning rate that's η/(2.6^l) where l is depth from output. This ensures early layers change slowly.
Epoch-based Unfreezing: Unfreeze one layer per epoch for language models. Vision models may need 2-3 epochs per layer due to different convergence dynamics.
Full Fine-tuning Phase: After all layers are unfrozen, continue training for several epochs with discriminative rates to fully adapt.
Gradual unfreezing shines when: (1) You have limited target data and need maximum regularization, (2) The domain shift is significant and early adaptation could be destabilizing, (3) You're fine-tuning very large models where stability is crucial. For large datasets with minor domain shift, static selective fine-tuning or full fine-tuning may be equally effective and simpler.
Discriminative learning rates extend the idea of layer selection to a continuous spectrum. Rather than binary freeze/unfreeze decisions, we assign different learning rates to different parts of the network.
The Core Idea:
If earlier layers need less adaptation (their features are more general), they should learn more slowly. If later layers need more adaptation (their features are more task-specific), they should learn faster.
Mathematically, for a network with L layers and base learning rate η:
$$\eta_l = \eta \cdot \alpha^{L-l}$$
where:
Example: For a 12-layer BERT with η = 3e-5 and α = 0.9:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
from torch.optim import AdamWfrom typing import List, Dict, Anyimport math def get_discriminative_lr_params( model: nn.Module, base_lr: float = 1e-3, lr_decay_factor: float = 0.9, group_strategy: str = "layer" # "layer", "block", or "depth") -> List[Dict[str, Any]]: """ Create parameter groups with discriminative learning rates. Args: model: Neural network model base_lr: Learning rate for the output layer lr_decay_factor: Multiplicative decay per layer from output group_strategy: How to group parameters - "layer": Each named parameter group - "block": Architectural blocks (e.g., ResNet stages) - "depth": Based on estimated depth from output Returns: List of parameter group dicts for optimizer """ param_groups = [] if group_strategy == "layer": # Simple approach: group by first name component layer_params = {} for name, param in model.named_parameters(): if not param.requires_grad: continue layer_name = name.split('.')[0] if layer_name not in layer_params: layer_params[layer_name] = [] layer_params[layer_name].append(param) # Assign LRs based on position (later = higher LR) layer_names = list(layer_params.keys()) num_layers = len(layer_names) for i, layer_name in enumerate(layer_names): depth_from_output = num_layers - 1 - i lr = base_lr * (lr_decay_factor ** depth_from_output) param_groups.append({ "params": layer_params[layer_name], "lr": lr, "name": layer_name }) elif group_strategy == "depth": # More sophisticated: compute actual depth based on dependency graph # This requires analyzing the model structure param_groups = _get_depth_based_lr_params(model, base_lr, lr_decay_factor) return param_groups def get_bert_discriminative_lr_params( model, classifier, base_lr: float = 2e-5, lr_decay_factor: float = 0.95) -> List[Dict[str, Any]]: """ Create discriminative LR parameter groups for BERT-like models. Structure assumed: - model.embeddings - model.encoder.layer[0..11] - classifier (separate head) """ param_groups = [] # Classifier gets base LR param_groups.append({ "params": list(classifier.parameters()), "lr": base_lr, "name": "classifier" }) # Pooler (if exists) gets base LR if hasattr(model, 'pooler'): param_groups.append({ "params": list(model.pooler.parameters()), "lr": base_lr, "name": "pooler" }) # Encoder layers: decay from output to input num_layers = len(model.encoder.layer) for i, layer in enumerate(reversed(model.encoder.layer)): layer_num = num_layers - 1 - i depth_from_output = i + 1 # +1 because classifier is at depth 0 lr = base_lr * (lr_decay_factor ** depth_from_output) param_groups.append({ "params": list(layer.parameters()), "lr": lr, "name": f"encoder_layer_{layer_num}" }) # Embeddings get lowest LR embedding_depth = num_layers + 1 embedding_lr = base_lr * (lr_decay_factor ** embedding_depth) param_groups.append({ "params": list(model.embeddings.parameters()), "lr": embedding_lr, "name": "embeddings" }) return param_groups def get_resnet_discriminative_lr_params( model: nn.Module, base_lr: float = 1e-3, lr_decay_factor: float = 0.8) -> List[Dict[str, Any]]: """ Create discriminative LR parameter groups for ResNet. Groups: - classifier (fc): base_lr - layer4: base_lr * decay^1 - layer3: base_lr * decay^2 - layer2: base_lr * decay^3 - layer1: base_lr * decay^4 - stem (conv1, bn1): base_lr * decay^5 """ param_groups = [] # Define groups from output to input groups = [ ("fc", model.fc.parameters()), ("layer4", model.layer4.parameters()), ("layer3", model.layer3.parameters()), ("layer2", model.layer2.parameters()), ("layer1", model.layer1.parameters()), ("stem", [p for n, p in model.named_parameters() if n.startswith(('conv1', 'bn1'))]) ] for depth, (name, params) in enumerate(groups): lr = base_lr * (lr_decay_factor ** depth) param_list = list(params) if hasattr(params, '__iter__') else params if param_list: # Only add if there are parameters param_groups.append({ "params": param_list, "lr": lr, "name": name }) return param_groups class DiscriminativeLRScheduler: """ Learning rate scheduler that maintains discriminative ratios. Standard schedulers modify all groups uniformly, which can disturb the discriminative ratio. This scheduler preserves the relative rates while decaying. """ def __init__( self, optimizer: AdamW, decay_method: str = "cosine", total_steps: int = 10000, warmup_steps: int = 500, min_lr_ratio: float = 0.1 ): self.optimizer = optimizer self.decay_method = decay_method self.total_steps = total_steps self.warmup_steps = warmup_steps self.min_lr_ratio = min_lr_ratio # Store initial LRs self.initial_lrs = [group['lr'] for group in optimizer.param_groups] self.current_step = 0 def step(self): """Update learning rates for all groups.""" self.current_step += 1 if self.current_step <= self.warmup_steps: # Linear warmup warmup_factor = self.current_step / self.warmup_steps else: # Decay phase progress = (self.current_step - self.warmup_steps) / ( self.total_steps - self.warmup_steps ) progress = min(1.0, progress) if self.decay_method == "cosine": decay_factor = 0.5 * (1 + math.cos(math.pi * progress)) elif self.decay_method == "linear": decay_factor = 1 - progress else: decay_factor = 1.0 # Ensure we don't go below min_lr_ratio * initial warmup_factor = max(self.min_lr_ratio, decay_factor) # Apply factor while preserving discriminative ratios for group, initial_lr in zip(self.optimizer.param_groups, self.initial_lrs): group['lr'] = initial_lr * warmup_factorChoosing the Decay Factor:
The decay factor α controls how aggressively learning rates decrease with depth:
α = 0.95-0.99 (Gentle decay): Small differences between layers. Use when domain shift is minimal and you want all layers to adapt somewhat uniformly.
α = 0.85-0.95 (Moderate decay): Clear but not extreme differences. Good default for most transfer scenarios.
α = 0.7-0.85 (Aggressive decay): Large differences; early layers learn very slowly. Use when preserving low-level features is critical.
α < 0.7 (Very aggressive): Early layers essentially frozen. Consider explicit freezing instead for computational efficiency.
| Layer | α = 0.95 | α = 0.90 | α = 0.80 | α = 0.70 |
|---|---|---|---|---|
| Layer 12 (output) | 3.00e-5 | 3.00e-5 | 3.00e-5 | 3.00e-5 |
| Layer 9 | 2.46e-5 | 2.19e-5 | 1.54e-5 | 1.03e-5 |
| Layer 6 | 2.02e-5 | 1.59e-5 | 0.79e-5 | 0.35e-5 |
| Layer 3 | 1.66e-5 | 1.16e-5 | 0.41e-5 | 0.12e-5 |
| Layer 1 (input) | 1.37e-5 | 0.85e-5 | 0.21e-5 | 0.04e-5 |
Discriminative LR is conceptually similar to selective freezing but offers finer control. Instead of binary decisions, you express a preference for how much each layer should adapt. This often works better because even 'universal' features can benefit from tiny adjustments. The key insight: very low learning rates (not zero) for early layers often outperform complete freezing.
The optimal layer selection strategy depends critically on the relationship between source and target domains. Let's analyze how domain shift should inform our choices.
Quantifying Domain Shift:
Domain shift can be characterized along several dimensions:
Each type of shift has different implications for layer selection.
The Domain Similarity-Data Size Matrix:
The interaction between domain similarity and target dataset size creates a decision matrix:
| Data Size | High Similarity | Medium Similarity | Low Similarity |
|---|---|---|---|
| Very Small (<1K) | Freeze all, train classifier | Freeze most, fine-tune last 1-2 | Consider domain adaptation instead |
| Small (1K-10K) | Freeze most, fine-tune last 2-3 | Gradual unfreezing | Gradual unfreezing + regularization |
| Medium (10K-100K) | Fine-tune last 50% | Fine-tune last 60-70% | Fine-tune most layers |
| Large (>100K) | Fine-tune all + discriminative LR | Fine-tune all + discriminative LR | Full fine-tuning optimal |
Practical Example: Medical Imaging
Consider fine-tuning ImageNet-pretrained ResNet-50 for chest X-ray classification:
Domain Analysis:
Recommendation:
Strategy: Freeze layer 1, fine-tune layers 2-4 with discriminative LR (lower for 2, higher for 4), full learning rate for classifier.
These guidelines are informed by research but should be validated empirically for your specific use case. A simple experiment: try freezing at different depths and compare validation performance. The optimal depth is often surprising and dataset-specific. Budget time for this exploration in your project timeline.
Batch normalization layers require special attention during selective fine-tuning. The interaction between frozen weights and batch statistics is subtle but critical.
The Problem:
Batch normalization has two types of parameters:
When you freeze a BN layer's learnable parameters but the layer is in train mode, the running statistics still update based on the target data. This creates a mismatch: the scale/shift parameters were tuned for source statistics but are being applied with target statistics.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
import torch.nn as nnfrom typing import Union def freeze_bn_layers(model: nn.Module): """ Properly freeze batch normalization layers. This sets both: 1. requires_grad = False for learnable params (γ, β) 2. Layer mode to eval (prevents running stats update) """ for module in model.modules(): if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): # Freeze learnable parameters module.weight.requires_grad = False module.bias.requires_grad = False # Set to eval mode to freeze running statistics module.eval() def set_bn_eval_mode(model: nn.Module): """ Set BN layers to eval mode while keeping rest in train mode. Call this after model.train() to ensure BN layers stay in eval mode. """ for module in model.modules(): if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): module.eval() class SelectiveTrainMode: """ Context manager that selectively sets training mode. Usage: with SelectiveTrainMode(model, freeze_bn=True): # BN in eval, rest in train output = model(input) """ def __init__(self, model: nn.Module, freeze_bn: bool = True): self.model = model self.freeze_bn = freeze_bn self.bn_training_states = {} def __enter__(self): self.model.train() if self.freeze_bn: for name, module in self.model.named_modules(): if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): self.bn_training_states[name] = module.training module.eval() return self.model def __exit__(self, *args): # Restore original states if needed pass class FineTuningBatchNormStrategy: """ Different strategies for handling batch normalization during fine-tuning. """ @staticmethod def strategy_freeze_all_bn(model: nn.Module): """ Strategy 1: Freeze all BN layers. Best when: - Source and target domains have similar statistics - You want maximum preservation of pre-trained behavior - Training data is limited """ freeze_bn_layers(model) @staticmethod def strategy_freeze_early_bn(model: nn.Module, freeze_until: str): """ Strategy 2: Freeze BN in early layers only. Best when: - You want to preserve low-level feature statistics - But adapt high-level batch statistics Args: freeze_until: Name prefix to freeze up to (e.g., "layer2") """ freeze_early = True for name, module in model.named_modules(): if freeze_until in name: freeze_early = False if isinstance(module, (nn.BatchNorm2d,)): if freeze_early: module.weight.requires_grad = False module.bias.requires_grad = False module.eval() else: module.weight.requires_grad = True module.bias.requires_grad = True @staticmethod def strategy_bn_calibration( model: nn.Module, calibration_loader, device: str = "cuda" ): """ Strategy 3: Calibrate BN statistics on target domain. After freezing weights, run forward passes on target data to update running statistics. This adapts the normalization to target domain without changing learnable parameters. Best when: - Domain shift is primarily in low-level statistics - You want frozen weights but adapted normalization """ model = model.to(device) model.train() # Enable running stats update # Freeze all parameters for param in model.parameters(): param.requires_grad = False # Reset running statistics for module in model.modules(): if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): module.reset_running_stats() module.momentum = 0.1 # Calibration pass with torch.no_grad(): for inputs, _ in calibration_loader: inputs = inputs.to(device) _ = model(inputs) # Set to eval mode after calibration model.eval() return model @staticmethod def strategy_adaptive_bn(model: nn.Module, momentum: float = 0.1): """ Strategy 4: Let BN statistics adapt during fine-tuning. Use a higher momentum to adapt quickly to target domain. Combined with discriminative LR for BN parameters. Best when: - Target domain has different statistics - You have enough data for stable estimation """ for module in model.modules(): if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): module.momentum = momentum module.weight.requires_grad = True module.bias.requires_grad = TrueThe most common bug in selective fine-tuning is freezing BN parameters but leaving layers in train mode. This causes running statistics to update during fine-tuning, creating a mismatch between learned scale/shift and statistics. The symptom: good training performance but poor test performance, especially on data similar to the source domain. Always call module.eval() for frozen BN layers.
Selective fine-tuning provides a powerful toolkit for balancing knowledge preservation and task adaptation. Let's consolidate the key principles.
What's Next:
With a foundation in both full and selective fine-tuning, we now turn to the critical question of learning rate strategies. The next page explores warmup schedules, decay policies, slanted triangular rates, and the art of finding optimal learning rates for each fine-tuning scenario.
You now understand selective fine-tuning: the theoretical basis in layer hierarchies, practical strategies from top-K to gradual unfreezing, and critical considerations like batch normalization handling. This prepares you for the learning rate strategies covered next.