Loading content...
GAN training is notoriously difficult. Unlike supervised learning where loss steadily decreases, GAN training involves two networks in dynamic competition. The loss curves oscillate, plateau, and sometimes diverge entirely. What makes training stable? Why do some runs produce stunning results while others collapse? Understanding training dynamics is essential for successfully deploying GANs.
This page examines the theoretical and practical aspects of GAN training: the alternating optimization procedure, convergence challenges, instability sources, and the techniques that have made modern GANs reliable enough for production use.
By the end of this page, you will understand: the alternating gradient descent procedure, why GANs don't converge like standard neural networks, common training failure modes, and practical techniques for stable training including learning rate tuning, architectural choices, and regularization methods.
GAN training uses alternating gradient descent: we optimize the discriminator and generator in turn, each taking steps while the other is fixed.
The Training Loop:
for each training iteration:
# Step 1: Train Discriminator
for k discriminator steps:
Sample real batch x from data
Sample noise z from prior
Generate fake batch G(z)
Compute D loss: L_D = -log(D(x)) - log(1-D(G(z)))
Update D parameters: θ_D ← θ_D - α∇L_D
# Step 2: Train Generator
Sample noise z from prior
Compute G loss: L_G = -log(D(G(z))) # non-saturating
Update G parameters: θ_G ← θ_G - α∇L_G
Why Alternating?
Simultaneous gradient updates create circular dependencies. If both networks update at once, each is chasing a moving target. Alternating updates provide a stable optimization target for each step.
The k:1 Ratio:
The original paper suggested training D for k steps per G step (k ≈ 5). The intuition: D should be "optimally" discriminating so G receives meaningful gradients. In practice, k=1 often works with modern architectures.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
"""Standard GAN Training Loop with Monitoring"""import torchimport torch.nn as nn def train_gan(generator, discriminator, dataloader, g_optimizer, d_optimizer, num_epochs, device, d_steps_per_g_step=1): """ Complete GAN training loop with best practices. """ criterion = nn.BCEWithLogitsLoss() for epoch in range(num_epochs): for batch_idx, (real_data, _) in enumerate(dataloader): batch_size = real_data.size(0) real_data = real_data.to(device) # Labels real_labels = torch.ones(batch_size, 1, device=device) fake_labels = torch.zeros(batch_size, 1, device=device) # ======================================== # Train Discriminator (k steps) # ======================================== for _ in range(d_steps_per_g_step): d_optimizer.zero_grad() # Real samples d_real = discriminator(real_data) d_loss_real = criterion(d_real, real_labels) # Fake samples z = torch.randn(batch_size, generator.latent_dim, device=device) fake_data = generator(z).detach() # Don't backprop to G d_fake = discriminator(fake_data) d_loss_fake = criterion(d_fake, fake_labels) d_loss = d_loss_real + d_loss_fake d_loss.backward() d_optimizer.step() # ======================================== # Train Generator (1 step) # ======================================== g_optimizer.zero_grad() z = torch.randn(batch_size, generator.latent_dim, device=device) fake_data = generator(z) d_fake = discriminator(fake_data) # Non-saturating loss: maximize log(D(G(z))) g_loss = criterion(d_fake, real_labels) # Trick: use real labels g_loss.backward() g_optimizer.step() # Epoch monitoring print(f"Epoch {epoch}: D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")Standard gradient descent converges to local minima under mild conditions. GAN training is fundamentally different—we're finding a Nash equilibrium of a game, not a minimum of a loss.
Why Standard Analysis Fails:
The Cycling Problem:
In simple games (e.g., $\min_x \max_y xy$), gradient descent doesn't converge—it cycles around the equilibrium. GANs can exhibit similar behavior:
This can lead to oscillating losses without actual improvement.
A common mistake is monitoring GAN training through loss values. Unlike supervised learning, GAN losses don't monotonically decrease. D and G losses oscillate as they compete. The only reliable metric is sample quality—either visual inspection or quantitative measures like FID.
Theoretical Results:
The original GAN paper proved that with infinite capacity and convex objectives:
However, these conditions never hold in practice:
GAN training can fail in several characteristic ways. Recognizing these failure modes is crucial for debugging.
| Symptom | Likely Cause | Remedies |
|---|---|---|
| D loss → 0, G loss → ∞ | D too strong | Reduce D capacity, add noise to D input, use WGAN |
| Low sample diversity | Mode collapse | Minibatch discrimination, unrolled GAN, feature matching |
| Oscillating losses | Unstable equilibrium | Lower learning rate, spectral normalization, gradient penalty |
| NaN in losses | Gradient explosion | Gradient clipping, smaller learning rate, check data preprocessing |
| Good D, bad samples | Insufficient G capacity | Increase G size, add skip connections, use progressive growing |
Years of research have produced a toolkit of techniques for stable GAN training. Here are the most important:
1. Spectral Normalization:
Constrains the Lipschitz constant of each layer by normalizing weights by their largest singular value:
$$\bar{W} = \frac{W}{\sigma(W)}$$
where $\sigma(W)$ is the largest singular value of $W$. This prevents D from becoming too discriminative too quickly.
2. Gradient Penalty (WGAN-GP):
Adds a regularization term encouraging gradients to have unit norm:
$$\mathcal{L}{GP} = \lambda \mathbb{E}{\hat{\mathbf{x}}}[(|\nabla D(\hat{\mathbf{x}})|_2 - 1)^2]$$
where $\hat{\mathbf{x}}$ is interpolated between real and fake samples.
3. Two-Timescale Update Rule (TTUR):
Use different learning rates for G and D. Typically D learns faster:
This gives D time to provide meaningful gradients to G.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
"""Key GAN Stabilization Techniques"""import torchimport torch.nn as nnfrom torch.nn.utils import spectral_norm # 1. Spectral Normalizationclass SpectralNormConv(nn.Module): def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1): super().__init__() self.conv = spectral_norm( nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding) ) def forward(self, x): return self.conv(x) # 2. Gradient Penalty (for WGAN-GP)def gradient_penalty(discriminator, real, fake, device): batch_size = real.size(0) # Random interpolation alpha = torch.rand(batch_size, 1, 1, 1, device=device) interpolated = alpha * real + (1 - alpha) * fake interpolated.requires_grad_(True) # Discriminator output on interpolated d_interpolated = discriminator(interpolated) # Compute gradients gradients = torch.autograd.grad( outputs=d_interpolated, inputs=interpolated, grad_outputs=torch.ones_like(d_interpolated), create_graph=True, retain_graph=True )[0] # Gradient penalty gradients = gradients.view(batch_size, -1) gradient_norm = gradients.norm(2, dim=1) penalty = ((gradient_norm - 1) ** 2).mean() return penalty # 3. Label Smoothingdef smooth_labels(labels, smoothing=0.1): """Soft labels: 1 -> 0.9, 0 -> 0.1""" return labels * (1 - smoothing) + 0.5 * smoothing # 4. Instance Noisedef add_instance_noise(x, std=0.1, decay=0.99, step=0): """Add decaying noise to discriminator input""" current_std = std * (decay ** step) noise = torch.randn_like(x) * current_std return x + noise4. Label Smoothing:
Instead of training D with hard labels (0 and 1), use soft labels (0.1 and 0.9). This prevents D from becoming overconfident.
5. Instance Noise:
Add Gaussian noise to D's input, decaying over training. This smooths the decision boundary early on, providing better gradients to G.
6. Historical Averaging:
Add a penalty for weights deviating from their running average:
$$\mathcal{L}_{hist} = |\theta - \bar{\theta}|^2$$
This encourages stable, gradual changes rather than wild swings.
Decades of collective experience have crystallized into practical guidelines for GAN training:
When training fails: (1) Verify data preprocessing—images should be in [-1,1] for tanh output, (2) Check for NaN early with torch.isnan checks, (3) Try a smaller learning rate, (4) Ensure batch size is sufficient (≥32), (5) Verify D isn't dominating by checking D accuracy, (6) Compare against a known-working implementation.
Unlike supervised learning, GAN training requires specialized monitoring approaches.
What to Track:
D(real) and D(fake) separately: D(real) should stay near 0.5-0.7; D(fake) should gradually increase toward 0.5
Gradient norms: Watch for explosion (>100) or vanishing (<0.001)
Sample grids: Generate fixed samples from fixed z vectors throughout training
FID score: Compute periodically (every N epochs) as an objective quality measure
Healthy Training Signs:
Fréchet Inception Distance (FID) measures the distance between real and generated image distributions in feature space. Lower is better. FID < 10 indicates high quality; FID > 100 indicates poor quality. It's the standard quantitative metric for GAN evaluation.
You now understand the dynamics of GAN training—the alternating optimization, convergence challenges, failure modes, and stabilization techniques. Next, we'll examine mode collapse in detail: what causes it, how to detect it, and strategies to prevent it.