Loading learning content...
Fine-tuning faces a fundamental challenge: overfitting to the target dataset. Pre-trained models have enormous capacity—millions or billions of parameters. Target datasets are often small. This combination creates a recipe for memorization rather than generalization.
Regularization is the art of constraining model capacity to match data complexity. In fine-tuning, regularization serves dual purposes:
This page explores regularization techniques specifically tailored for transfer learning: weight decay strategies, dropout configurations, data augmentation policies, and specialized methods like mixup and label smoothing.
By the end of this page, you will understand how standard regularization techniques adapt for fine-tuning, implement effective data augmentation for transfer learning, and apply advanced methods like mixup and label smoothing.
Weight decay (L2 regularization) penalizes large weight magnitudes. In fine-tuning, it has a special interpretation: it pulls weights back toward zero (or toward pre-trained values with a modification).
Standard L2 Regularization: $$\mathcal{L}{total} = \mathcal{L}{task} + \frac{\lambda}{2} |\theta|^2$$
This penalizes deviation from zero. But for fine-tuning, we might want to penalize deviation from the pre-trained weights instead:
L2-SP (L2 Starting Point): $$\mathcal{L}{total} = \mathcal{L}{task} + \frac{\lambda}{2} |\theta - \theta^*|^2$$
where θ* are the pre-trained weights. This encourages staying near the pre-trained solution.
Decoupled Weight Decay (AdamW):
Traditional optimizers apply weight decay as part of the gradient update. AdamW decouples them, leading to better generalization:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
import torchimport torch.nn as nnfrom torch.optim import AdamWfrom copy import deepcopy class L2SPRegularizer: """ L2-SP: L2 regularization toward Starting Point (pre-trained weights). More effective than standard L2 for fine-tuning because it preserves pre-trained knowledge as the reference point. """ def __init__(self, model: nn.Module, alpha: float = 0.01, beta: float = 0.01): """ Args: alpha: Regularization strength for pre-trained layers beta: Regularization strength for new layers (toward zero) """ self.alpha = alpha self.beta = beta # Store pre-trained weights self.pretrained_weights = {} self.new_layer_names = [] def register_pretrained(self, model: nn.Module, new_layer_patterns: list = ['fc', 'classifier']): """Store pre-trained weights and identify new layers.""" for name, param in model.named_parameters(): is_new = any(pattern in name for pattern in new_layer_patterns) if is_new: self.new_layer_names.append(name) else: self.pretrained_weights[name] = param.clone().detach() def penalty(self, model: nn.Module) -> torch.Tensor: """Compute L2-SP penalty.""" loss = 0.0 for name, param in model.named_parameters(): if name in self.pretrained_weights: # Regularize toward pre-trained weights diff = param - self.pretrained_weights[name] loss += self.alpha * (diff ** 2).sum() elif name in self.new_layer_names: # Regularize new layers toward zero loss += self.beta * (param ** 2).sum() return loss def create_optimizer_with_decay( model: nn.Module, learning_rate: float = 1e-4, weight_decay: float = 0.01, no_decay_patterns: list = ['bias', 'LayerNorm', 'layer_norm', 'bn']) -> AdamW: """ Create AdamW optimizer with proper weight decay configuration. Excludes bias and normalization parameters from weight decay as these should not be regularized. """ decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if any(pattern in name for pattern in no_decay_patterns): no_decay_params.append(param) else: decay_params.append(param) param_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': no_decay_params, 'weight_decay': 0.0} ] return AdamW(param_groups, lr=learning_rate)| Scenario | Weight Decay | Notes |
|---|---|---|
| Large target dataset | 0.0001 - 0.001 | Less regularization needed |
| Small target dataset | 0.01 - 0.1 | Strong regularization important |
| BERT-family models | 0.01 - 0.1 | Transformers benefit from higher decay |
| Vision CNNs | 0.0001 - 0.01 | Standard CV range |
| Very different domains | Higher | More regularization preserves features |
Bias terms and normalization layer parameters (BatchNorm γ, β; LayerNorm) should not have weight decay applied. These parameters have different roles than weights and regularizing them can hurt performance.
Dropout randomly zeros activations during training, forcing the network to learn redundant representations. For fine-tuning, dropout placement and rate require careful consideration.
Key Principles:
Higher dropout for fine-tuning: Pre-trained models are already regularized by their training data. Fine-tuning on small data benefits from additional dropout.
Dropout in new layers: Focus dropout on newly added layers (classifiers, adapters) which are prone to overfitting.
Consistent dropout in backbone: If the pre-trained model uses dropout, keep it. Don't add dropout to layers that weren't designed for it.
Dropout Placement:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
import torch.nn as nn class FineTuningHead(nn.Module): """ Classification head with dropout for fine-tuning. Two dropout layers: after pooled features and before final classifier. This is more aggressive than typical training-from-scratch heads. """ def __init__( self, in_features: int, num_classes: int, hidden_dim: int = 512, dropout_rate: float = 0.3, # Higher than typical 0.1-0.2 use_hidden_layer: bool = True ): super().__init__() if use_hidden_layer: self.classifier = nn.Sequential( nn.Dropout(p=dropout_rate), nn.Linear(in_features, hidden_dim), nn.ReLU(inplace=True), nn.Dropout(p=dropout_rate), nn.Linear(hidden_dim, num_classes) ) else: self.classifier = nn.Sequential( nn.Dropout(p=dropout_rate), nn.Linear(in_features, num_classes) ) def forward(self, x): return self.classifier(x) def adjust_dropout_for_finetuning( model: nn.Module, new_dropout_rate: float = 0.3, only_layers: list = None): """ Adjust dropout rates in an existing model for fine-tuning. Args: model: The model to modify new_dropout_rate: New dropout probability only_layers: If provided, only modify these layer name patterns """ for name, module in model.named_modules(): if isinstance(module, nn.Dropout): if only_layers is None or any(pattern in name for pattern in only_layers): module.p = new_dropout_rate print(f"Set {name} dropout to {new_dropout_rate}") class DropConnect(nn.Module): """ DropConnect: Drop connections (weights) instead of activations. Can be more effective than dropout for fine-tuning as it regularizes the weight matrix directly. """ def __init__(self, linear: nn.Linear, drop_prob: float = 0.2): super().__init__() self.linear = linear self.drop_prob = drop_prob def forward(self, x): if self.training: # Create weight mask mask = torch.bernoulli( torch.ones_like(self.linear.weight) * (1 - self.drop_prob) ) # Apply masked weights return nn.functional.linear( x, self.linear.weight * mask / (1 - self.drop_prob), self.linear.bias ) return self.linear(x)Small dataset (<5K samples): 0.3-0.5 dropout in new layers. Medium dataset (5K-50K): 0.2-0.3. Large dataset (>50K): 0.1-0.2 or use pre-trained model's rates. If validation loss diverges from training loss, increase dropout.
Data augmentation artificially expands the training set by creating modified versions of examples. For fine-tuning, augmentation is crucial when target data is limited.
Principles for Fine-Tuning:
Match pre-training augmentations: If the pre-trained model saw certain augmentations, continue using them.
Domain-appropriate augmentations: Medical images shouldn't be flipped vertically (anatomy isn't symmetric). Satellite images can be rotated 360°.
Stronger augmentation for smaller data: AutoAugment, RandAugment, and similar policies help with limited data.
Test-time augmentation (TTA): Average predictions over augmented versions at inference for better accuracy.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
import torchfrom torchvision import transformsimport random # Standard augmentation for ImageNet-pretrained modelsimagenet_finetune_transforms = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]) # Stronger augmentation for small datasetsstrong_augmentation = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.6, 1.0)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), transforms.RandomGrayscale(p=0.1), transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.RandomErasing(p=0.2),]) class RandAugmentTransform: """ RandAugment: Randomly select N augmentations with magnitude M. Simple yet effective; good default for fine-tuning. """ def __init__(self, n_ops: int = 2, magnitude: int = 9): self.n_ops = n_ops self.magnitude = magnitude self.augmentations = [ 'rotate', 'translateX', 'translateY', 'shearX', 'shearY', 'brightness', 'contrast', 'saturation', 'sharpness', 'posterize' ] def __call__(self, img): ops = random.sample(self.augmentations, self.n_ops) for op in ops: img = self._apply_op(img, op, self.magnitude) return img def _apply_op(self, img, op, magnitude): # Implementation details for each operation # (Simplified - use torchvision.transforms.RandAugment in practice) return img # Placeholder def get_augmentation_for_domain(domain: str, dataset_size: int): """ Select appropriate augmentation based on domain and data size. """ if domain == "natural_images": if dataset_size < 5000: return strong_augmentation return imagenet_finetune_transforms elif domain == "medical": # Conservative augmentation for medical images return transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.9, 1.0)), transforms.RandomRotation(10), # Slight rotation only transforms.ColorJitter(brightness=0.1, contrast=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) elif domain == "satellite": # Satellite images can be heavily augmented return transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(180), # Any rotation is valid transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) return imagenet_finetune_transformsInappropriate augmentations can harm fine-tuning. Horizontally flipping text makes it unreadable. Aggressive color jitter on medical images changes diagnostic information. Always consider what transformations preserve the semantic meaning in your domain.
Beyond standard techniques, several advanced methods improve fine-tuning:
Mixup: Blend pairs of training examples and their labels: $$\tilde{x} = \lambda x_i + (1-\lambda) x_j$$ $$\tilde{y} = \lambda y_i + (1-\lambda) y_j$$
where λ ~ Beta(α, α). This creates soft training targets that improve generalization.
Label Smoothing: Replace hard labels (0 or 1) with soft labels: $$y_{smooth} = (1-\epsilon)y + \epsilon/K$$
where ε is smoothing factor and K is number of classes. Prevents overconfident predictions.
CutMix: Cut and paste patches between images, mixing labels proportionally. More effective than Mixup for vision tasks.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as np def mixup_data(x, y, alpha=0.4): """ Mixup: Creates mixed inputs and targets. Returns: mixed_x: Lambda * x + (1-lambda) * x[shuffled] y_a, y_b: Original labels for mixup loss lam: Mixing coefficient """ if alpha > 0: lam = np.random.beta(alpha, alpha) else: lam = 1 batch_size = x.size(0) index = torch.randperm(batch_size).to(x.device) mixed_x = lam * x + (1 - lam) * x[index] y_a, y_b = y, y[index] return mixed_x, y_a, y_b, lam def mixup_criterion(criterion, pred, y_a, y_b, lam): """Compute mixup loss.""" return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) class LabelSmoothingLoss(nn.Module): """ Cross-entropy loss with label smoothing. Prevents overconfident predictions, improving generalization. """ def __init__(self, num_classes: int, smoothing: float = 0.1): super().__init__() self.num_classes = num_classes self.smoothing = smoothing self.confidence = 1.0 - smoothing def forward(self, pred, target): pred = pred.log_softmax(dim=-1) with torch.no_grad(): true_dist = torch.zeros_like(pred) true_dist.fill_(self.smoothing / (self.num_classes - 1)) true_dist.scatter_(1, target.unsqueeze(1), self.confidence) return torch.mean(torch.sum(-true_dist * pred, dim=-1)) def cutmix_data(x, y, alpha=1.0): """ CutMix: Cut and paste patches between images. More effective than Mixup for localized features. """ lam = np.random.beta(alpha, alpha) batch_size = x.size(0) index = torch.randperm(batch_size).to(x.device) # Get bounding box W, H = x.size(2), x.size(3) cut_rat = np.sqrt(1 - lam) cut_w = int(W * cut_rat) cut_h = int(H * cut_rat) cx = np.random.randint(W) cy = np.random.randint(H) bbx1 = np.clip(cx - cut_w // 2, 0, W) bby1 = np.clip(cy - cut_h // 2, 0, H) bbx2 = np.clip(cx + cut_w // 2, 0, W) bby2 = np.clip(cy + cut_h // 2, 0, H) # Apply cutmix x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2] # Adjust lambda for actual box size lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H)) return x, y, y[index], lam| Technique | Parameter | Recommended Range | Notes |
|---|---|---|---|
| Mixup | α (beta param) | 0.2 - 0.4 | Higher = more mixing |
| Label Smoothing | ε (smoothing) | 0.1 | Standard value works well |
| CutMix | α (beta param) | 1.0 | Full mixing typically used |
| Combine | Prob of each | 0.5 | Randomly choose one per batch |
Module Complete:
You have now mastered the complete toolkit for fine-tuning strategies: from full and selective fine-tuning, through learning rate orchestration, catastrophic forgetting mitigation, to comprehensive regularization. These techniques form the foundation for effective transfer learning in real-world applications.
Congratulations! You've completed Module 3: Fine-Tuning Strategies. You now have a comprehensive understanding of how to adapt pre-trained models effectively for new tasks while preserving valuable knowledge and preventing overfitting.