Loading learning content...
The standard VAE provides a principled framework for generative modeling, but it has known limitations: blurry samples, posterior collapse, limited expressiveness, and entangled representations. Researchers have developed numerous variants that address these issues, each with distinct trade-offs.
This page surveys the most important VAE variants, organized by what problem they solve. Understanding this landscape helps you select the right approach for your application and provides insight into how generative modeling has evolved.
We'll cover variants that modify the objective function, the posterior family, the latent structure, and the conditioning mechanism—each representing a different axis of improvement to the base VAE.
By the end of this page, you will: (1) Understand β-VAE and how it promotes disentanglement, (2) Master VQ-VAE and its discrete latent space, (3) Understand hierarchical VAEs for complex data, (4) Know how to condition VAEs on labels or other information, (5) Survey additional variants: WAE, VAE-GAN, NVAE, and more.
β-VAE (Higgins et al., 2017) is perhaps the simplest and most influential VAE variant. It modifies the objective by adding a hyperparameter $\beta$ that weighs the KL divergence term:
$$\mathcal{L}{\beta\text{-VAE}} = \mathbb{E}{q}[\log p(\mathbf{x}|\mathbf{z})] - \beta \cdot D_{\text{KL}}(q(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}))$$
With $\beta > 1$, the stronger KL regularization creates an information bottleneck:
Pros:
Cons:
β-VAE can be derived from a rate-distortion perspective:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
import torchimport torch.nn.functional as Ffrom typing import Dict class BetaVAE: """ β-VAE: VAE with weighted KL term for disentanglement. Loss = Reconstruction + β * KL β = 1: Standard VAE β > 1: Stronger disentanglement, weaker reconstruction β < 1: Weaker disentanglement, stronger reconstruction """ def __init__(self, model, beta: float = 4.0): self.model = model self.beta = beta def loss_function( self, x: torch.Tensor, outputs: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: """Compute β-VAE loss.""" recon = outputs['recon'] mu = outputs['mu'] log_var = outputs['log_var'] # Reconstruction loss recon_loss = F.binary_cross_entropy_with_logits( recon.view(x.size(0), -1), x.view(x.size(0), -1), reduction='sum' ) / x.size(0) # KL divergence kl_loss = -0.5 * torch.sum( 1 + log_var - mu.pow(2) - log_var.exp() ) / x.size(0) # β-weighted total loss total_loss = recon_loss + self.beta * kl_loss return { 'loss': total_loss, 'recon_loss': recon_loss, 'kl_loss': kl_loss, 'beta': self.beta } class AnnealedBetaVAE: """ β-VAE with KL annealing: gradually increase β during training. Helps prevent posterior collapse at the start of training. """ def __init__( self, model, beta_start: float = 0.0, beta_end: float = 4.0, anneal_steps: int = 10000 ): self.model = model self.beta_start = beta_start self.beta_end = beta_end self.anneal_steps = anneal_steps self.step = 0 @property def beta(self) -> float: """Current β value based on training step.""" if self.step >= self.anneal_steps: return self.beta_end # Linear annealing progress = self.step / self.anneal_steps return self.beta_start + progress * (self.beta_end - self.beta_start) def update_step(self): """Call after each training step.""" self.step += 1 def loss_function(self, x, outputs): """Compute loss with current annealed β.""" recon = outputs['recon'] mu = outputs['mu'] log_var = outputs['log_var'] recon_loss = F.binary_cross_entropy_with_logits( recon.view(x.size(0), -1), x.view(x.size(0), -1), reduction='sum' ) / x.size(0) kl_loss = -0.5 * torch.sum( 1 + log_var - mu.pow(2) - log_var.exp() ) / x.size(0) current_beta = self.beta total_loss = recon_loss + current_beta * kl_loss self.update_step() return { 'loss': total_loss, 'recon_loss': recon_loss, 'kl_loss': kl_loss, 'beta': current_beta }| Variant | Key Idea | Improvement |
|---|---|---|
| β-TCVAE | Decompose KL into interpretable terms, weight only total correlation | More principled disentanglement |
| FactorVAE | Add discriminator to penalize total correlation directly | Better disentanglement metrics |
| DIP-VAE | Regularize covariance of q(z) toward identity | Direct independence constraint |
| Anneal-VAE | Gradually increase β during training | Avoids early posterior collapse |
VQ-VAE (van den Oord et al., 2017) takes a radically different approach: instead of continuous Gaussian latents, it uses a discrete codebook of embedding vectors.
The latent representation is a grid of indices into the codebook—completely discrete.
Advantages:
Argmax (finding nearest codebook entry) is non-differentiable. VQ-VAE uses the straight-through estimator:
$$\text{Forward: } \mathbf{z}q = \mathbf{e}{\text{argmin}_k ||\mathbf{z}_e - \mathbf{e}k||^2}$$ $$\text{Backward: } \nabla{\mathbf{z}e} = \nabla{\mathbf{z}_q}$$ (straight-through)
The VQ-VAE loss has three components:
$$\mathcal{L} = \underbrace{||\mathbf{x} - \hat{\mathbf{x}}||^2}_\text{Reconstruction} + \underbrace{||\text{sg}[\mathbf{z}e] - \mathbf{e}||^2}\text{Codebook loss} + \underbrace{\beta ||\mathbf{z}e - \text{sg}[\mathbf{e}]||^2}\text{Commitment loss}$$
where $\text{sg}[\cdot]$ is the stop-gradient operator.
Reconstruction Loss: Train decoder to reconstruct from quantized codes.
Codebook Loss: Move codebook entries toward encoder outputs (trains codebook).
Commitment Loss: Keep encoder outputs close to codebook entries (trains encoder to "commit" to codes). $\beta$ typically 0.25.
An alternative to codebook loss: update codebook entries via EMA of assigned encoder outputs:
$$\mathbf{e}_k \leftarrow \gamma \mathbf{e}k + (1-\gamma) \bar{\mathbf{z}}{e,k}$$
where $\bar{\mathbf{z}}_{e,k}$ is the mean of encoder outputs assigned to code $k$. Often more stable than gradient-based codebook update.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
import torchimport torch.nn as nnimport torch.nn.functional as F class VectorQuantizer(nn.Module): """ Vector Quantization layer for VQ-VAE. Maps continuous encoder outputs to discrete codebook indices. """ def __init__( self, num_embeddings: int = 512, embedding_dim: int = 64, commitment_cost: float = 0.25, use_ema: bool = True, ema_decay: float = 0.99 ): super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.commitment_cost = commitment_cost self.use_ema = use_ema # Codebook embeddings self.embedding = nn.Embedding(num_embeddings, embedding_dim) self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings) if use_ema: self.ema_decay = ema_decay # EMA cluster size and sum for updating embeddings self.register_buffer('ema_cluster_size', torch.zeros(num_embeddings)) self.register_buffer('ema_embedding_sum', self.embedding.weight.clone()) def forward(self, z_e: torch.Tensor): """ Args: z_e: Encoder output [batch, height, width, embedding_dim] or [batch, seq_len, embedding_dim] Returns: z_q: Quantized vectors (same shape as z_e) loss: VQ loss (codebook + commitment) encoding_indices: Indices into codebook """ # Flatten to [N, embedding_dim] for distance computation z_e_flat = z_e.reshape(-1, self.embedding_dim) # Compute distances: ||z_e - e_k||^2 # = ||z_e||^2 + ||e_k||^2 - 2 * z_e @ e_k^T distances = ( z_e_flat.pow(2).sum(dim=1, keepdim=True) + self.embedding.weight.pow(2).sum(dim=1) - 2 * z_e_flat @ self.embedding.weight.T ) # Find nearest codebook entry encoding_indices = distances.argmin(dim=1) # Quantize: look up embeddings z_q_flat = self.embedding(encoding_indices) z_q = z_q_flat.view_as(z_e) # Compute loss if self.training: if self.use_ema: # EMA update of embeddings self._ema_update(z_e_flat, encoding_indices) # Only commitment loss needed with EMA loss = self.commitment_cost * F.mse_loss(z_e, z_q.detach()) else: # Codebook loss + commitment loss codebook_loss = F.mse_loss(z_q, z_e.detach()) commitment_loss = F.mse_loss(z_e, z_q.detach()) loss = codebook_loss + self.commitment_cost * commitment_loss else: loss = torch.tensor(0.0, device=z_e.device) # Straight-through estimator: copy gradients from z_q to z_e z_q = z_e + (z_q - z_e).detach() return z_q, loss, encoding_indices.view(z_e.shape[:-1]) def _ema_update(self, z_e_flat: torch.Tensor, indices: torch.Tensor): """Update codebook using exponential moving average.""" # One-hot encode indices encodings = F.one_hot(indices, self.num_embeddings).float() # Update cluster size cluster_size = encodings.sum(dim=0) self.ema_cluster_size.mul_(self.ema_decay).add_( cluster_size, alpha=1 - self.ema_decay ) # Update embedding sum embedding_sum = encodings.T @ z_e_flat self.ema_embedding_sum.mul_(self.ema_decay).add_( embedding_sum, alpha=1 - self.ema_decay ) # Normalize to get new embeddings n = self.ema_cluster_size.sum() cluster_size_normalized = ( (self.ema_cluster_size + 1e-5) / (n + self.num_embeddings * 1e-5) * n ) self.embedding.weight.data = ( self.ema_embedding_sum / cluster_size_normalized.unsqueeze(1) ) class VQVAE(nn.Module): """Complete VQ-VAE model.""" def __init__( self, in_channels: int = 3, hidden_dims: list = [128, 256], num_embeddings: int = 512, embedding_dim: int = 64 ): super().__init__() # Encoder encoder_layers = [] ch = in_channels for h_dim in hidden_dims: encoder_layers.extend([ nn.Conv2d(ch, h_dim, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(h_dim), nn.ReLU() ]) ch = h_dim encoder_layers.append(nn.Conv2d(ch, embedding_dim, kernel_size=1)) self.encoder = nn.Sequential(*encoder_layers) # Vector Quantizer self.vq = VectorQuantizer(num_embeddings, embedding_dim) # Decoder decoder_layers = [nn.Conv2d(embedding_dim, hidden_dims[-1], kernel_size=1)] for i in range(len(hidden_dims) - 1, -1, -1): out_ch = hidden_dims[i-1] if i > 0 else in_channels decoder_layers.extend([ nn.ConvTranspose2d(hidden_dims[i], out_ch, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(out_ch) if i > 0 else nn.Identity(), nn.ReLU() if i > 0 else nn.Identity() ]) self.decoder = nn.Sequential(*decoder_layers) def forward(self, x): # Encode z_e = self.encoder(x) z_e = z_e.permute(0, 2, 3, 1) # [B, H, W, D] # Quantize z_q, vq_loss, indices = self.vq(z_e) z_q = z_q.permute(0, 3, 1, 2) # [B, D, H, W] # Decode x_recon = self.decoder(z_q) return x_recon, vq_loss, indicesStandard VAEs use a single latent space. Hierarchical VAEs use multiple layers of latent variables at different scales, enabling modeling of complex, multi-level structure in data.
Complex data (like high-resolution images) have structure at multiple scales:
A single latent vector struggles to capture all levels. Hierarchical VAEs dedicate different latent layers to different scales.
The Ladder VAE (Sønderby et al., 2016) uses a bidirectional architecture:
Generative model (top-down): $$p(\mathbf{x}, \mathbf{z}_{1:L}) = p(\mathbf{z}L) \prod{l=1}^{L-1} p(\mathbf{z}l | \mathbf{z}{l+1}) \cdot p(\mathbf{x} | \mathbf{z}_1)$$
Inference model (bottom-up): $$q(\mathbf{z}_{1:L} | \mathbf{x}) = q(\mathbf{z}1 | \mathbf{x}) \prod{l=2}^{L} q(\mathbf{z}l | \mathbf{z}{l-1}, \mathbf{x})$$
The key innovation: inference combines bottom-up features from the data with top-down features from the generative model.
NVAE (Vahdat & Kautz, 2020) is a state-of-the-art hierarchical VAE that achieves image generation quality rivaling GANs:
Key innovations:
ELBO for Hierarchical VAEs:
$$\mathcal{L} = \mathbb{E}q[\log p(\mathbf{x}|\mathbf{z}{1:L})] - \sum_{l=1}^{L} D_{\text{KL}}(q(\mathbf{z}l | \mathbf{z}{<l}, \mathbf{x}) || p(\mathbf{z}l | \mathbf{z}{>l}))$$
The KL is computed between inference and generative distributions at each level.
VD-VAE (Child, 2021) further scales the hierarchical approach:
| Model | Latent Groups | Key Features | FID (CelebA) |
|---|---|---|---|
| Standard VAE | 1 | Single latent vector | ~70 |
| Ladder VAE | 5-10 | Bidirectional inference | ~40 |
| BIVA | 15-30 | Bidirectional + skip connections | ~25 |
| NVAE | 30-50 | Residual cells, spectral norm | ~7 |
| VD-VAE | 50-78 | Very deep, careful KL balancing | ~5 |
Conditional VAEs (CVAEs) extend VAEs to generate data conditioned on additional information—labels, attributes, text, or other modalities.
The CVAE models $p(\mathbf{x}|\mathbf{c})$ where $\mathbf{c}$ is the conditioning information:
Generative model: $$p(\mathbf{x}|\mathbf{c}) = \int p(\mathbf{x}|\mathbf{z}, \mathbf{c}) p(\mathbf{z}|\mathbf{c}) d\mathbf{z}$$
Inference model: $$q(\mathbf{z}|\mathbf{x}, \mathbf{c})$$
CVAE ELBO: $$\log p(\mathbf{x}|\mathbf{c}) \geq \mathbb{E}{q(\mathbf{z}|\mathbf{x},\mathbf{c})}[\log p(\mathbf{x}|\mathbf{z}, \mathbf{c})] - D{\text{KL}}(q(\mathbf{z}|\mathbf{x}, \mathbf{c}) || p(\mathbf{z}|\mathbf{c}))$$
1. Concatenation:
2. Adaptive normalization (FiLM):
3. Cross-attention:
4. Prior conditioning:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
import torchimport torch.nn as nnimport torch.nn.functional as F class ConditionalEncoder(nn.Module): """ Encoder that conditions on class labels. """ def __init__( self, input_dim: int, hidden_dim: int, latent_dim: int, num_classes: int ): super().__init__() # Class embedding self.class_embed = nn.Embedding(num_classes, hidden_dim) # MLP that takes x concatenated with class embedding self.encoder = nn.Sequential( nn.Linear(input_dim + hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU() ) self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim) def forward(self, x: torch.Tensor, c: torch.Tensor): """ Args: x: Input [batch, input_dim] c: Class labels [batch] """ c_embed = self.class_embed(c) # [batch, hidden_dim] x_c = torch.cat([x, c_embed], dim=-1) # [batch, input_dim + hidden_dim] h = self.encoder(x_c) mu = self.fc_mu(h) log_var = self.fc_logvar(h) return mu, log_var class ConditionalDecoder(nn.Module): """ Decoder that conditions on class labels. """ def __init__( self, output_dim: int, hidden_dim: int, latent_dim: int, num_classes: int ): super().__init__() self.class_embed = nn.Embedding(num_classes, hidden_dim) self.decoder = nn.Sequential( nn.Linear(latent_dim + hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim) ) def forward(self, z: torch.Tensor, c: torch.Tensor): """ Args: z: Latent code [batch, latent_dim] c: Class labels [batch] """ c_embed = self.class_embed(c) z_c = torch.cat([z, c_embed], dim=-1) return self.decoder(z_c) class CVAE(nn.Module): """ Complete Conditional VAE. """ def __init__( self, input_dim: int, hidden_dim: int = 256, latent_dim: int = 64, num_classes: int = 10 ): super().__init__() self.encoder = ConditionalEncoder( input_dim, hidden_dim, latent_dim, num_classes ) self.decoder = ConditionalDecoder( input_dim, hidden_dim, latent_dim, num_classes ) self.latent_dim = latent_dim def reparameterize(self, mu, log_var): std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return mu + std * eps def forward(self, x: torch.Tensor, c: torch.Tensor): mu, log_var = self.encoder(x, c) z = self.reparameterize(mu, log_var) recon = self.decoder(z, c) return recon, mu, log_var @torch.no_grad() def generate(self, c: torch.Tensor, num_samples: int = 1): """Generate samples for a given class.""" device = next(self.parameters()).device # Expand class to match num_samples c = c.unsqueeze(0).expand(num_samples) # Sample from prior z = torch.randn(num_samples, self.latent_dim, device=device) # Decode conditioned on class return torch.sigmoid(self.decoder(z, c))CVAEs are widely used for: (1) Class-conditional generation: Generate images of specific classes (digits, objects), (2) Attribute manipulation: Condition on attributes to control style, (3) Multi-modal learning: Condition on text to generate images, (4) Data augmentation: Generate additional samples for minority classes, (5) Style transfer: Condition on style while preserving content.
The VAE literature contains many more variants. Here we survey additional important ones:
WAE (Tolstikhin et al., 2018) replaces KL divergence with Wasserstein distance for matching aggregate posterior to prior:
$$\mathcal{L}_{\text{WAE}} = \mathbb{E}[c(\mathbf{x}, \hat{\mathbf{x}})] + \lambda \cdot D_Z(q(\mathbf{z}) || p(\mathbf{z}))$$
where $D_Z$ can be:
Advantage: Often produces sharper samples than VAE.
VAE-GAN (Larsen et al., 2016) combines VAE with GAN:
$$\mathcal{L} = \mathcal{L}{\text{recon}} + \mathcal{L}{\text{KL}} + \mathcal{L}_{\text{GAN}}$$
Result: Sharper images than VAE, more stable than pure GAN.
IWAE (Burda et al., 2016) uses multiple samples for tighter bound:
$$\mathcal{L}{K} = \mathbb{E}{\mathbf{z}_1, ..., \mathbf{z}K \sim q}\left[\log \frac{1}{K}\sum{k=1}^{K} \frac{p(\mathbf{x}, \mathbf{z}_k)}{q(\mathbf{z}_k|\mathbf{x})}\right]$$
As $K \to \infty$, the bound approaches true log-likelihood.
Trade-off: Better log-likelihood, but may learn weaker inference networks.
| Variant | Key Modification | Main Benefit | Main Drawback |
|---|---|---|---|
| β-VAE | Weight KL by β > 1 | Disentanglement | Worse reconstruction |
| VQ-VAE | Discrete codebook latents | No posterior collapse, sharp samples | Needs autoregressive prior for generation |
| Hierarchical VAE | Multiple latent layers | Multi-scale modeling | Complex training, more hyperparameters |
| CVAE | Condition on labels/attributes | Controlled generation | Needs labeled data |
| WAE | Wasserstein instead of KL | Sharper samples | MMD/GAN training complexity |
| VAE-GAN | Add GAN discriminator | Sharp + principled | GAN training instability |
| IWAE | Multiple importance samples | Tighter likelihood bound | Weaker inference network |
| Flow-VAE | Normalizing flow posterior | More expressive posterior | Computational cost |
AAE (Adversarial Autoencoder): Replace KL with adversarial matching of aggregate posterior to prior.
RAE (Regularized Autoencoder): Add ex-post density estimation to deterministic autoencoder.
Optimus: VAE for text using pretrained BERT encoder and GPT-2 decoder.
Dalle (discrete VAE): dVAE component of DALL-E for image tokenization.
Perceiver VAE: Attention-based architecture for multi-modal data.
3D-VAE: VAEs for 3D shapes, point clouds, meshes.
Graph VAE: VAEs for molecular graphs and network structures.
With so many variants, how do you choose? Here's a decision framework based on your requirements:
1. Generation Quality Priority:
2. Disentanglement Priority:
3. Conditional Generation:
4. Discrete Representations:
5. Computational Budget:
6. Downstream Task:
Don't reach for complex variants immediately. Standard VAE with proper hyperparameters (learning rate, β, latent dimension, architecture) often performs better than poorly-tuned advanced variants. Master the baseline first, then add complexity as needed for specific requirements.
The VAE field continues to evolve. Here are current trends and emerging directions:
Recent work pushes VAEs to larger scales:
Models like NVAE and VD-VAE show VAEs can match GANs with sufficient scale.
VAE + Diffusion:
VAE + Transformers:
VAE + Pretrained Models:
After years of GANs dominating image generation, VAEs are experiencing a renaissance. Hierarchical VAEs match GAN quality, while VAE-based latent spaces power state-of-the-art diffusion models. The principled probabilistic framework and stable training make VAEs increasingly attractive for production systems.
We've surveyed the rich landscape of VAE variants, each addressing specific limitations of the standard model:
Module Complete:
You've now completed the comprehensive Variational Autoencoders module. You understand the ELBO objective, encoder-decoder architecture, latent space structure, the reparameterization trick, and the diverse ecosystem of VAE variants. You're equipped to implement, train, and deploy VAEs for a wide range of generative modeling applications.
Congratulations! You now have comprehensive knowledge of Variational Autoencoders—from theoretical foundations to practical variants. You can select the right VAE variant for your application, implement it correctly, diagnose training issues, and leverage the learned representations. VAEs represent one of the most elegant intersections of deep learning and probabilistic modeling.