Loading content...
By 2017, GANs could generate convincing 64×64 images, but attempts to scale to higher resolutions consistently failed. Training 1024×1024 generators led to training collapse, mode dropping, and images that fell apart at any reasonable zoom level. The fundamental challenge: generating high-resolution images requires learning both global structure and fine details simultaneously, an optimization task that proved too difficult for direct training.
Progressive Growing of GANs (ProGAN), introduced by Karras et al. at NVIDIA in 2017, solved this problem with an elegantly simple insight: don't try to learn everything at once. Instead, start training on 4×4 images (where getting the global structure right is easy), then progressively add higher-resolution layers while training continues.
This curriculum learning approach was transformative. ProGAN was the first method to generate photorealistic 1024×1024 faces—images so convincing they could fool humans. More importantly, the progressive training methodology became a foundational technique adopted by subsequent architectures, including StyleGAN.
By the end of this page, you will understand the challenges of high-resolution GAN training, the progressive growing methodology, smooth layer transitions (fade-in), minibatch standard deviation for diversity, equalized learning rates, and how to implement progressive training from scratch.
To understand why ProGAN was necessary, we must first understand why direct high-resolution training fails. The challenges compound exponentially with resolution.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546
import numpy as np def analyze_resolution_scaling(): """ Quantify why high-resolution training is fundamentally harder. """ resolutions = [4, 8, 16, 32, 64, 128, 256, 512, 1024] for res in resolutions: pixels = res * res * 3 # RGB # Typical generator parameters scale with resolution # (roughly doubles per doubled resolution) if res <= 4: params_g = 0.5e6 else: base_res_idx = resolutions.index(res) params_g = 0.5e6 * (2 ** (base_res_idx - 1)) # Discriminator similar scaling params_d = params_g * 0.8 # Memory per image (bytes, float32) memory_per_image = pixels * 4 # Maximum practical batch size (assuming 32GB GPU) max_batch = min(512, 32e9 / (memory_per_image * 4)) # 4x for activations print(f"Resolution {res}x{res}:") print(f" Pixels: {pixels:,}") print(f" G params: {params_g/1e6:.1f}M") print(f" Memory/image: {memory_per_image/1024:.1f} KB") print(f" Max batch (approx): {int(max_batch)}") print() analyze_resolution_scaling() # Output shows the scaling challenge:# 64x64: 12K pixels, ~2M params, batch=512 OK# 256x256: 196K pixels, ~8M params, batch=128 tight# 1024x1024: 3M pixels, ~23M params, batch=4-8 only! # The key insight of ProGAN:# Don't try to solve the hard problem directly.# Build up to it incrementally, using the solution# to easier problems as initialization.ProGAN's core insight comes from curriculum learning: students learn better when starting with easy examples and progressing to hard ones. For GANs, 'easy' means low resolution (where structure is all that matters) and 'hard' means high resolution (where details become critical). By training incrementally, each stage builds on the stable foundation of the previous stage.
The progressive growing algorithm trains the generator and discriminator starting at 4×4 resolution, then progressively adds higher-resolution layers. At each stage, both networks learn to handle the current resolution before moving to the next.
The training schedule:
The crucial innovation is the smooth transition (fade-in) when adding new layers. Without smooth transitions, adding new random layers causes training collapse.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
import torchimport torch.nn as nnimport torch.nn.functional as F class ProgressiveGenerator(nn.Module): """ Progressive GAN generator that grows during training. Architecture: - Latent z → 4x4 constant → progressive upsampling blocks - Each block: upsample → conv → conv - New blocks are "faded in" smoothly using alpha blending """ def __init__(self, latent_dim=512, max_resolution=1024): super().__init__() self.latent_dim = latent_dim # Number of channels at each resolution # Decreases as resolution increases (memory constraint) self.channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256, 128: 128, 256: 64, 512: 32, 1024: 16 } # Initial 4x4 block: learned constant + style injection self.initial_constant = nn.Parameter( torch.randn(1, 512, 4, 4) ) # Blocks for each resolution (added progressively) self.blocks = nn.ModuleDict() self.to_rgb = nn.ModuleDict() # Separate toRGB for each resolution # Build all blocks (but only use up to current resolution during training) prev_channels = 512 for res in [8, 16, 32, 64, 128, 256, 512, 1024]: if res > max_resolution: break out_channels = self.channels[res] self.blocks[str(res)] = self._make_block(prev_channels, out_channels) self.to_rgb[str(res)] = nn.Conv2d(out_channels, 3, 1) prev_channels = out_channels # Initial toRGB for 4x4 self.to_rgb['4'] = nn.Conv2d(512, 3, 1) # Fade-in alpha (0 = use old layers, 1 = use new layers) self.alpha = 1.0 # Current training resolution self.current_resolution = 4 def _make_block(self, in_channels, out_channels): """Create an upsampling block: upsample → conv → conv.""" return nn.Sequential( nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.LeakyReLU(0.2), ) def forward(self, z): """ Generate image at current resolution with fade-in support. During transition from resolution n to 2n: - Run both the old path (upsample existing) and new path (new block) - Blend: output = (1 - alpha) * old_rgb + alpha * new_rgb """ batch_size = z.size(0) # Start with learned constant x = self.initial_constant.expand(batch_size, -1, -1, -1) # If we're only at 4x4 if self.current_resolution == 4: return torch.tanh(self.to_rgb['4'](x)) # Process through blocks up to current resolution resolutions = [8, 16, 32, 64, 128, 256, 512, 1024] for i, res in enumerate(resolutions): if res > self.current_resolution: break # Is this the newly added block (being faded in)? is_new_block = (res == self.current_resolution) and (self.alpha < 1.0) if is_new_block: # Fade-in: blend old upsampled output with new block output # Old path: upsample previous resolution's RGB output prev_res = resolutions[i-1] if i > 0 else 4 old_rgb = self.to_rgb[str(prev_res)](x) old_rgb = F.interpolate(old_rgb, scale_factor=2, mode='nearest') # New path: apply new block x = self.blocks[str(res)](x) new_rgb = self.to_rgb[str(res)](x) # Blend rgb = (1 - self.alpha) * old_rgb + self.alpha * new_rgb else: # Normal forward pass through block x = self.blocks[str(res)](x) if not is_new_block: rgb = self.to_rgb[str(self.current_resolution)](x) return torch.tanh(rgb) def grow(self, new_resolution): """Transition to a higher resolution.""" self.current_resolution = new_resolution self.alpha = 0.0 # Start with full old path def update_alpha(self, alpha): """Update fade-in progress (0 to 1).""" self.alpha = min(1.0, alpha)Understanding the fade-in mechanism:
When adding a new resolution block, we can't just replace the old output with the new one—this would cause a sudden change that destabilizes training. Instead, we blend:
$$\text{output} = (1 - \alpha) \cdot \text{upsample}(\text{old_rgb}) + \alpha \cdot \text{new_block_rgb}$$
α starts at 0 (purely old path) and linearly increases to 1 (purely new path) over many training iterations. This gives the new layers time to learn without disrupting what the network has already learned.
The fade-in mechanism is critical for stable progressive training. Let's visualize exactly how it works for both the generator and discriminator.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
import torchimport torch.nn as nnimport torch.nn.functional as F class FadeInGenerator(nn.Module): """ Detailed implementation of fade-in for generator. Transition from N×N to 2N×2N: Phase 1 (α=0): Full old path ┌─────────────────────────────────────┐ │ latent → blocks(N) → toRGB(N) → ↑2 → output(2N) └─────────────────────────────────────┘ Phase 2 (0<α<1): Blending ┌─────────────────────────────────────────────────────┐ │ latent → blocks(N) ─┬→ toRGB(N) → ↑2 ─┐ │ │ │ × (1-α) │ │ ┴───→ + → output │ └→ block(2N) → toRGB(2N) ─┘ × α └─────────────────────────────────────────────────────┘ Phase 3 (α=1): Full new path ┌───────────────────────────────────────────────────┐ │ latent → blocks(N) → block(2N) → toRGB(2N) → output └───────────────────────────────────────────────────┘ """ def __init__(self): super().__init__() # Example: transitioning from 16x16 to 32x32 # Existing blocks (already trained) self.blocks_16 = nn.Sequential( nn.ConvTranspose2d(512, 256, 4, 1, 0), # 4x4 nn.LeakyReLU(0.2), nn.ConvTranspose2d(256, 128, 4, 2, 1), # 8x8 nn.LeakyReLU(0.2), nn.ConvTranspose2d(128, 64, 4, 2, 1), # 16x16 nn.LeakyReLU(0.2), ) # New block being faded in self.block_32 = nn.Sequential( nn.Upsample(scale_factor=2), nn.Conv2d(64, 32, 3, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(32, 32, 3, padding=1), nn.LeakyReLU(0.2), ) # ToRGB converters self.to_rgb_16 = nn.Conv2d(64, 3, 1) # From 16x16 features self.to_rgb_32 = nn.Conv2d(32, 3, 1) # From 32x32 features def forward(self, z, alpha=1.0): # Process through existing blocks x_16 = self.blocks_16(z) # 16x16 features if alpha == 0.0: # Pure old path: just upsample 16x16 RGB rgb_16 = self.to_rgb_16(x_16) return F.interpolate(rgb_16, scale_factor=2, mode='bilinear') elif alpha == 1.0: # Pure new path: use new block x_32 = self.block_32(x_16) return self.to_rgb_32(x_32) else: # Blending: both paths are active # Old path: RGB from 16x16, upsampled to 32x32 rgb_16 = self.to_rgb_16(x_16) rgb_old = F.interpolate(rgb_16, scale_factor=2, mode='bilinear') # New path: RGB from 32x32 features x_32 = self.block_32(x_16) rgb_new = self.to_rgb_32(x_32) # Weighted blend return (1 - alpha) * rgb_old + alpha * rgb_new class FadeInDiscriminator(nn.Module): """ Discriminator fade-in mirrors the generator. Transition from 2N×2N to N×N processing: ┌─────────────────────────────────────────────────────┐ │ ┌→ fromRGB(2N) → block(2N→N) ─┐ │ │ │ × α │ input(2N) ─┤ ┴───→ + → blocks(N) → output │ │ ┬ │ └→ ↓2 → fromRGB(N) ───────────┘ × (1-α) └─────────────────────────────────────────────────────┘ """ def __init__(self): super().__init__() # FromRGB converters self.from_rgb_32 = nn.Conv2d(3, 32, 1) self.from_rgb_16 = nn.Conv2d(3, 64, 1) # New block (processes new higher-res input) self.block_32 = nn.Sequential( nn.Conv2d(32, 64, 3, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(64, 64, 3, padding=1), nn.LeakyReLU(0.2), nn.AvgPool2d(2), ) # Existing blocks (already trained) self.blocks_16_down = nn.Sequential( nn.Conv2d(64, 128, 3, padding=1), nn.LeakyReLU(0.2), nn.AvgPool2d(2), # 8x8 nn.Conv2d(128, 256, 3, padding=1), nn.LeakyReLU(0.2), nn.AvgPool2d(2), # 4x4 nn.Conv2d(256, 512, 3, padding=1), nn.LeakyReLU(0.2), nn.Flatten(), nn.Linear(512 * 4 * 4, 1), ) def forward(self, img, alpha=1.0): # img is 32x32 if alpha == 0.0: # Pure old path: downsample to 16x16, use old fromRGB img_16 = F.avg_pool2d(img, 2) x = self.from_rgb_16(img_16) elif alpha == 1.0: # Pure new path: use new block x = self.from_rgb_32(img) x = self.block_32(x) else: # Blending # Old path: downsample then fromRGB img_16 = F.avg_pool2d(img, 2) x_old = self.from_rgb_16(img_16) # New path: fromRGB then block x_new = self.from_rgb_32(img) x_new = self.block_32(x_new) # x_new and x_old are both 16x16 with 64 channels x = (1 - alpha) * x_old + alpha * x_new return self.blocks_16_down(x)The fade-in mechanism ensures that newly added layers start by contributing almost nothing to the output (α≈0). As training progresses when α increases to 1, these layers gradually take over responsibility for the new resolution details. The old path acts as a 'student' that the new path must match and eventually surpass. This is much more stable than randomly initialized layers immediately taking full control.
Beyond progressive growing, ProGAN introduced several additional techniques that became standard in subsequent GAN architectures.
Minibatch Standard Deviation combats mode collapse by giving the discriminator explicit information about variation within a batch.
The idea: if the generator produces identical outputs (mode collapse), the discriminator can detect this by noticing zero variation. Conversely, real image batches have natural variation.
Implementation:
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
import torchimport torch.nn as nn class MinibatchStdDev(nn.Module): """ Minibatch standard deviation layer. Appends a constant feature channel that encodes the variation within the current minibatch. This helps the discriminator detect mode collapse. """ def __init__(self, group_size=4, num_features=1): """ Args: group_size: Compute std over groups of this size (helps with small batches) num_features: Number of std features to compute (1 in original ProGAN) """ super().__init__() self.group_size = group_size self.num_features = num_features def forward(self, x): batch_size, channels, height, width = x.shape # Ensure group_size doesn't exceed batch_size group_size = min(self.group_size, batch_size) # Determine how many features per group features_per_group = channels // self.num_features # Reshape: [B, C, H, W] -> [G, M, F, C//F, H, W] # G = groups, M = samples per group, F = num_features y = x.view( group_size, batch_size // group_size, self.num_features, features_per_group, height, width ) # Compute std across samples in each group # std([G, M, F, C//F, H, W], dim=1) -> [G, F, C//F, H, W] y = y.float() # Ensure float for std computation y = y.var(dim=1) # Variance within group y = (y + 1e-8).sqrt() # Standard deviation # Average over features and spatial dimensions # [G, F, C//F, H, W] -> [G, F] -> [G, 1] y = y.mean(dim=[2, 3, 4], keepdim=True) y = y.mean(dim=1, keepdim=True) # Tile to match input spatial dimensions y = y.repeat(batch_size // group_size, 1, height, width) # Concatenate as new channel return torch.cat([x, y], dim=1) # Usage in discriminatorclass ProGANDiscriminator(nn.Module): def __init__(self): super().__init__() # ... previous layers ... # Add before final layers (at 4x4 resolution) self.minibatch_std = MinibatchStdDev() # Final conv now has +1 input channel self.final_conv = nn.Conv2d(512 + 1, 512, 3, padding=1) def forward(self, x): # ... process to 4x4 ... # Add minibatch std as extra feature x = self.minibatch_std(x) # [B, 513, 4, 4] x = self.final_conv(x) # [B, 512, 4, 4] # ... final output ... # The discriminator can now "see" batch diversity:# - Real batches: high std (natural variation)# - Mode collapse: zero std (all samples identical)The training schedule determines how long to train at each resolution and how quickly to fade in new layers. Getting this right is crucial for training success.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
import torchimport torch.optim as optimfrom torch.utils.data import DataLoader class ProgressiveTrainer: """ Complete progressive GAN training with scheduling. """ def __init__( self, generator, discriminator, dataset, device='cuda' ): self.G = generator.to(device) self.D = discriminator.to(device) self.device = device # Training schedule: images per resolution self.schedule = { 4: 800_000, # 800K images at 4x4 8: 800_000, # 800K images at 8x8 16: 800_000, # etc. 32: 800_000, 64: 800_000, 128: 800_000, 256: 1_600_000, # More training at higher res 512: 1_600_000, 1024: 1_600_000, } # Half of time at each resolution is fade-in # Other half is stabilization at full alpha self.fade_in_fraction = 0.5 # Dataset that can return images at different resolutions self.dataset = dataset # Current training state self.current_resolution = 4 self.images_shown = 0 self.current_alpha = 1.0 def get_batch_size(self, resolution): """Scale batch size inversely with resolution for memory.""" sizes = { 4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4, 1024: 2 } return sizes[resolution] def get_learning_rate(self, resolution): """Slightly reduce LR at higher resolutions for stability.""" base_lr = 0.001 if resolution >= 256: return base_lr * 0.5 if resolution >= 512: return base_lr * 0.25 return base_lr def train(self, total_iterations=None): """Main training loop across all resolutions.""" resolutions = sorted(self.schedule.keys()) for res in resolutions: print(f"\n{'='*50}") print(f"Training at {res}x{res}") print(f"{'='*50}") # Update dataset to return images at this resolution self.dataset.set_resolution(res) # Create dataloader with appropriate batch size batch_size = self.get_batch_size(res) dataloader = DataLoader( self.dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True ) # Create optimizers with resolution-specific LR lr = self.get_learning_rate(res) opt_G = optim.Adam(self.G.parameters(), lr=lr, betas=(0, 0.99)) opt_D = optim.Adam(self.D.parameters(), lr=lr, betas=(0, 0.99)) # Grow networks to new resolution if res > 4: self.G.grow(res) self.D.grow(res) # Train this resolution images_this_res = 0 total_images = self.schedule[res] fade_in_images = int(total_images * self.fade_in_fraction) while images_this_res < total_images: for real_batch in dataloader: real_batch = real_batch.to(self.device) current_batch_size = real_batch.size(0) # Update alpha for fade-in if images_this_res < fade_in_images and res > 4: alpha = images_this_res / fade_in_images else: alpha = 1.0 self.G.alpha = alpha self.D.alpha = alpha # Train discriminator opt_D.zero_grad() z = torch.randn(current_batch_size, 512).to(self.device) fake = self.G(z) real_score = self.D(real_batch) fake_score = self.D(fake.detach()) # WGAN-GP loss d_loss = fake_score.mean() - real_score.mean() # Gradient penalty gp = self.compute_gradient_penalty(real_batch, fake.detach()) d_loss = d_loss + 10 * gp d_loss.backward() opt_D.step() # Train generator opt_G.zero_grad() z = torch.randn(current_batch_size, 512).to(self.device) fake = self.G(z) fake_score = self.D(fake) g_loss = -fake_score.mean() g_loss.backward() opt_G.step() # Update counters images_this_res += current_batch_size self.images_shown += current_batch_size # Logging if images_this_res % 10000 < current_batch_size: phase = "fade-in" if alpha < 1.0 else "stabilize" print(f"[{res}x{res}] {images_this_res:,}/{total_images:,} " f"({phase}, α={alpha:.3f}) " f"D: {d_loss.item():.4f}, G: {g_loss.item():.4f}") if images_this_res >= total_images: break # Save checkpoint after each resolution self.save_checkpoint(res) def compute_gradient_penalty(self, real, fake): """WGAN-GP gradient penalty.""" alpha = torch.rand(real.size(0), 1, 1, 1, device=self.device) interpolated = alpha * real + (1 - alpha) * fake interpolated.requires_grad_(True) d_interpolated = self.D(interpolated) gradients = torch.autograd.grad( outputs=d_interpolated, inputs=interpolated, grad_outputs=torch.ones_like(d_interpolated), create_graph=True, retain_graph=True, )[0] gradients = gradients.view(gradients.size(0), -1) gradient_norm = gradients.norm(2, dim=1) return ((gradient_norm - 1) ** 2).mean() def save_checkpoint(self, resolution): """Save model checkpoint.""" torch.save({ 'resolution': resolution, 'generator': self.G.state_dict(), 'discriminator': self.D.state_dict(), 'images_shown': self.images_shown, }, f'progan_checkpoint_{resolution}x{resolution}.pth')| Resolution | Images (M) | Batch Size | GPU Hours* |
|---|---|---|---|
| 4×4 | 0.8M | 128 | ~2h |
| 8×8 | 0.8M | 128 | ~4h |
| 16×16 | 0.8M | 128 | ~6h |
| 32×32 | 0.8M | 64 | ~10h |
| 64×64 | 0.8M | 32 | ~16h |
| 128×128 | 0.8M | 16 | ~24h |
| 256×256 | 1.6M | 8 | ~48h |
| 512×512 | 1.6M | 4 | ~72h |
| 1024×1024 | 1.6M | 2 | ~96h |
Full ProGAN training to 1024×1024 takes approximately 2 weeks on a single high-end GPU (NVIDIA V100). The original paper used 8 GPUs for faster training. Modern implementations can be faster with improved infrastructure, but high-resolution GAN training remains computationally intensive.
ProGAN achieved unprecedented results in high-resolution image generation and established methodology that influenced all subsequent work.
Counter-intuitively, progressive training often requires fewer total images than direct training. The early resolutions establish correct structure quickly, then higher resolutions 'just' add details. Direct training wastes many iterations with the generator learning and unlearning structure while struggling with details simultaneously.
Progressive GAN demonstrated that curriculum learning applies powerfully to generative modeling. By breaking the hard problem of high-resolution generation into a sequence of easier problems, ProGAN achieved what direct training could not.
ProGAN's progressive training and stabilization techniques directly enabled StyleGAN, which we'll study next. StyleGAN kept the progressive training framework while introducing revolutionary changes to how the generator uses the latent space—achieving even more impressive results and unprecedented control over generated images.