Loading learning content...
Adversarial domain adaptation draws inspiration from GANs: rather than explicitly minimizing a distribution distance, we train a domain discriminator to distinguish source from target, while the feature extractor learns to fool this discriminator.
The result? Features that are domain-invariant by construction—if the discriminator can't tell them apart, neither can any downstream classifier's domain-specific behavior.
The Minimax Game: $$\min_G \max_D \mathcal{L}{cls}(G, C) - \lambda \mathcal{L}{domain}(G, D)$$
where:
This page covers the DANN architecture and gradient reversal, ADDA and other adversarial variants, training dynamics and stability, and advanced techniques like class-conditional adversarial adaptation.
DANN introduced the elegant gradient reversal layer (GRL), enabling end-to-end training of adversarial domain adaptation.
┌─────────────────┐
X ──────▶ │ Feature Extractor G │
└────────┬────────┘
│
┌────────┴────────┐
▼ ▼
┌───────────┐ ┌───────────────┐
│ Classifier C │ │ GRL → Domain D │
└─────┬─────┘ └───────┬───────┘
│ │
▼ ▼
ŷ (task) d̂ (domain)
The GRL is the key innovation. During:
Mathematically, define: $$R_\lambda(x) = x$$ $$\frac{\partial R_\lambda}{\partial x} = -\lambda I$$
This makes the feature extractor maximize domain confusion while still minimizing classification loss.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
import torchimport torch.nn as nnfrom torch.autograd import Function class GradientReversal(Function): """ Gradient Reversal Layer for adversarial domain adaptation. Forward: Identity Backward: Negate gradients (scaled by lambda) This elegant trick enables end-to-end training where the feature extractor learns to maximize domain confusion. """ @staticmethod def forward(ctx, x, lambda_): ctx.lambda_ = lambda_ return x.view_as(x) @staticmethod def backward(ctx, grad_output): # Reverse gradients during backprop return grad_output.neg() * ctx.lambda_, None class GRL(nn.Module): """Gradient Reversal Layer module wrapper.""" def __init__(self, lambda_=1.0): super().__init__() self.lambda_ = lambda_ def forward(self, x): return GradientReversal.apply(x, self.lambda_) def set_lambda(self, lambda_): self.lambda_ = lambda_ class DANN(nn.Module): """ Domain-Adversarial Neural Network. Three components: 1. Feature extractor G - shared representation learning 2. Label classifier C - predicts task labels 3. Domain discriminator D - predicts source vs target Training objective: - Minimize classification loss on source - Maximize domain discrimination loss (via GRL) """ def __init__(self, backbone, num_classes, feature_dim=256): super().__init__() # Feature extractor self.feature_extractor = backbone # Bottleneck layer self.bottleneck = nn.Sequential( nn.Linear(backbone.output_dim, feature_dim), nn.BatchNorm1d(feature_dim), nn.ReLU(), nn.Dropout(0.5) ) # Label classifier self.classifier = nn.Linear(feature_dim, num_classes) # Domain discriminator (with GRL) self.grl = GRL(lambda_=1.0) self.domain_classifier = nn.Sequential( nn.Linear(feature_dim, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Linear(64, 1) # Binary: source (0) vs target (1) ) def forward(self, x, alpha=1.0): """ Forward pass returning both task and domain predictions. Args: x: Input images alpha: GRL strength (typically scheduled from 0 to 1) """ # Extract features features = self.feature_extractor(x) features = self.bottleneck(features) # Task prediction class_output = self.classifier(features) # Domain prediction (through GRL) self.grl.set_lambda(alpha) grl_features = self.grl(features) domain_output = self.domain_classifier(grl_features) return class_output, domain_output def train_dann(model, source_loader, target_loader, optimizer, num_epochs, device): """ Training loop for DANN with progressive GRL scheduling. """ criterion_cls = nn.CrossEntropyLoss() criterion_domain = nn.BCEWithLogitsLoss() for epoch in range(num_epochs): model.train() # GRL schedule: gradually increase from 0 to 1 p = epoch / num_epochs alpha = 2.0 / (1 + np.exp(-10 * p)) - 1 for (x_s, y_s), (x_t, _) in zip(source_loader, target_loader): x_s, y_s = x_s.to(device), y_s.to(device) x_t = x_t.to(device) # Combined batch x = torch.cat([x_s, x_t], dim=0) # Domain labels domain_labels = torch.cat([ torch.zeros(len(x_s)), # source = 0 torch.ones(len(x_t)) # target = 1 ]).to(device) # Forward class_output, domain_output = model(x, alpha=alpha) # Classification loss (source only) loss_cls = criterion_cls(class_output[:len(x_s)], y_s) # Domain loss (both domains) loss_domain = criterion_domain( domain_output.squeeze(), domain_labels ) # Total loss (domain loss is reversed by GRL) loss = loss_cls + loss_domain optimizer.zero_grad() loss.backward() optimizer.step() print(f"Epoch {epoch}: cls_loss={loss_cls:.4f}, " f"domain_loss={loss_domain:.4f}, alpha={alpha:.3f}")ADDA uses a different training procedure:
Stage 1: Pre-training
Stage 2: Adversarial Adaptation
Key Differences from DANN:
| Method | Key Innovation | Training | Notes |
|---|---|---|---|
| DANN | Gradient reversal layer | End-to-end | Simple, widely used |
| ADDA | Separate encoders | Two-stage | More flexible |
| CyCADA | Cycle-consistent pixel adaptation | Multi-stage | Image translation + adaptation |
| CDAN | Conditional adversarial | End-to-end | Class-aware alignment |
| MCD | Maximum classifier discrepancy | Alternating | No explicit domain discriminator |
MCD takes a different approach: use two classifiers that try to disagree on target samples.
Idea:
Training Steps:
This finds target features that both classifiers confidently agree on.
Adversarial domain adaptation inherits the training challenges of GANs.
Mode Collapse: The feature extractor may map all inputs to the same representation, trivially fooling the discriminator but destroying task-relevant information.
Discriminator Dominance: If the discriminator learns too fast, gradients to the generator vanish. The generator can't improve.
Oscillation: Generator and discriminator may oscillate without converging to domain-invariant features.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
import torchimport torch.nn as nnimport numpy as np def grl_lambda_schedule(epoch, num_epochs, method='progressive'): """ Schedule for GRL lambda parameter. Starting with low lambda and increasing helps stability: - Generator first learns good features for classification - Then gradually incorporates domain invariance """ progress = epoch / num_epochs if method == 'progressive': # Sigmoid schedule from DANN paper return 2.0 / (1 + np.exp(-10 * progress)) - 1 elif method == 'linear': return progress elif method == 'step': return 1.0 if progress > 0.5 else 0.0 else: return 1.0 class SpectralNorm(nn.Module): """ Spectral normalization for discriminator stability. Constrains the Lipschitz constant of the discriminator, preventing it from becoming too powerful too quickly. """ def __init__(self, module, name='weight', n_power_iterations=1): super().__init__() self.module = module self.name = name self.n_power_iterations = n_power_iterations self._make_params() def _make_params(self): w = getattr(self.module, self.name) height = w.size(0) width = w.view(height, -1).size(1) u = w.new_empty(height).normal_(0, 1) v = w.new_empty(width).normal_(0, 1) self.register_buffer('u', u) self.register_buffer('v', v) def _update_vectors(self): with torch.no_grad(): w = getattr(self.module, self.name) w_mat = w.view(w.size(0), -1) for _ in range(self.n_power_iterations): self.v = nn.functional.normalize( torch.mv(w_mat.t(), self.u), dim=0 ) self.u = nn.functional.normalize( torch.mv(w_mat, self.v), dim=0 ) def forward(self, *args, **kwargs): self._update_vectors() w = getattr(self.module, self.name) w_mat = w.view(w.size(0), -1) sigma = torch.dot(self.u, torch.mv(w_mat, self.v)) w_normalized = w / sigma setattr(self.module, self.name, nn.Parameter(w_normalized)) return self.module(*args, **kwargs) class GradientPenalty: """ Gradient penalty for Wasserstein GAN-style training. Enforces Lipschitz constraint on discriminator by penalizing gradients with norm far from 1. """ def __init__(self, lambda_gp=10.0): self.lambda_gp = lambda_gp def __call__(self, discriminator, real_features, fake_features): batch_size = real_features.size(0) # Random interpolation alpha = torch.rand(batch_size, 1, device=real_features.device) interpolated = alpha * real_features + (1 - alpha) * fake_features interpolated.requires_grad_(True) # Discriminator output on interpolated d_interpolated = discriminator(interpolated) # Gradient w.r.t. interpolated gradients = torch.autograd.grad( outputs=d_interpolated, inputs=interpolated, grad_outputs=torch.ones_like(d_interpolated), create_graph=True, retain_graph=True )[0] # Gradient penalty gradient_norm = gradients.view(batch_size, -1).norm(2, dim=1) penalty = ((gradient_norm - 1) ** 2).mean() return self.lambda_gp * penaltyStandard adversarial adaptation aligns marginal feature distributions. But this can misalign class boundaries!
If domains have different class proportions, aligning marginals forces features from different classes to overlap:
$$P_S(Z) = P_T(Z) \not\Rightarrow P_S(Z|Y) = P_T(Z|Y)$$
Condition the discriminator on classifier predictions:
$$D(G(x), F(x))$$
where $F(x)$ is the classifier's probability vector.
Implementation: Use outer product of features and predictions: $$h(x) = G(x) \otimes F(x)$$
This creates class-aware features for the discriminator.
CDAN also weights the adversarial loss by prediction entropy. High-entropy (uncertain) predictions contribute less, focusing adaptation on confident samples where class information is reliable.
Another approach: K+1 way discriminator instead of binary.
Training:
This ensures target features align with specific source classes, not just aggregated source distribution.
The final page covers domain generalization—learning models that work on any target domain without seeing target data during training.