Loading learning content...
Instead of reweighting samples to account for distributional differences, an alternative strategy is to transform data into a shared representation space where source and target distributions are indistinguishable.
The intuition is compelling: if a learned representation $Z = f(X)$ makes source and target look identical, then a classifier trained on source representations should work equally well on target representations. The domain shift is eliminated by design.
Core Objective: $$\min_f \mathcal{L}_{task}(f, g) + \lambda \cdot d(P_S(f(X)), P_T(f(X)))$$
where $g$ is a classifier, $\mathcal{L}_{task}$ is the supervised loss on source data, and $d$ is a distributional distance measure.
This page covers MMD-based distribution matching, moment matching methods, optimal transport approaches, and the tradeoffs between distribution alignment and task performance.
Maximum Mean Discrepancy (MMD) is a natural choice for distribution matching because it's differentiable and can be estimated from samples.
One of the earliest deep domain adaptation methods adds an MMD penalty to the standard classification loss:
$$\mathcal{L} = \mathcal{L}_{CE}(y_s, \hat{y}_s) + \lambda \cdot \text{MMD}^2(f(X_S), f(X_T))$$
Architecture:
DAN extends DDC by:
$$\mathcal{L} = \mathcal{L}{CE} + \lambda \sum{l \in \mathcal{L}} \text{MK-MMD}(f^{(l)}(X_S), f^{(l)}(X_T))$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
import torchimport torch.nn as nn class MMDLoss(nn.Module): """ Maximum Mean Discrepancy loss for domain adaptation. Measures distance between source and target feature distributions in a reproducing kernel Hilbert space (RKHS). """ def __init__(self, kernel='rbf', bandwidth_list=None): super().__init__() self.kernel = kernel # Multiple bandwidths for robustness self.bandwidth_list = bandwidth_list or [0.1, 1.0, 10.0] def forward(self, source_features, target_features): """Compute MMD² between source and target features.""" batch_size = source_features.size(0) # Combine for efficient kernel computation combined = torch.cat([source_features, target_features], dim=0) # Multi-kernel MMD mmd = 0.0 for bandwidth in self.bandwidth_list: K = self._rbf_kernel(combined, combined, bandwidth) # Split kernel matrix K_ss = K[:batch_size, :batch_size] K_tt = K[batch_size:, batch_size:] K_st = K[:batch_size, batch_size:] # MMD² = E[k(s,s')] + E[k(t,t')] - 2E[k(s,t)] mmd += (K_ss.mean() + K_tt.mean() - 2 * K_st.mean()) return mmd / len(self.bandwidth_list) def _rbf_kernel(self, X, Y, bandwidth): """Compute RBF kernel matrix.""" XX = X @ X.T YY = Y @ Y.T XY = X @ Y.T X_sqnorms = torch.diag(XX) Y_sqnorms = torch.diag(YY) r = X_sqnorms.unsqueeze(1) - 2 * XY + Y_sqnorms.unsqueeze(0) return torch.exp(-r / (2 * bandwidth ** 2)) class DANModel(nn.Module): """ Deep Adaptation Network for domain adaptation. Uses multi-layer MMD matching with multiple kernels. """ def __init__(self, backbone, num_classes, hidden_dim=256): super().__init__() self.backbone = backbone self.bottleneck = nn.Sequential( nn.Linear(backbone.output_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU() ) self.classifier = nn.Linear(hidden_dim, num_classes) self.mmd_loss = MMDLoss() def forward(self, x): features = self.backbone(x) bottleneck = self.bottleneck(features) logits = self.classifier(bottleneck) return logits, bottleneck def compute_loss(self, x_s, y_s, x_t, lambda_mmd=1.0): """Training step with classification + MMD loss.""" logits_s, feat_s = self(x_s) _, feat_t = self(x_t) # Classification loss on source cls_loss = nn.CrossEntropyLoss()(logits_s, y_s) # MMD loss for domain alignment mmd_loss = self.mmd_loss(feat_s, feat_t) total_loss = cls_loss + lambda_mmd * mmd_loss return total_loss, {'cls': cls_loss, 'mmd': mmd_loss}A simpler approach than kernel methods: match the moments (mean, variance, higher-order statistics) of the feature distributions.
Matches central moments up to order $K$:
$$\text{CMD}(S, T) = |\mathbb{E}[S] - \mathbb{E}[T]| + \sum_{k=2}^{K} |C_k(S) - C_k(T)|$$
where $C_k$ is the $k$-th central moment.
Advantages:
BatchNorm implicitly matches the first two moments (mean and variance) of features. Domain-specific BatchNorm allows:
Matches second-order statistics (covariance matrices):
$$\mathcal{L}_{CORAL} = \frac{1}{4d^2}|C_S - C_T|_F^2$$
where $C_S, C_T$ are feature covariance matrices and $d$ is feature dimension.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
import torchimport torch.nn as nn class CORALLoss(nn.Module): """ CORAL: Correlation Alignment for domain adaptation. Minimizes the difference in second-order statistics (covariance matrices) between source and target features. """ def forward(self, source, target): d = source.size(1) # Remove mean source_centered = source - source.mean(dim=0) target_centered = target - target.mean(dim=0) # Covariance matrices cov_source = (source_centered.T @ source_centered) / (source.size(0) - 1) cov_target = (target_centered.T @ target_centered) / (target.size(0) - 1) # Frobenius norm of difference loss = torch.norm(cov_source - cov_target, p='fro') ** 2 return loss / (4 * d * d) class CentralMomentDiscrepancy(nn.Module): """ CMD: Central Moment Discrepancy loss. Matches multiple orders of central moments between distributions. """ def __init__(self, max_order=5): super().__init__() self.max_order = max_order def forward(self, source, target): # First moment (mean) loss = torch.norm(source.mean(0) - target.mean(0)) # Higher-order central moments s_centered = source - source.mean(0) t_centered = target - target.mean(0) for k in range(2, self.max_order + 1): s_moment = (s_centered ** k).mean(0) t_moment = (t_centered ** k).mean(0) loss += torch.norm(s_moment - t_moment) return loss class AdaptiveBatchNorm(nn.Module): """ Adaptive BatchNorm for domain adaptation. Learns to interpolate between source and target statistics. """ def __init__(self, num_features, momentum=0.1): super().__init__() self.bn_source = nn.BatchNorm1d(num_features, momentum=momentum) self.bn_target = nn.BatchNorm1d(num_features, momentum=momentum) self.alpha = nn.Parameter(torch.tensor(0.5)) def forward(self, x, domain='source'): if domain == 'source': return self.bn_source(x) elif domain == 'target': return self.bn_target(x) else: # inference: use interpolated statistics alpha = torch.sigmoid(self.alpha) mean = alpha * self.bn_source.running_mean + (1-alpha) * self.bn_target.running_mean var = alpha * self.bn_source.running_var + (1-alpha) * self.bn_target.running_var return (x - mean) / torch.sqrt(var + 1e-5)Optimal transport (OT) provides a principled framework for measuring and minimizing distributional differences.
The Wasserstein distance measures the minimum cost of transforming one distribution into another:
$$W(P_S, P_T) = \inf_{\gamma \in \Pi(P_S, P_T)} \mathbb{E}_{(x,y) \sim \gamma}[c(x, y)]$$
Key insight: The optimal transport plan $\gamma^*$ tells us which source samples should be "matched" to which target samples.
JDOT simultaneously optimizes the classifier and transport plan:
$$\min_{f, \gamma} \sum_{i,j} \gamma_{ij} \left[ \alpha \cdot c(x_i^s, x_j^t) + \mathcal{L}(y_i^s, f(x_j^t)) \right]$$
This encourages:
Entropically regularized OT enables efficient computation via the Sinkhorn algorithm:
$$W_\epsilon = \min_\gamma \langle \gamma, C \rangle + \epsilon H(\gamma)$$
where $H(\gamma)$ is the entropy of the coupling matrix.
OT respects the geometry of the feature space. Unlike MMD which only compares distributions globally, OT provides a transport plan showing which samples correspond across domains. This enables sample-level adaptation and meaningful interpolation.
A critical insight: perfect domain alignment can hurt performance. If we force representations to be identical across domains, we may discard information useful for classification.
We want representations $Z$ that:
But these can conflict! Some domain-specific features may be predictive.
Over-aggressive alignment causes negative transfer:
A solution: match distributions within each class, not globally:
$$\mathcal{L} = \sum_{c=1}^C d(P_S(Z | Y=c), P_T(Z | Y=c))$$
Challenge: We don't have labels in target domain!
Pseudo-labeling approach:
The next page covers adversarial approaches—using a domain discriminator to force the feature extractor to produce domain-invariant representations.