Loading content...
Domain adaptation assumes we have access to target domain data (even if unlabeled). But what if we don't? What if we need to deploy to domains we've never seen and can't anticipate?
Domain Generalization (DG) addresses this challenge: train on multiple source domains such that the model generalizes to any new target domain.
Problem Setting:
This is arguably the most realistic setting for production ML—we can't collect data from every possible deployment context.
This page covers domain-invariant representation learning, data augmentation strategies, meta-learning for generalization, ensemble and regularization approaches, and evaluation protocols for domain generalization.
The core idea: learn features that are invariant across all training domains. If features don't change across known domains, hopefully they won't change for unseen domains either.
Minimize pairwise MMD between all source domains:
$$\mathcal{L}{inv} = \sum{i<j} \text{MMD}(P_i(Z), P_j(Z))$$
This encourages the feature extractor to remove domain-specific information.
IRM seeks representations where the optimal classifier is the same across all domains:
$$\min_\phi \sum_e R^e(\phi) + \lambda |\nabla_{w|w=1} R^e(w \cdot \phi)|^2$$
where $R^e$ is risk in environment/domain $e$, and the penalty ensures the optimal linear classifier on top of $\phi$ is the same for all domains.
Key Insight: If a feature enables the same classifier across domains, it captures invariant (causal) relationships rather than spurious correlations.
IRM is inspired by causality: causal features have invariant relationships with the outcome, while spurious features (confounders) vary across environments. Domain generalization can be viewed as finding causal features.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
import torchimport torch.nn as nn class IRMPenalty(nn.Module): """ Invariant Risk Minimization penalty. Encourages the optimal classifier to be the same across all environments/domains by penalizing gradient magnitude. """ def __init__(self, lambda_irm=1.0): super().__init__() self.lambda_irm = lambda_irm def forward(self, logits, labels, env_indices): """ Compute IRM penalty across environments. Args: logits: Model predictions (before softmax) labels: True labels env_indices: Environment/domain index for each sample """ unique_envs = torch.unique(env_indices) penalties = [] for env in unique_envs: mask = (env_indices == env) env_logits = logits[mask] env_labels = labels[mask] # Create dummy scalar w to compute gradient scale = torch.ones(1, requires_grad=True, device=logits.device) env_loss = nn.CrossEntropyLoss()(env_logits * scale, env_labels) # Penalty: squared gradient norm grad = torch.autograd.grad( env_loss, scale, create_graph=True )[0] penalties.append(grad ** 2) return self.lambda_irm * torch.stack(penalties).mean() class DomainGeneralizationModel(nn.Module): """ Multi-source domain generalization model. Combines: - Feature extraction with domain-invariant objective - Classification on source domains - Regularization for generalization """ def __init__(self, backbone, num_classes, num_domains): super().__init__() self.backbone = backbone self.classifier = nn.Linear(backbone.output_dim, num_classes) # Domain-specific batch norm self.domain_bn = nn.ModuleList([ nn.BatchNorm1d(backbone.output_dim) for _ in range(num_domains) ]) self.irm_penalty = IRMPenalty(lambda_irm=1.0) def forward(self, x, domain_idx=None): features = self.backbone(x) # Apply domain-specific normalization during training if domain_idx is not None and self.training: features = self.domain_bn[domain_idx](features) else: # Inference: average over domain BNs bn_outputs = [bn(features) for bn in self.domain_bn] features = torch.stack(bn_outputs).mean(0) logits = self.classifier(features) return logits, features def train_dg_model(model, domain_loaders, optimizer, num_epochs, device): """ Training loop for domain generalization. """ criterion = nn.CrossEntropyLoss() for epoch in range(num_epochs): model.train() # Sample batches from all domains domain_iters = [iter(loader) for loader in domain_loaders] for step in range(min(len(l) for l in domain_loaders)): all_features = [] all_logits = [] all_labels = [] all_domains = [] for d_idx, d_iter in enumerate(domain_iters): x, y = next(d_iter) x, y = x.to(device), y.to(device) logits, features = model(x, domain_idx=d_idx) all_features.append(features) all_logits.append(logits) all_labels.append(y) all_domains.append(torch.full_like(y, d_idx)) features = torch.cat(all_features) logits = torch.cat(all_logits) labels = torch.cat(all_labels) domains = torch.cat(all_domains) # Classification loss cls_loss = criterion(logits, labels) # IRM penalty for invariance irm_loss = model.irm_penalty(logits, labels, domains) # Domain alignment via MMD mmd_loss = compute_multi_domain_mmd(all_features) loss = cls_loss + irm_loss + 0.1 * mmd_loss optimizer.zero_grad() loss.backward() optimizer.step()If we can't access the target domain, we can try to synthesize variations that cover potential target distributions.
Randomize domain-specific aspects during training:
Philosophy: If training covers extreme variations, real-world domains likely fall within this coverage.
Interpolate between samples from different domains:
$$\tilde{x} = \lambda x_i^{d_1} + (1-\lambda) x_j^{d_2}$$ $$\tilde{y} = \lambda y_i + (1-\lambda) y_j$$
This creates implicit "in-between" domains, smoothing the representation space.
Use adversarial perturbations that increase domain classifier confidence:
$$x' = x + \epsilon \cdot \text{sign}(\nabla_x \mathcal{L}_{domain})$$
This synthesizes samples that push toward other domains while preserving labels.
Meta-learning simulates domain shift during training by using source domains as pseudo train/test splits.
Algorithm:
This simulates the generalization gap and learns to minimize it.
Each training episode:
The model learns to generalize to "unseen" domains by repeatedly practicing this.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
import torchimport torch.nn as nnimport copy class MLDG: """ Meta-Learning Domain Generalization. Simulates domain shift by splitting source domains into meta-train and meta-test, then optimizing for generalization. """ def __init__(self, model, inner_lr=0.01, meta_lr=0.001): self.model = model self.inner_lr = inner_lr self.meta_optimizer = torch.optim.Adam( model.parameters(), lr=meta_lr ) def meta_update(self, domain_data, device): """ One meta-learning update step. Args: domain_data: Dict mapping domain_id -> (x, y) batches """ domains = list(domain_data.keys()) # Split domains into meta-train and meta-test meta_test_domain = domains[-1] # or random selection meta_train_domains = domains[:-1] # Clone model for inner loop inner_model = copy.deepcopy(self.model) inner_optimizer = torch.optim.SGD( inner_model.parameters(), lr=self.inner_lr ) # Inner loop: update on meta-train domains for domain in meta_train_domains: x, y = domain_data[domain] x, y = x.to(device), y.to(device) logits = inner_model(x) loss = nn.CrossEntropyLoss()(logits, y) inner_optimizer.zero_grad() loss.backward() inner_optimizer.step() # Outer loop: evaluate on meta-test domain x_test, y_test = domain_data[meta_test_domain] x_test, y_test = x_test.to(device), y_test.to(device) logits_test = inner_model(x_test) meta_loss = nn.CrossEntropyLoss()(logits_test, y_test) # Compute gradients w.r.t. original model # (through the inner loop computation) self.meta_optimizer.zero_grad() # Approximate: use gradients from inner model for p_orig, p_inner in zip( self.model.parameters(), inner_model.parameters() ): if p_inner.grad is not None: p_orig.grad = p_inner.grad.clone() self.meta_optimizer.step() return meta_loss.item() class EpisodicTrainer: """ Episodic training for domain generalization. Each episode holds out one domain and optimizes for generalization to that domain. """ def __init__(self, model, domains, lr=0.001): self.model = model self.domains = domains self.optimizer = torch.optim.Adam(model.parameters(), lr=lr) def train_episode(self, domain_loaders, device): """Train one episode with leave-one-domain-out.""" import random # Randomly select held-out domain held_out = random.choice(self.domains) train_domains = [d for d in self.domains if d != held_out] self.model.train() # Train on other domains train_loss = 0.0 for domain in train_domains: x, y = next(iter(domain_loaders[domain])) x, y = x.to(device), y.to(device) self.optimizer.zero_grad() logits = self.model(x) loss = nn.CrossEntropyLoss()(logits, y) loss.backward() self.optimizer.step() train_loss += loss.item() # Evaluate on held-out domain self.model.eval() with torch.no_grad(): x_test, y_test = next(iter(domain_loaders[held_out])) x_test, y_test = x_test.to(device), y_test.to(device) logits = self.model(x_test) val_loss = nn.CrossEntropyLoss()(logits, y_test) return train_loss / len(train_domains), val_loss.item()Train separate models per domain, aggregate at test time:
$$f(x) = \sum_k w_k f_k(x)$$
Weights can be uniform or learned based on similarity to the test sample.
Remove the most predictive features during training to force use of more diverse features:
Flat minima in loss landscape generalize better. Techniques:
| Category | Methods | Key Idea |
|---|---|---|
| Invariant Features | IRM, MMD, CORAL | Remove domain-specific information |
| Augmentation | MixUp, StyleAug, RandConv | Expand training distribution coverage |
| Meta-Learning | MLDG, Feature-Critic | Simulate generalization during training |
| Ensemble | DoA, Expert mixture | Combine domain-specific knowledge |
| Regularization | RSC, SAM, Fish | Encourage robust representations |
Proper evaluation of domain generalization is challenging. Key considerations:
For K source domains:
PACS: Photo, Art painting, Cartoon, Sketch (4 domains) Office-Home: Art, Clipart, Product, Real-World (4 domains) DomainNet: 6 domains, 345 classes (large-scale) Wilds: Real-world distribution shifts (multiple datasets)
Recent studies show that many DG methods don't outperform ERM (standard training) when properly tuned. Model selection is hard without target data, and hyperparameters often overfit to specific domain splits. Domain generalization remains an open problem.
You've completed the Domain Adaptation module. You now understand how to detect, measure, and address domain shift using covariate shift correction, distribution matching, adversarial adaptation, and domain generalization techniques.