Loading content...
All the GAN variants we've studied so far generate images from noise without explicit control over what is generated. The latent code z determines the output, but we can't easily specify 'generate a cat' or 'generate a smiling person.' Conditional GANs (cGANs) solve this fundamental limitation by incorporating additional information—class labels, text descriptions, input images, or other conditions—into both the generator and discriminator.
The conditional GAN formulation, introduced by Mirza and Osindero in 2014, is deceptively simple but incredibly powerful:
$$\min_G \max_D V(D, G) = \mathbb{E}{x \sim p{data}}[\log D(x|c)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z|c)|c))]$$
where c is the conditioning variable. This small modification enables a vast ecosystem of applications: class-conditional image generation, image-to-image translation, text-to-image synthesis, style transfer, and much more.
This page explores the theory and practice of conditional GANs, from simple class conditioning to complex image translation architectures like pix2pix and CycleGAN.
By the end of this page, you will understand class-conditional generation, conditioning strategies (concatenation, projection, normalization), image-to-image translation with pix2pix, unpaired translation with CycleGAN, auxiliary classifier GANs (ACGAN), and projection-based conditioning.
The fundamental idea of conditional GANs is to provide both the generator and discriminator with additional information c:
This dual conditioning is essential. If only G is conditioned, nothing forces it to actually use the condition—it might learn to ignore c. By conditioning D, we create an adversarial incentive: if G ignores c, the discriminator can tell that the output doesn't match the condition.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
import torchimport torch.nn as nnimport torch.nn.functional as F class ConditionalGenerator(nn.Module): """ Basic conditional generator with class embedding. Conditioning strategy: Concatenate one-hot class embedding with z """ def __init__( self, latent_dim=100, num_classes=10, img_channels=3, ngf=64 ): super().__init__() # Embed class labels to same dimension as latent self.embed = nn.Embedding(num_classes, latent_dim) # Generator takes z + embedded label self.fc = nn.Linear(latent_dim * 2, ngf * 8 * 4 * 4) self.main = nn.Sequential( nn.BatchNorm2d(ngf * 8), nn.ReLU(True), nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf), nn.ReLU(True), nn.ConvTranspose2d(ngf, img_channels, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, z, labels): """ z: [batch, latent_dim] - noise vector labels: [batch] - integer class labels """ # Embed class labels label_embedding = self.embed(labels) # [batch, latent_dim] # Concatenate z and label embedding combined = torch.cat([z, label_embedding], dim=1) # [batch, latent_dim*2] # Project to spatial dimensions x = self.fc(combined) x = x.view(-1, 512, 4, 4) return self.main(x) class ConditionalDiscriminator(nn.Module): """ Conditional discriminator that judges reality AND class match. Conditioning strategy: Embed label, expand to image size, concatenate """ def __init__( self, num_classes=10, img_channels=3, img_size=64, ndf=64 ): super().__init__() self.img_size = img_size # Embed class to image-sized tensor self.embed = nn.Embedding(num_classes, img_size * img_size) # Discriminator takes image + embedded label channel self.main = nn.Sequential( # Input: (img_channels + 1) x img_size x img_size nn.Conv2d(img_channels + 1, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, img, labels): """ img: [batch, channels, H, W] labels: [batch] - integer class labels """ # Create label image label_embedding = self.embed(labels) # [batch, H*W] label_image = label_embedding.view(-1, 1, self.img_size, self.img_size) # Concatenate as additional channel combined = torch.cat([img, label_image], dim=1) return self.main(combined).view(-1) # Training a conditional GANdef train_cgan(G, D, dataloader, num_classes, epochs=100): """ Conditional GAN training loop. Key difference from unconditional: both G and D receive class labels. """ criterion = nn.BCELoss() optimG = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimD = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999)) for epoch in range(epochs): for real_images, real_labels in dataloader: batch_size = real_images.size(0) # ==== Train Discriminator ==== D.zero_grad() # Real samples with correct labels real_output = D(real_images, real_labels) real_loss = criterion(real_output, torch.ones(batch_size)) # Fake samples with conditioning z = torch.randn(batch_size, 100) # Generate with same labels as real batch (or random) fake_labels = real_labels # Use same labels fake_images = G(z, fake_labels) fake_output = D(fake_images.detach(), fake_labels) fake_loss = criterion(fake_output, torch.zeros(batch_size)) d_loss = real_loss + fake_loss d_loss.backward() optimD.step() # ==== Train Generator ==== G.zero_grad() output = D(fake_images, fake_labels) g_loss = criterion(output, torch.ones(batch_size)) g_loss.backward() optimG.step()Both G and D must be conditioned. If only G is conditioned, it can learn to ignore the condition since D only checks reality. If only D is conditioned, G has no gradient signal about how to use the condition. The adversarial relationship ensures G uses the condition correctly because D will reject outputs that don't match.
There are multiple ways to inject conditioning information into generator and discriminator. The choice affects training stability, output quality, and scalability to many classes.
Concatenation is the simplest approach: embed the condition and concatenate it with the input.
Pros: Simple, intuitive, works reasonably well Cons: Doesn't scale well to many classes; label information can be ignored in deep layers
1234567
# Generator: concat at inputcombined = torch.cat([z, label_embed], dim=1) # [B, z_dim + embed_dim]x = self.fc(combined) # Discriminator: concat as channellabel_channel = label_embed.view(B, 1, H, W).expand(-1, -1, H, W)combined = torch.cat([image, label_channel], dim=1) # [B, C+1, H, W]| Strategy | Pros | Cons | Used In |
|---|---|---|---|
| Concatenation | Simple, intuitive | Doesn't scale, can be ignored | Original cGAN |
| Projection | Theoretically motivated, scales well | More complex discriminator | SNGAN, SA-GAN |
| Conditional BN | Effective, deep integration | Class-specific parameters | BigGAN, ResNet-based |
| All Three | Maximum conditioning power | Most complex | BigGAN (all combined) |
Auxiliary Classifier GAN (ACGAN) adds a classification head to the discriminator, providing an additional training signal. Instead of just judging real/fake, the discriminator also predicts the class of the input.
ACGAN objective:
$$L_S = \mathbb{E}[\log P(S=\text{real}|x_{\text{real}})] + \mathbb{E}[\log P(S=\text{fake}|x_{\text{fake}})]$$ $$L_C = \mathbb{E}[\log P(C=c|x_{\text{real}})] + \mathbb{E}[\log P(C=c|x_{\text{fake}})]$$
The classification objective L_C forces G to generate class-recognizable images.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
import torchimport torch.nn as nn class ACGANDiscriminator(nn.Module): """ Discriminator with auxiliary classifier. Outputs: 1. Real/fake probability (standard GAN output) 2. Class probabilities (auxiliary classifier) """ def __init__(self, num_classes, ndf=64): super().__init__() # Shared convolutional backbone self.features = nn.Sequential( nn.Conv2d(3, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), ) # Head 1: Real/Fake discrimination self.disc_head = nn.Sequential( nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) # Head 2: Class classification self.aux_head = nn.Sequential( nn.Conv2d(ndf * 8, num_classes, 4, 1, 0, bias=False), nn.Flatten() # No softmax here - use CrossEntropyLoss which includes it ) def forward(self, x): features = self.features(x) disc_output = self.disc_head(features).view(-1) class_output = self.aux_head(features) return disc_output, class_output def train_acgan(G, D, dataloader, num_classes, epochs=100): """ ACGAN training with both adversarial and classification loss. """ adv_loss = nn.BCELoss() aux_loss = nn.CrossEntropyLoss() optimG = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimD = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999)) for epoch in range(epochs): for real_images, real_labels in dataloader: batch_size = real_images.size(0) # ==== Train Discriminator ==== D.zero_grad() # Real images disc_real, class_real = D(real_images) d_loss_real = adv_loss(disc_real, torch.ones(batch_size)) d_loss_class_real = aux_loss(class_real, real_labels) # Fake images z = torch.randn(batch_size, 100) fake_labels = torch.randint(0, num_classes, (batch_size,)) fake_images = G(z, fake_labels) disc_fake, class_fake = D(fake_images.detach()) d_loss_fake = adv_loss(disc_fake, torch.zeros(batch_size)) d_loss_class_fake = aux_loss(class_fake, fake_labels) # Total D loss: adversarial + classification for both real and fake d_loss = d_loss_real + d_loss_fake + d_loss_class_real + d_loss_class_fake d_loss.backward() optimD.step() # ==== Train Generator ==== G.zero_grad() disc_output, class_output = D(fake_images) g_loss_adv = adv_loss(disc_output, torch.ones(batch_size)) g_loss_aux = aux_loss(class_output, fake_labels) # G wants to fool D AND produce correct class g_loss = g_loss_adv + g_loss_aux g_loss.backward() optimG.step()ACGAN adds a classification task; the projection discriminator uses class embeddings directly. Empirically, projection works better for many-class scenarios (e.g., ImageNet with 1000 classes) because ACGAN's classifier can overfit to easy classes. BigGAN uses projection with cBN, not ACGAN.
pix2pix (Isola et al., 2017) is a landmark conditional GAN for paired image-to-image translation. Given input-output image pairs (e.g., satellite → map, sketch → photo, day → night), pix2pix learns to transform images from one domain to another.
Key innovations:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
import torchimport torch.nn as nn class UNetGenerator(nn.Module): """ U-Net generator for pix2pix. Encoder-decoder architecture with skip connections. Skip connections help preserve spatial details from input. Structure: Input → Enc1 → Enc2 → ... → EncN → DecN → ... → Dec2 → Dec1 → Output ↓ ↓ ↑ ↑ └───────┴── skips ────┴────────┘ """ def __init__(self, in_channels=3, out_channels=3, ngf=64): super().__init__() # Encoder self.enc1 = self._encoder_block(in_channels, ngf, bn=False) # 64 self.enc2 = self._encoder_block(ngf, ngf * 2) # 128 self.enc3 = self._encoder_block(ngf * 2, ngf * 4) # 256 self.enc4 = self._encoder_block(ngf * 4, ngf * 8) # 512 self.enc5 = self._encoder_block(ngf * 8, ngf * 8) # 512 self.enc6 = self._encoder_block(ngf * 8, ngf * 8) # 512 self.enc7 = self._encoder_block(ngf * 8, ngf * 8) # 512 self.enc8 = self._encoder_block(ngf * 8, ngf * 8, bn=False) # 512 (bottleneck) # Decoder with skip connections self.dec1 = self._decoder_block(ngf * 8, ngf * 8, dropout=True) self.dec2 = self._decoder_block(ngf * 16, ngf * 8, dropout=True) # *2 for skip self.dec3 = self._decoder_block(ngf * 16, ngf * 8, dropout=True) self.dec4 = self._decoder_block(ngf * 16, ngf * 8) self.dec5 = self._decoder_block(ngf * 16, ngf * 4) self.dec6 = self._decoder_block(ngf * 8, ngf * 2) self.dec7 = self._decoder_block(ngf * 4, ngf) self.final = nn.Sequential( nn.ConvTranspose2d(ngf * 2, out_channels, 4, 2, 1), nn.Tanh() ) def _encoder_block(self, in_ch, out_ch, bn=True): layers = [nn.Conv2d(in_ch, out_ch, 4, 2, 1, bias=False)] if bn: layers.append(nn.BatchNorm2d(out_ch)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return nn.Sequential(*layers) def _decoder_block(self, in_ch, out_ch, dropout=False): layers = [ nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias=False), nn.BatchNorm2d(out_ch), ] if dropout: layers.append(nn.Dropout(0.5)) layers.append(nn.ReLU(inplace=True)) return nn.Sequential(*layers) def forward(self, x): # Encoder pass - save for skip connections e1 = self.enc1(x) e2 = self.enc2(e1) e3 = self.enc3(e2) e4 = self.enc4(e3) e5 = self.enc5(e4) e6 = self.enc6(e5) e7 = self.enc7(e6) e8 = self.enc8(e7) # Decoder pass with skip connections d1 = self.dec1(e8) d2 = self.dec2(torch.cat([d1, e7], dim=1)) # Skip from e7 d3 = self.dec3(torch.cat([d2, e6], dim=1)) d4 = self.dec4(torch.cat([d3, e5], dim=1)) d5 = self.dec5(torch.cat([d4, e4], dim=1)) d6 = self.dec6(torch.cat([d5, e3], dim=1)) d7 = self.dec7(torch.cat([d6, e2], dim=1)) return self.final(torch.cat([d7, e1], dim=1)) class PatchGANDiscriminator(nn.Module): """ PatchGAN discriminator for pix2pix. Instead of outputting a single real/fake score, outputs a grid where each cell judges whether a 70x70 patch is real or fake. This focuses the discriminator on local texture/style, leaving global structure to the L1 loss. """ def __init__(self, in_channels=6, ndf=64): # 6 = input + target super().__init__() self.model = nn.Sequential( # 70x70 PatchGAN nn.Conv2d(in_channels, ndf, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, 4, 1, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 8, 1, 4, 1, 1), # No sigmoid - use BCEWithLogitsLoss ) def forward(self, input_img, target_img): # Concatenate input and target x = torch.cat([input_img, target_img], dim=1) return self.model(x) # pix2pix lossdef pix2pix_loss(G, D, input_img, target_img, lambda_l1=100): """ pix2pix objective: adversarial + L1. cGAN Loss: The adversarial loss L1 Loss: Encourages output to be similar to target The L1 term is crucial - without it, outputs are realistic but may not match the target content. """ fake_img = G(input_img) # Discriminator loss pred_real = D(input_img, target_img) pred_fake = D(input_img, fake_img.detach()) loss_D_real = F.binary_cross_entropy_with_logits( pred_real, torch.ones_like(pred_real) ) loss_D_fake = F.binary_cross_entropy_with_logits( pred_fake, torch.zeros_like(pred_fake) ) loss_D = (loss_D_real + loss_D_fake) * 0.5 # Generator loss pred_fake_for_G = D(input_img, fake_img) loss_G_adv = F.binary_cross_entropy_with_logits( pred_fake_for_G, torch.ones_like(pred_fake_for_G) ) loss_G_l1 = F.l1_loss(fake_img, target_img) loss_G = loss_G_adv + lambda_l1 * loss_G_l1 return loss_D, loss_Gpix2pix requires paired training data—for each input, the exact corresponding output. But many translation tasks have no pairs: horses ↔ zebras, summer ↔ winter, Monet ↔ photos. CycleGAN (Zhu et al., 2017) enables unpaired image-to-image translation through cycle consistency.
The insight:
If we translate horse → zebra → horse, we should get back the original horse. This cycle consistency constraint prevents the model from mapping all horses to a single zebra or learning arbitrary transformations.
Two generators, two discriminators:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
import torchimport torch.nn as nnimport torch.nn.functional as F class CycleGAN: """ CycleGAN for unpaired image-to-image translation. Key losses: 1. Adversarial: G(x) should look like Y, F(y) should look like X 2. Cycle consistency: F(G(x)) ≈ x and G(F(y)) ≈ y 3. Identity (optional): G(y) ≈ y and F(x) ≈ x """ def __init__(self, G, F, D_X, D_Y, lambda_cycle=10, lambda_identity=0.5): """ G: Generator X → Y F: Generator Y → X D_X: Discriminator for domain X D_Y: Discriminator for domain Y """ self.G = G self.F = F self.D_X = D_X self.D_Y = D_Y self.lambda_cycle = lambda_cycle self.lambda_identity = lambda_identity def generator_loss(self, real_X, real_Y): """ Generator loss for both G and F. """ # ----- Forward cycle: X → Y → X ----- fake_Y = self.G(real_X) # G(X) reconstructed_X = self.F(fake_Y) # F(G(X)) # ----- Backward cycle: Y → X → Y ----- fake_X = self.F(real_Y) # F(Y) reconstructed_Y = self.G(fake_X) # G(F(Y)) # ----- Adversarial loss ----- # G wants D_Y to think fake_Y is real pred_fake_Y = self.D_Y(fake_Y) loss_G_adv = F.mse_loss(pred_fake_Y, torch.ones_like(pred_fake_Y)) # F wants D_X to think fake_X is real pred_fake_X = self.D_X(fake_X) loss_F_adv = F.mse_loss(pred_fake_X, torch.ones_like(pred_fake_X)) # ----- Cycle consistency loss ----- loss_cycle_X = F.l1_loss(reconstructed_X, real_X) loss_cycle_Y = F.l1_loss(reconstructed_Y, real_Y) loss_cycle = loss_cycle_X + loss_cycle_Y # ----- Identity loss (optional, helps preserve color) ----- # G should be identity on Y, F should be identity on X identity_Y = self.G(real_Y) identity_X = self.F(real_X) loss_identity = F.l1_loss(identity_Y, real_Y) + F.l1_loss(identity_X, real_X) # Total generator loss loss_G = (loss_G_adv + loss_F_adv + self.lambda_cycle * loss_cycle + self.lambda_identity * self.lambda_cycle * loss_identity) return loss_G, { 'adv_G': loss_G_adv.item(), 'adv_F': loss_F_adv.item(), 'cycle': loss_cycle.item() } def discriminator_loss(self, real_X, real_Y): """ Discriminator loss for both D_X and D_Y. """ # Generate fakes fake_Y = self.G(real_X).detach() fake_X = self.F(real_Y).detach() # D_Y: distinguish real_Y from fake_Y pred_real_Y = self.D_Y(real_Y) pred_fake_Y = self.D_Y(fake_Y) loss_D_Y = 0.5 * ( F.mse_loss(pred_real_Y, torch.ones_like(pred_real_Y)) + F.mse_loss(pred_fake_Y, torch.zeros_like(pred_fake_Y)) ) # D_X: distinguish real_X from fake_X pred_real_X = self.D_X(real_X) pred_fake_X = self.D_X(fake_X) loss_D_X = 0.5 * ( F.mse_loss(pred_real_X, torch.ones_like(pred_real_X)) + F.mse_loss(pred_fake_X, torch.zeros_like(pred_fake_X)) ) return loss_D_X + loss_D_Y # The cycle consistency constraint is crucial:"""Without cycle consistency:- G could map ALL horses to the SAME zebra- The zebra would be realistic (fooling D_Y)- But the mapping would lose all input information With cycle consistency:- If G(horse_1) = zebra_A and G(horse_2) = zebra_A- Then F(zebra_A) must equal both horse_1 AND horse_2- This is impossible! F can only output one thing- So G must preserve enough information to reconstruct The cycle constraint forces meaningful correspondence:- Each horse maps to a DIFFERENT, corresponding zebra- The transformation preserves structure/identity"""Use pix2pix when you have paired data (each input has exactly one correct output). Use CycleGAN when you only have unpaired collections from each domain. pix2pix typically produces more accurate results because it has ground truth to learn from, but CycleGAN is applicable to many more tasks where pairs don't exist.
Conditional GANs transform GANs from random image samplers into controllable generation systems. From class labels to input images, conditioning enables a vast array of practical applications.
You've now completed the GAN Variants module, covering the major architectural innovations from DCGAN through StyleGAN and conditional generation. These foundations underpin virtually all modern GAN-based systems. The next module on Flow-Based Models explores a complementary approach to generative modeling with exact likelihood computation.