Loading content...
When domain discrepancy is too large for frozen features but fine-tuning is impractical (limited target data, compute constraints), feature adaptation offers a middle ground. These techniques transform pre-trained features to better align with the target domain while avoiding full model retraining.
Feature adaptation methods operate on the extracted representations, learning transformations that:
This page covers the theory and practice of feature adaptation: from simple linear transformations to sophisticated domain alignment techniques.
Understand feature transformation methods, domain alignment losses, adaptation architectures, and when to use feature adaptation versus other transfer strategies.
The simplest adaptation learns a linear transformation of frozen features:
$$z_{\text{adapted}} = W z + b$$
where $W \in \mathbb{R}^{d' \times d}$ and $b \in \mathbb{R}^{d'}$ are learned parameters.
Why linear adaptation works:
Pre-trained features often encode relevant information, but in a "rotated" or "scaled" form relative to what the target task needs. A linear transformation can:
Comparison to linear probe:
| Method | Transforms Features | Classification | Parameters |
|---|---|---|---|
| Linear Probe | No | Linear | d × K |
| Linear Adaptation + Linear | Yes | Linear | d × d' + d' × K |
Linear adaptation adds a feature transformation layer before classification.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
import torchimport torch.nn as nn class LinearAdapter(nn.Module): """ Linear feature adaptation layer. Learns to transform frozen features for the target task. """ def __init__( self, input_dim: int, output_dim: int = None, normalize: bool = True, residual: bool = False ): super().__init__() output_dim = output_dim or input_dim self.transform = nn.Linear(input_dim, output_dim) self.normalize = normalize self.residual = residual and (input_dim == output_dim) # Initialize close to identity for stable training if input_dim == output_dim: nn.init.eye_(self.transform.weight) nn.init.zeros_(self.transform.bias) else: nn.init.orthogonal_(self.transform.weight) def forward(self, x): adapted = self.transform(x) if self.residual: adapted = adapted + x if self.normalize: adapted = nn.functional.normalize(adapted, p=2, dim=1) return adapted class AdaptedClassifier(nn.Module): """Complete pipeline: frozen encoder → adapter → classifier.""" def __init__( self, encoder: nn.Module, adapter: nn.Module, num_classes: int ): super().__init__() self.encoder = encoder self.adapter = adapter # Freeze encoder for p in self.encoder.parameters(): p.requires_grad = False # Classifier on adapted features adapter_dim = adapter.transform.out_features self.classifier = nn.Linear(adapter_dim, num_classes) def forward(self, x): with torch.no_grad(): features = self.encoder(x) adapted = self.adapter(features) return self.classifier(adapted) def get_trainable_params(self): return list(self.adapter.parameters()) + list(self.classifier.parameters())When unlabeled target data is available, we can learn adaptations that explicitly align source and target distributions.
Maximum Mean Discrepancy (MMD) Loss:
Minimize the MMD between adapted source and target features:
$$\mathcal{L}{\text{MMD}} = |\frac{1}{n_s}\sum_i \phi(z^s_i) - \frac{1}{n_t}\sum_j \phi(z^t_j)|{\mathcal{H}}^2$$
Training minimizes this distance, encouraging the adaptation to produce similar distributions.
Correlation Alignment (CORAL):
Align second-order statistics (covariances):
$$\mathcal{L}_{\text{CORAL}} = \frac{1}{4d^2} |C_S - C_T|_F^2$$
where $C_S$ and $C_T$ are covariance matrices of adapted source and target features.
Optimal Transport:
Find the transport plan that minimizes the cost of moving source to target:
$$\mathcal{L}{\text{OT}} = \min{\gamma} \sum_{i,j} \gamma_{ij} |z^s_i - z^t_j|^2$$
subject to marginal constraints.
CORAL only matches means and covariances, which is computationally cheap and often sufficient. It's a strong baseline before trying more complex domain alignment methods.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
import torchimport torch.nn as nn def coral_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ CORAL loss: align covariance matrices of source and target. Args: source: (n_s, d) source features target: (n_t, d) target features """ d = source.size(1) # Center features source_centered = source - source.mean(dim=0, keepdim=True) target_centered = target - target.mean(dim=0, keepdim=True) # Covariance matrices cov_source = (source_centered.T @ source_centered) / (source.size(0) - 1) cov_target = (target_centered.T @ target_centered) / (target.size(0) - 1) # Frobenius norm of difference loss = torch.norm(cov_source - cov_target, p='fro') ** 2 return loss / (4 * d * d) def mmd_loss(source: torch.Tensor, target: torch.Tensor, kernel: str = "rbf") -> torch.Tensor: """ Maximum Mean Discrepancy loss with RBF kernel. """ def rbf_kernel(x, y, gamma=1.0): dist = torch.cdist(x, y, p=2) return torch.exp(-gamma * dist ** 2) # Median heuristic for bandwidth combined = torch.cat([source, target], dim=0) dists = torch.cdist(combined, combined) gamma = 1.0 / (2 * dists.median() ** 2 + 1e-8) K_ss = rbf_kernel(source, source, gamma) K_tt = rbf_kernel(target, target, gamma) K_st = rbf_kernel(source, target, gamma) mmd = K_ss.mean() + K_tt.mean() - 2 * K_st.mean() return mmd class DomainAlignedAdapter(nn.Module): """Adapter trained with domain alignment loss.""" def __init__( self, input_dim: int, hidden_dim: int = 256, output_dim: int = 128, alignment: str = "coral" # or "mmd" ): super().__init__() self.adapter = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim) ) self.alignment = alignment def forward(self, x): return self.adapter(x) def alignment_loss(self, source_features, target_features): adapted_source = self.adapter(source_features) adapted_target = self.adapter(target_features) if self.alignment == "coral": return coral_loss(adapted_source, adapted_target) elif self.alignment == "mmd": return mmd_loss(adapted_source, adapted_target)Recent research has developed efficient adapter architectures that modify pre-trained models with minimal parameters.
Bottleneck Adapters:
Insert small trainable modules between frozen layers:
Frozen Layer → Adapter → Frozen Layer
Adapter structure:
Total parameters: $2 \times d \times r$, where $r \ll d$.
LoRA (Low-Rank Adaptation):
Modify weight matrices with low-rank updates:
$$W' = W + \Delta W = W + BA$$
where $B \in \mathbb{R}^{d \times r}$, $A \in \mathbb{R}^{r \times d}$, and $r \ll d$.
Only $A$ and $B$ are trained; $W$ stays frozen.
Comparison:
| Method | Where Applied | Parameters | Integration |
|---|---|---|---|
| Linear Adapter | After encoder | d × d' | Separate layer |
| Bottleneck | Between layers | 2 × d × r | Inserted modules |
| LoRA | Weight matrices | 2 × d × r per layer | Merged with weights |
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
import torchimport torch.nn as nn class BottleneckAdapter(nn.Module): """ Bottleneck adapter module for parameter-efficient adaptation. """ def __init__( self, input_dim: int, bottleneck_dim: int = 64, activation: str = "relu" ): super().__init__() self.down = nn.Linear(input_dim, bottleneck_dim) self.up = nn.Linear(bottleneck_dim, input_dim) if activation == "relu": self.act = nn.ReLU() elif activation == "gelu": self.act = nn.GELU() # Initialize for near-identity nn.init.zeros_(self.up.weight) nn.init.zeros_(self.up.bias) def forward(self, x): # Residual connection ensures stability return x + self.up(self.act(self.down(x))) class LoRALayer(nn.Module): """ Low-Rank Adaptation for a linear layer. """ def __init__( self, original_layer: nn.Linear, rank: int = 4, alpha: float = 1.0 ): super().__init__() self.original = original_layer in_features = original_layer.in_features out_features = original_layer.out_features # Low-rank matrices self.lora_A = nn.Parameter(torch.randn(rank, in_features) * 0.01) self.lora_B = nn.Parameter(torch.zeros(out_features, rank)) self.scaling = alpha / rank # Freeze original for p in self.original.parameters(): p.requires_grad = False def forward(self, x): # Original computation + low-rank update original_out = self.original(x) lora_out = (x @ self.lora_A.T @ self.lora_B.T) * self.scaling return original_out + lora_out def merge_weights(self): """Merge LoRA weights into original for inference.""" self.original.weight.data += ( self.lora_B @ self.lora_A * self.scaling ) def add_adapters_to_model(model: nn.Module, adapter_type: str = "bottleneck", bottleneck_dim: int = 64) -> nn.Module: """Add adapters after each transformer/conv block.""" for name, module in model.named_children(): if isinstance(module, nn.Linear): if adapter_type == "bottleneck": adapter = BottleneckAdapter(module.out_features, bottleneck_dim) setattr(model, name, nn.Sequential(module, adapter)) else: add_adapters_to_model(module, adapter_type, bottleneck_dim) return modelTraining feature adapters requires balancing task performance with domain alignment.
Combined loss function:
$$\mathcal{L} = \mathcal{L}{\text{task}}(\text{labeled source}) + \lambda \mathcal{L}{\text{align}}(\text{source}, \text{target})$$
where:
Curriculum for λ:
Start with small λ (focus on task), increase over training:
If λ is too high, the model may learn to map all features to the same point (trivial alignment). Monitor feature diversity and classification loss. If task loss suddenly increases, reduce λ.
Feature adaptation occupies a specific niche in the transfer learning toolkit.
| Scenario | Frozen Features | Feature Adaptation | Fine-Tuning |
|---|---|---|---|
| High domain similarity | ✓ Best | Unnecessary | Overkill |
| Moderate similarity, limited target labels | May underperform | ✓ Best | Overfitting risk |
| Moderate similarity, ample target labels | Okay | Good | ✓ Best |
| Low similarity, unlabeled target data available | Fails | ✓ Best | Needs labels |
| Low similarity, labeled target data | Fails | Good | ✓ Best |
| Compute constrained | ✓ Best | Good | Too expensive |
Feature adaptation sweet spots:
Combining with other methods:
Feature adaptation can complement:
Module Complete!
You've now completed Module 2: Feature-Based Transfer. You understand:
The next module explores Fine-Tuning Strategies—going beyond feature-based methods to update the entire model for optimal target performance.
Congratulations! You've mastered feature-based transfer learning—from frozen features through advanced adaptation techniques. These methods form the foundation for efficient knowledge transfer in modern ML.