Loading learning content...
Despite DCGAN's architectural improvements, GAN training remained fundamentally unstable. Even with perfect architecture, practitioners struggled with mode collapse, training oscillation, and the frustrating reality that GAN loss values were meaningless—a decreasing generator loss didn't necessarily mean better image quality.
The Wasserstein GAN (WGAN), introduced by Arjovsky, Chintala, and Bottou in 2017, addressed these problems not through architectural changes, but through a fundamental theoretical insight: the original GAN's training objective was fundamentally flawed, based on a distance metric that becomes degenerate when distributions don't overlap.
WGAN replaced this objective with the Wasserstein distance (also known as Earth Mover's distance), which provides:
This page explores the mathematical foundations of WGAN, why it works, and how it fundamentally changed our understanding of generative modeling.
By the end of this page, you will understand the theoretical problems with JS divergence, the mathematical formulation of Wasserstein distance, the Kantorovich-Rubinstein duality that makes WGAN tractable, weight clipping and gradient penalty, and why WGAN fundamentally improved training stability.
The original GAN optimizes the Jensen-Shannon (JS) divergence between the real data distribution P_r and the generator distribution P_g:
$$JS(P_r || P_g) = \frac{1}{2} KL(P_r || M) + \frac{1}{2} KL(P_g || M)$$
where $M = \frac{1}{2}(P_r + P_g)$ is the mixture distribution, and KL is the Kullback-Leibler divergence.
This seems reasonable—we want to minimize the 'distance' between distributions. But JS divergence has a critical pathology that makes GAN training fundamentally unstable.
When P_r and P_g have non-overlapping support (which is almost always the case in high dimensions), JS divergence is constant (log 2) and provides zero gradient. This means the discriminator can become perfect, and the generator receives no learning signal at all—even when P_g is 'close' to P_r in some intuitive sense.
Why distributions don't overlap:
Real images lie on a low-dimensional manifold embedded in high-dimensional pixel space. A 64×64 RGB image has 12,288 dimensions, but the manifold of 'realistic faces' might have only a few hundred degrees of freedom (pose, lighting, identity, expression).
The generator's distribution P_g is also a low-dimensional manifold (parameterized by the latent space z). Two low-dimensional manifolds in high-dimensional space have measure zero probability of intersection.
Mathematically: if dim(manifold) < dim(space), then smooth manifolds generically don't intersect except at discrete points.
This means JS divergence is almost always at its maximum value (log 2), providing no useful gradient.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
import numpy as npimport matplotlib.pyplot as plt def demonstrate_js_problem(): """ Demonstrate why JS divergence fails for non-overlapping distributions. Consider two 1D Gaussians that are gradually moving closer together. With JS divergence, we get a discontinuous, uninformative gradient. """ def js_divergence(p, q, eps=1e-10): """Compute JS divergence between discrete distributions p and q.""" m = 0.5 * (p + q) kl_pm = np.sum(p * np.log((p + eps) / (m + eps))) kl_qm = np.sum(q * np.log((q + eps) / (m + eps))) return 0.5 * (kl_pm + kl_qm) def wasserstein_distance(p, q, x): """Compute Wasserstein distance (1D case: integral of |CDF_p - CDF_q|).""" cdf_p = np.cumsum(p) * (x[1] - x[0]) cdf_q = np.cumsum(q) * (x[1] - x[0]) return np.sum(np.abs(cdf_p - cdf_q)) * (x[1] - x[0]) # Create two Gaussians at varying distances x = np.linspace(-10, 10, 1000) dx = x[1] - x[0] distances = np.linspace(0.01, 5, 100) js_values = [] w_values = [] for d in distances: # P_r: Gaussian at 0 p_r = np.exp(-x**2 / 2) / np.sqrt(2 * np.pi) p_r /= (p_r.sum() * dx) # Normalize # P_g: Gaussian at distance d p_g = np.exp(-(x - d)**2 / 2) / np.sqrt(2 * np.pi) p_g /= (p_g.sum() * dx) # Normalize js_values.append(js_divergence(p_r, p_g)) w_values.append(wasserstein_distance(p_r, p_g, x)) # JS divergence saturates quickly; Wasserstein is linear in distance print("Key insight:") print(f"JS divergence range: [{min(js_values):.4f}, {max(js_values):.4f}]") print(f"Wasserstein range: [{min(w_values):.4f}, {max(w_values):.4f}]") print("JS saturates at log(2) ≈ 0.693 when distributions barely overlap") print("Wasserstein provides smooth, linear gradient in distance") demonstrate_js_problem() # The practical consequence:"""When training a standard GAN: 1. Discriminator quickly learns to perfectly separate P_r and P_g (because they don't overlap in high dimensions) 2. JS divergence = log(2), gradient = 0 3. Generator receives no learning signal 4. Training stalls or requires careful tricks (noise injection, label smoothing, careful learning rate scheduling) WGAN's key insight: use a distance that provides gradient even when distributions don't overlap."""The Wasserstein-1 distance (also called Earth Mover's Distance or EMD) provides a solution to JS divergence's problems. The intuition is beautiful: imagine P_r as a pile of dirt and P_g as another pile. The Wasserstein distance is the minimum 'work' needed to transform one pile into the other, where work = mass × distance moved.
Formal definition:
$$W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x,y) \sim \gamma} [||x - y||]$$
where $\Pi(P_r, P_g)$ is the set of all joint distributions (transport plans) with marginals P_r and P_g.
In words: consider all possible ways to move mass from P_r to P_g, and find the one that minimizes total distance traveled.
Think of two cities with different population distributions. The Wasserstein distance measures the minimum total travel distance needed to move people so that the first city's distribution matches the second. Unlike JS divergence, this is always well-defined and provides smooth gradients—even a small movement of the population is reflected in the distance.
Why Wasserstein distance is superior:
No saturation: Even when distributions don't overlap, W provides a meaningful distance and gradient.
Metric properties: W satisfies all metric axioms (non-negativity, identity, symmetry, triangle inequality), unlike JS which fails triangle inequality.
Weak convergence: W(P_n, P) → 0 implies P_n → P in distribution, which is exactly what we want for generative modeling.
Lipschitz continuity: Small changes in generator parameters lead to small changes in W, enabling stable gradient descent.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
import numpy as npfrom scipy.stats import wasserstein_distance as scipy_wimport matplotlib.pyplot as plt def wasserstein_vs_js_visualization(): """ Visualize the key difference between Wasserstein and JS divergence. We'll move a Gaussian distribution gradually toward another and plot how each divergence measure changes. """ # Create the target distribution (fixed) x = np.linspace(-5, 10, 1000) target = np.exp(-x**2 / 2) # Gaussian at 0 target /= target.sum() # Move a Gaussian from distance d toward the target js_values = [] w_values = [] distances = np.linspace(5, 0, 50) # Start far, end close for d in distances: # Generator's current distribution generated = np.exp(-(x - d)**2 / 2) generated /= generated.sum() # Compute Wasserstein distance (scipy uses 1D optimal transport) w = scipy_w(x, x, target, generated) w_values.append(w) # Compute JS divergence m = 0.5 * (target + generated) js = 0.5 * np.sum(target * np.log(target / m + 1e-10)) + \ 0.5 * np.sum(generated * np.log(generated / m + 1e-10)) js_values.append(js) # Key observation: Wasserstein is smooth and monotonic # JS divergence saturates when distributions don't overlap return { 'distances': distances, 'wasserstein': w_values, # Linear decrease 'js': js_values # Saturated until overlap } # The fundamental insight:"""Consider two point masses: P_r = δ_0 (mass at 0) and P_g = δ_θ (mass at θ) JS divergence:- If θ ≠ 0: JS(P_r, P_g) = log(2) (constant!)- If θ = 0: JS(P_r, P_g) = 0 This is discontinuous at θ = 0, and ∂JS/∂θ = 0 everywhere else.The generator has no gradient to follow! Wasserstein distance:- W(P_r, P_g) = |θ| (linear in distance)- ∂W/∂θ = sign(θ) The generator has a clear, consistent gradient: move toward 0. This is why WGAN training is more stable: the loss function provides directional information even when distributions are far apart."""The primal formulation of Wasserstein distance requires solving an optimization problem over the space of all joint distributions—computationally intractable. The Kantorovich-Rubinstein duality provides an alternative formulation that's actually computable:
$$W(P_r, P_g) = \sup_{||f||L \leq 1} \mathbb{E}{x \sim P_r}[f(x)] - \mathbb{E}_{x \sim P_g}[f(x)]$$
where the supremum is over all 1-Lipschitz functions f.
A function f is 1-Lipschitz if: $|f(x) - f(y)| \leq ||x - y||$ for all x, y.
In words: among all functions that can't change faster than slope 1, find the one that maximizes the difference between E[f(real)] and E[f(fake)].
The discriminator in WGAN is no longer a classifier outputting probabilities—it's a 'critic' that outputs a scalar score (can be any real number). The critic is trained to maximize the score for real samples and minimize it for fake samples, subject to the Lipschitz constraint. The difference in expected scores IS the Wasserstein distance.
Mathematical derivation:
The original optimal transport problem: $$W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \int ||x - y|| d\gamma(x, y)$$
is the primal form of a linear program. By strong duality (under mild conditions), this equals: $$W(P_r, P_g) = \sup_{f} \int f(x) dP_r(x) - \int f(x) dP_g(x)$$
where f ranges over 1-Lipschitz functions. The Lipschitz constraint is dual to the transport constraint that marginals must match P_r and P_g.
Intuition for the dual:
Think of f(x) as the 'price' at location x. The Lipschitz constraint says prices can't change faster than transport cost. The primal asks: 'what's the cheapest way to move mass?' The dual asks: 'what's the largest price difference I can maintain?'
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
import torchimport torch.nn as nn # The WGAN objective using Kantorovich-Rubinstein duality class WGANCritic(nn.Module): """ The critic (not 'discriminator') in WGAN. Key differences from standard GAN discriminator: 1. No sigmoid at output - outputs unbounded real value 2. Must be approximately 1-Lipschitz 3. Trained to maximize E[f(real)] - E[f(fake)] """ def __init__(self, nc=3, ndf=64): super().__init__() # Same architecture as DCGAN, but no sigmoid self.main = nn.Sequential( nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), # NO sigmoid here - output is unbounded ) def forward(self, x): return self.main(x).view(-1) def wgan_loss(critic, real_samples, fake_samples): """ Compute WGAN loss from Kantorovich-Rubinstein formulation. W(P_r, P_g) ≈ E[f(real)] - E[f(fake)] Critic wants to MAXIMIZE this (larger scores for real) Generator wants to MINIMIZE this (higher scores for fake) """ # Critic scores real_scores = critic(real_samples) fake_scores = critic(fake_samples) # Wasserstein estimate w_distance = real_scores.mean() - fake_scores.mean() # Critic loss: negative W distance (we maximize via gradient descent on negative) critic_loss = -w_distance # Generator loss: negative mean fake score (want to maximize fake scores) generator_loss = -fake_scores.mean() return critic_loss, generator_loss, w_distance.item() # Understanding the loss values:"""In standard GAN:- D outputs probability in [0, 1]- Losses involve log(D(x)) and log(1 - D(G(z)))- Loss values are hard to interpret In WGAN:- Critic outputs unbounded scalar- Loss is simply E[f(real)] - E[f(fake)]- This directly estimates Wasserstein distance! Implications:- Lower W distance = better generator (always!)- Loss curve directly tracks generation quality- No need for heuristics like comparing sample quality- Training progress is measurable and interpretable"""The theoretical formulation requires the critic to be 1-Lipschitz. In practice, this constraint is challenging to enforce exactly. The original WGAN paper proposed weight clipping; later work introduced the more principled gradient penalty (WGAN-GP).
The original WGAN enforced Lipschitz continuity by clipping all weights to a compact set [-c, c] after each gradient update:
$$w \leftarrow \text{clip}(w, -c, c)$$
The idea: if all weights are bounded, the function can't have arbitrarily large derivatives, approximately enforcing Lipschitz. The paper used c = 0.01.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
def train_wgan_with_clipping( critic, generator, dataloader, n_critic=5, clip_value=0.01, epochs=100): """ Original WGAN training with weight clipping. Key differences from standard GAN: 1. Train critic multiple times per generator step (n_critic=5) 2. Clip critic weights after each update 3. Use RMSprop (not Adam) - the paper found Adam unstable 4. No log in loss - just mean difference """ import torch.optim as optim # RMSprop, not Adam - more stable for WGAN optim_critic = optim.RMSprop(critic.parameters(), lr=5e-5) optim_gen = optim.RMSprop(generator.parameters(), lr=5e-5) for epoch in range(epochs): for real_batch, _ in dataloader: batch_size = real_batch.size(0) # ============================================ # Train critic for n_critic steps # ============================================ for _ in range(n_critic): optim_critic.zero_grad() # Sample real and fake z = torch.randn(batch_size, 100) fake = generator(z).detach() # Critic scores real_score = critic(real_batch).mean() fake_score = critic(fake).mean() # WGAN critic loss: maximize real - fake critic_loss = -(real_score - fake_score) critic_loss.backward() optim_critic.step() # WEIGHT CLIPPING - the key constraint for p in critic.parameters(): p.data.clamp_(-clip_value, clip_value) # ============================================ # Train generator # ============================================ optim_gen.zero_grad() z = torch.randn(batch_size, 100) fake = generator(z) # Generator wants to maximize critic score of fakes gen_loss = -critic(fake).mean() gen_loss.backward() optim_gen.step() # The Wasserstein distance estimate w_dist = real_score.item() - fake_score.item() print(f"Epoch {epoch}: W-distance = {w_dist:.4f}")Weight clipping works but has issues: (1) If c is too low, the critic has limited capacity; (2) If c is too high, training becomes unstable; (3) Clipping biases the critic toward simple functions; (4) Gradient flow can still vanish in deep networks. These motivate the gradient penalty approach.
WGAN-GP requires disabling BatchNorm in the critic. BatchNorm introduces dependencies between samples in a batch, which violates the assumption that the gradient penalty can be computed per-sample. Common alternatives: LayerNorm, InstanceNorm, or Spectral Normalization (discussed later).
WGAN fundamentally changes the training dynamics of GANs. Understanding these dynamics is crucial for successful training and debugging.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
import matplotlib.pyplot as pltimport numpy as np def analyze_wgan_training(training_history): """ Analyze WGAN training curves. Unlike standard GAN, WGAN losses are interpretable: - Critic loss ≈ -W(P_r, P_g) (negated for minimization) - Generator loss ≈ -E[f(fake)] - W_distance = E[f(real)] - E[f(fake)] should decrease over training """ # Extract metrics w_distances = training_history['w_distance'] critic_losses = training_history['critic_loss'] gen_losses = training_history['gen_loss'] # Healthy training signs: signs = { 'w_decreasing': is_decreasing_trend(w_distances), 'no_oscillation': check_smoothness(w_distances), 'critic_converging': is_converging(critic_losses), 'no_explosion': max(abs(np.array(w_distances))) < 100 } return signs def is_decreasing_trend(values, window=100): """Check if values show decreasing trend over time.""" if len(values) < 2 * window: return None early_mean = np.mean(values[:window]) late_mean = np.mean(values[-window:]) return late_mean < early_mean def check_smoothness(values, threshold=2.0): """Check if values are relatively smooth (no wild oscillations).""" diffs = np.abs(np.diff(values)) return np.std(diffs) / np.mean(diffs) < threshold def is_converging(values, window=100): """Check if values are converging (decreasing variance).""" if len(values) < 2 * window: return None early_var = np.var(values[:window]) late_var = np.var(values[-window:]) return late_var < early_var # What to look for in WGAN training:"""GOOD SIGNS:✓ W_distance steadily decreases✓ Critic loss is relatively stable after initial transient✓ Generated samples improve in tandem with lower W_distance✓ No sudden spikes or plateaus in losses BAD SIGNS:✗ W_distance increases or oscillates wildly → Something wrong with Lipschitz constraint ✗ W_distance stuck at high value → Critic not powerful enough or learning rate too low ✗ Losses explode to ±infinity → Gradient penalty coefficient too low, or numerical issues ✗ W_distance goes to 0 but samples are bad → Mode collapse (rare in WGAN but possible)"""WGAN's interpretable loss curves are transformative for practitioners. In standard GAN, you must visually inspect samples to gauge progress. In WGAN, you can monitor training curves like any supervised learning task. This enables: automated early stopping, hyperparameter tuning based on validation W-distance, and reliable comparison of different architectures.
While WGAN-GP is effective, computing second-order gradients is expensive. Spectral Normalization (SN) provides an alternative approach to Lipschitz constraint that's computationally cheaper and often more effective.
The key insight: a neural network's Lipschitz constant is bounded by the product of layer-wise Lipschitz constants. For a linear layer with weight matrix W, the Lipschitz constant is the spectral norm σ(W) (largest singular value).
Spectral normalization divides each weight matrix by its spectral norm:
$$W_{SN} = \frac{W}{\sigma(W)}$$
This ensures each layer is 1-Lipschitz, making the entire network 1-Lipschitz (with respect to the composition bound).
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
import torchimport torch.nn as nnimport torch.nn.functional as F class SpectralNorm(nn.Module): """ Spectral Normalization layer wrapper. Uses power iteration to estimate the spectral norm efficiently. Complexity: O(n*m) per forward pass, where n*m is weight size. Much cheaper than gradient penalty: O(n*m) vs O(full backward pass). """ def __init__(self, module, n_power_iterations=1): super().__init__() self.module = module self.n_power_iterations = n_power_iterations # Initialize singular vectors if not self._made_params(): self._make_params() def _make_params(self): w = self.module.weight.data height = w.shape[0] width = w.view(height, -1).shape[1] # Random initialization of singular vectors u = torch.randn(height, 1) u = u / u.norm() v = torch.randn(width, 1) v = v / v.norm() # Register as buffers (not parameters - not trained) self.module.register_buffer('u', u) self.module.register_buffer('v', v) def _made_params(self): return hasattr(self.module, 'u') and hasattr(self.module, 'v') def _power_iteration(self, w, u, v, n_iterations): """ Power iteration to estimate largest singular value. Converges to σ(W) = ||W||_2 (spectral norm) """ for _ in range(n_iterations): v = F.normalize(torch.mv(w.t(), u), dim=0) u = F.normalize(torch.mv(w, v), dim=0) # Spectral norm estimate sigma = torch.dot(u, torch.mv(w, v)) return sigma, u, v def forward(self, x): w = self.module.weight height = w.shape[0] w_mat = w.view(height, -1) # Estimate spectral norm sigma, u, v = self._power_iteration( w_mat, self.module.u, self.module.v, self.n_power_iterations ) # Update singular vectors (for next iteration) self.module.u = u.detach() self.module.v = v.detach() # Normalize weight by spectral norm w_sn = w / sigma # Apply normalized weight return F.conv2d(x, w_sn, self.module.bias, self.module.stride, self.module.padding) # PyTorch provides spectral_norm utilitydef apply_spectral_norm(discriminator): """Apply spectral normalization to all conv/linear layers in discriminator.""" for name, module in discriminator.named_modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): nn.utils.spectral_norm(module) return discriminator # Example: SNGAN discriminatorclass SNGANDiscriminator(nn.Module): """ Spectrally normalized GAN discriminator. Key architectural choices: - Spectral norm on all layers (no GP needed) - No BatchNorm (conflicts with SN) - Can use standard GAN loss (SN alone stabilizes training) """ def __init__(self, nc=3, ndf=64): super().__init__() self.main = nn.Sequential( nn.utils.spectral_norm(nn.Conv2d(nc, ndf, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True), nn.utils.spectral_norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True), nn.utils.spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True), nn.utils.spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True), nn.utils.spectral_norm(nn.Conv2d(ndf * 8, 1, 4, 1, 0)), ) def forward(self, x): return self.main(x).view(-1)| Method | Computational Cost | Training Stability | Sample Quality | Ease of Use |
|---|---|---|---|---|
| Weight Clipping | Very Low | Moderate | Lower (capacity limited) | Easy (1 hyperparameter) |
| Gradient Penalty | High (2nd order grad) | High | High | Moderate (requires care with BatchNorm) |
| Spectral Norm | Low (power iteration) | Very High | Very High | Very Easy (drop-in wrapper) |
Wasserstein GAN represents a fundamental theoretical advancement in generative modeling. By replacing JS divergence with Wasserstein distance, WGAN solved many of the stability problems that plagued early GANs.
WGAN's principles underpin virtually all modern GAN training. Even when using different losses (hinge loss, relativistic loss), the insights about Lipschitz constraints and gradient behavior remain relevant. The next page explores Progressive GAN, which combines WGAN's stable training with a clever curriculum learning approach to generate high-resolution images.