Loading content...
With the VAE objective established, we now turn to the architectural components that bring it to life. The VAE consists of two neural networks working in tandem:
These networks are not arbitrary—their design choices profoundly impact training dynamics, reconstruction quality, generation fidelity, and latent space structure. This page provides a comprehensive treatment of both networks, examining architectural patterns, output parameterization, and the subtle design decisions that separate effective VAEs from poorly functioning ones.
Understanding encoder and decoder design is essential for adapting VAEs to different data modalities (images, text, audio) and for diagnosing common training pathologies.
By the end of this page, you will: (1) Understand the encoder's role in parameterizing the approximate posterior, (2) Master output layer design for mean and variance, (3) Understand the decoder's role in parameterizing the likelihood, (4) Select appropriate likelihood functions for different data types, (5) Design encoder-decoder pairs for images, sequences, and other modalities, and (6) Implement complete encoder and decoder networks in PyTorch.
The encoder implements the approximate posterior $q_{\boldsymbol{\phi}}(\mathbf{z}|\mathbf{x})$. Given an observation $\mathbf{x}$, it outputs the parameters of a probability distribution over latent codes $\mathbf{z}$.
From a probabilistic perspective, the encoder performs amortized variational inference. Instead of optimizing variational parameters independently for each datapoint (as in classical VI), the encoder learns a function that maps any $\mathbf{x}$ to its posterior parameters. This amortization is what makes VAEs scalable.
The most common choice uses a diagonal Gaussian posterior:
$$q_{\boldsymbol{\phi}}(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\mathbf{z}; \boldsymbol{\mu}{\boldsymbol{\phi}}(\mathbf{x}), \text{diag}(\boldsymbol{\sigma}^2{\boldsymbol{\phi}}(\mathbf{x})))$$
The encoder network produces two outputs:
Why log-variance? Variance must be positive. Outputting $\log \sigma^2$ allows the network to produce any real value, which is mapped to positive variance via $\sigma^2 = \exp(\log \sigma^2)$. This improves numerical stability and gradient flow.
1. Backbone Network
The backbone processes raw input into a feature representation. Architecture depends on data modality:
2. Projection Heads
The backbone features are projected to posterior parameters via separate linear layers:
μ = W_μ · h + b_μ
log σ² = W_σ · h + b_σ
Using separate heads (not shared weights) allows independent learning of location and scale.
3. Output Constraints
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
import torchimport torch.nn as nnfrom typing import Tuple class ConvEncoder(nn.Module): """ Convolutional encoder for image VAE. Progressively downsamples input through conv layers, then projects to latent mean and log-variance. """ def __init__( self, in_channels: int = 3, latent_dim: int = 256, hidden_dims: list = [32, 64, 128, 256, 512] ): super().__init__() self.latent_dim = latent_dim # Build convolutional backbone modules = [] for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d( in_channels, h_dim, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(h_dim), nn.LeakyReLU(0.2) ) ) in_channels = h_dim self.backbone = nn.Sequential(*modules) # Calculate flattened dimension after convolutions # For 64x64 input with 5 stride-2 layers: 64 -> 32 -> 16 -> 8 -> 4 -> 2 # Final: hidden_dims[-1] * 2 * 2 = 512 * 4 = 2048 self.flatten_dim = hidden_dims[-1] * 4 # Adjust based on input size # Projection layers for mean and log-variance self.fc_mu = nn.Linear(self.flatten_dim, latent_dim) self.fc_logvar = nn.Linear(self.flatten_dim, latent_dim) # Initialize projection layers carefully nn.init.xavier_normal_(self.fc_mu.weight) nn.init.xavier_normal_(self.fc_logvar.weight) nn.init.zeros_(self.fc_logvar.bias) # Start with unit variance def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Encode input to latent distribution parameters. Args: x: Input images [batch, channels, height, width] Returns: mu: Latent mean [batch, latent_dim] log_var: Latent log-variance [batch, latent_dim] """ # Extract features through backbone h = self.backbone(x) h = h.view(h.size(0), -1) # Flatten # Project to distribution parameters mu = self.fc_mu(h) log_var = self.fc_logvar(h) # Clamp log_var for numerical stability # Prevents variance from being too small (collapse) or too large (explosion) log_var = torch.clamp(log_var, min=-10.0, max=10.0) return mu, log_var @staticmethod def reparameterize(mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: """ Reparameterization trick: sample z = mu + sigma * epsilon. Args: mu: Mean of q(z|x) log_var: Log variance of q(z|x) Returns: z: Sampled latent code """ std = torch.exp(0.5 * log_var) # sigma = sqrt(exp(log_var)) eps = torch.randn_like(std) # epsilon ~ N(0, I) return mu + std * epsSeveral important design choices affect encoder quality. Let's examine each:
The latent dimension $d$ controls the information bottleneck:
Rule of thumb: Start with latent dimension that's 1-2 orders of magnitude smaller than input dimension. For 64×64×3 images (12,288 dimensions), try latent dimensions 32-256.
The encoder and decoder should have roughly balanced capacity. If the decoder is much more powerful, it may reconstruct well without using the latent code, causing posterior collapse. If the encoder is much more powerful, it may encode too much information, hurting generation.
Encoder initialization affects early training dynamics:
| Data Type | Backbone Architecture | Key Considerations | Typical Latent Dim |
|---|---|---|---|
| Small Images (32×32) | 3-4 conv layers, stride 2 | Avoid too aggressive downsampling | 32-128 |
| Medium Images (64-128px) | ResNet-style, 5-6 blocks | Use skip connections for gradient flow | 64-256 |
| Large Images (256+px) | Hierarchical, multi-scale | Consider hierarchical latent spaces | 256-512 or hierarchical |
| Text Sequences | Transformer encoder or LSTM | Aggregate sequence to fixed vector | 64-512 |
| Tabular Data | MLP with layer normalization | Handle mixed feature types | 8-64 |
| Audio/Spectrograms | 1D or 2D convolutions | Long-range temporal dependencies | 64-256 |
Some practitioners use alternative variance parameterizations: (1) Softplus variance: $\sigma^2 = \text{softplus}(\text{output}) = \log(1 + e^{\text{output}})$, smoother than exp near zero. (2) Constant variance: Fix $\sigma^2 = 1$ and only learn the mean—simpler but less expressive. (3) Learned minimum variance: $\sigma^2 = \sigma_{\min}^2 + \text{softplus}(\text{output})$ prevents collapse.
The decoder implements the likelihood model $p_{\boldsymbol{\theta}}(\mathbf{x}|\mathbf{z})$. Given a latent code $\mathbf{z}$, it outputs parameters of a probability distribution over observations $\mathbf{x}$.
The decoder defines the generative process: starting from a simple prior distribution, how do we transform latent codes into complex data? The decoder is the 'artist'—it learns to paint images (or write text, or generate audio) given abstract codes.
The decoder typically mirrors the encoder architecture:
This symmetry is not required but often works well, ensuring similar representational capacity in both directions.
The decoder's output layer is crucial—it must produce valid parameters for the chosen likelihood distribution:
For Binary/Bernoulli Data (e.g., binarized MNIST):
For Continuous Data with Gaussian Likelihood:
For Discrete/Categorical Data:
For Normalized Images [0, 1]:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
import torchimport torch.nn as nn class ConvDecoder(nn.Module): """ Convolutional decoder for image VAE. Progressively upsamples from latent code to image, outputting likelihood parameters. """ def __init__( self, out_channels: int = 3, latent_dim: int = 256, hidden_dims: list = [512, 256, 128, 64, 32], output_type: str = 'bernoulli' # 'bernoulli', 'gaussian', 'continuous_bernoulli' ): super().__init__() self.latent_dim = latent_dim self.output_type = output_type # Initial projection from latent to spatial features # For 64x64 output with 5 upsample layers, start at 2x2 self.initial_size = 2 self.initial_channels = hidden_dims[0] self.fc = nn.Linear( latent_dim, hidden_dims[0] * self.initial_size * self.initial_size ) # Build transposed convolutional blocks (upsampling) modules = [] for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d( hidden_dims[i], hidden_dims[i + 1], kernel_size=4, stride=2, padding=1 ), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU(0.2) ) ) self.decoder_blocks = nn.Sequential(*modules) # Final output layer - depends on likelihood type if output_type == 'bernoulli': # Output logits, BCEWithLogits applies sigmoid internally self.output_layer = nn.ConvTranspose2d( hidden_dims[-1], out_channels, kernel_size=4, stride=2, padding=1 ) elif output_type == 'gaussian': # Output mean and log-variance self.mean_layer = nn.ConvTranspose2d( hidden_dims[-1], out_channels, kernel_size=4, stride=2, padding=1 ) self.logvar_layer = nn.ConvTranspose2d( hidden_dims[-1], out_channels, kernel_size=4, stride=2, padding=1 ) else: # Default: single output with sigmoid for [0,1] continuous self.output_layer = nn.Sequential( nn.ConvTranspose2d( hidden_dims[-1], out_channels, kernel_size=4, stride=2, padding=1 ), nn.Sigmoid() # Constrain to [0, 1] ) def forward(self, z: torch.Tensor) -> torch.Tensor: """ Decode latent code to observation likelihood parameters. Args: z: Latent codes [batch, latent_dim] Returns: For bernoulli: logits [batch, channels, height, width] For gaussian: (mean, log_var) tuple Otherwise: reconstruction [batch, channels, height, width] """ # Project latent to initial spatial features h = self.fc(z) h = h.view(-1, self.initial_channels, self.initial_size, self.initial_size) # Upsample through decoder blocks h = self.decoder_blocks(h) # Generate output based on likelihood type if self.output_type == 'gaussian': mean = self.mean_layer(h) log_var = self.logvar_layer(h) log_var = torch.clamp(log_var, min=-10.0, max=2.0) return mean, log_var else: return self.output_layer(h) class MLPDecoder(nn.Module): """ Simple MLP decoder for tabular or flattened data. """ def __init__( self, output_dim: int, latent_dim: int = 64, hidden_dims: list = [256, 512, 1024] ): super().__init__() layers = [] in_dim = latent_dim for h_dim in hidden_dims: layers.extend([ nn.Linear(in_dim, h_dim), nn.LayerNorm(h_dim), nn.LeakyReLU(0.2) ]) in_dim = h_dim layers.append(nn.Linear(in_dim, output_dim)) self.net = nn.Sequential(*layers) def forward(self, z: torch.Tensor) -> torch.Tensor: return self.net(z)The choice of likelihood function $p_{\boldsymbol{\theta}}(\mathbf{x}|\mathbf{z})$ significantly impacts VAE behavior. Each likelihood makes assumptions about your data distribution.
Use for: Binary data or data that can be interpreted as probabilities
$$p(\mathbf{x}|\mathbf{z}) = \prod_i p_i^{x_i} (1-p_i)^{1-x_i}$$
where $p_i = \sigma(f_{\theta}(\mathbf{z})_i)$
Reconstruction loss: Binary cross-entropy
$$-\log p = -\sum_i [x_i \log p_i + (1-x_i) \log(1-p_i)]$$
Issues: Bernoulli assumes each pixel is independent given $\mathbf{z}$. This conditional independence ignores spatial correlations, leading to blurry reconstructions.
Use for: Continuous real-valued data
$$p(\mathbf{x}|\mathbf{z}) = \mathcal{N}(\mathbf{x}; \boldsymbol{\mu}_{\theta}(\mathbf{z}), \sigma^2 \mathbf{I})$$
Reconstruction loss: Mean squared error (with fixed variance)
$$-\log p \propto \frac{1}{2\sigma^2} ||\mathbf{x} - \boldsymbol{\mu}||^2$$
The variance parameter $\sigma^2$:
Use for: Integer-valued data (pixel intensities 0-255)
$$p(x_i|\mathbf{z}) = \sigma\left(\frac{x_i + 0.5 - \mu_i}{s}\right) - \sigma\left(\frac{x_i - 0.5 - \mu_i}{s}\right)$$
Why it's better for images: Images are discretized (256 values), not truly continuous. Treating them as continuous (Gaussian) wastes model capacity. Discretized logistic respects the discrete nature while remaining differentiable.
| Likelihood | Best For | Pros | Cons | Reconstruction Loss |
|---|---|---|---|---|
| Bernoulli | Binarized data, probabilities | Simple, numerically stable | Ignores correlations, blurry | BCE |
| Gaussian (fixed σ²) | Continuous data | Simple, MSE loss | Ignores correlations, tuning σ² | MSE |
| Gaussian (learned σ²) | Heteroscedastic data | Adaptive uncertainty | More parameters, can collapse | NLL with variance |
| Discretized Logistic | Integer images | Respects discretization | More complex implementation | Mixture NLL |
| Continuous Bernoulli | [0,1] continuous data | Correct normalization | Less common, newer | Modified BCE |
VAEs are often criticized for producing blurry samples. This is partly caused by the reconstruction loss averaging over uncertainty. When multiple reconstructions are plausible, the model learns to produce the mean—which is blurry. Solutions include: better likelihood functions (discretized logistic), hierarchical VAEs, adversarial objectives, or flow-based decoders.
Let's assemble the encoder and decoder into a complete VAE, including the reparameterization trick (covered in depth in a later page) and loss computation.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Dict, Tuple class VAE(nn.Module): """ Complete Variational Autoencoder. Combines encoder and decoder with reparameterization trick and provides loss computation. """ def __init__( self, in_channels: int = 3, latent_dim: int = 256, hidden_dims: list = [32, 64, 128, 256, 512], image_size: int = 64 ): super().__init__() self.latent_dim = latent_dim self.image_size = image_size # ============= ENCODER ============= encoder_layers = [] ch = in_channels for h_dim in hidden_dims: encoder_layers.append( nn.Sequential( nn.Conv2d(ch, h_dim, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(h_dim), nn.LeakyReLU(0.2) ) ) ch = h_dim self.encoder = nn.Sequential(*encoder_layers) # Calculate flattened size after encoder self.feat_size = image_size // (2 ** len(hidden_dims)) self.flatten_dim = hidden_dims[-1] * self.feat_size * self.feat_size # Latent distribution parameters self.fc_mu = nn.Linear(self.flatten_dim, latent_dim) self.fc_logvar = nn.Linear(self.flatten_dim, latent_dim) # ============= DECODER ============= self.fc_decode = nn.Linear( latent_dim, hidden_dims[-1] * self.feat_size * self.feat_size ) decoder_layers = [] hidden_dims_rev = hidden_dims[::-1] for i in range(len(hidden_dims_rev) - 1): decoder_layers.append( nn.Sequential( nn.ConvTranspose2d( hidden_dims_rev[i], hidden_dims_rev[i + 1], kernel_size=4, stride=2, padding=1 ), nn.BatchNorm2d(hidden_dims_rev[i + 1]), nn.LeakyReLU(0.2) ) ) # Final layer to image channels decoder_layers.append( nn.ConvTranspose2d( hidden_dims_rev[-1], in_channels, kernel_size=4, stride=2, padding=1 ) ) self.decoder = nn.Sequential(*decoder_layers) self._init_weights() def _init_weights(self): """Initialize weights for stable training.""" for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') if m.bias is not None: nn.init.zeros_(m.bias) # Initialize log_var bias to ~0 (unit variance) nn.init.zeros_(self.fc_logvar.bias) def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Encode input to latent distribution parameters.""" h = self.encoder(x) h = h.view(h.size(0), -1) mu = self.fc_mu(h) log_var = self.fc_logvar(h) log_var = torch.clamp(log_var, min=-10.0, max=10.0) return mu, log_var def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: """Sample z using reparameterization trick.""" if self.training: std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return mu + std * eps else: # At inference, use mean (deterministic) return mu def decode(self, z: torch.Tensor) -> torch.Tensor: """Decode latent code to reconstruction logits.""" h = self.fc_decode(z) h = h.view(-1, self.encoder[-1][0].out_channels, self.feat_size, self.feat_size) return self.decoder(h) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Complete forward pass. Returns: Dictionary with 'recon', 'mu', 'log_var', 'z' """ mu, log_var = self.encode(x) z = self.reparameterize(mu, log_var) recon = self.decode(z) return { 'recon': recon, 'mu': mu, 'log_var': log_var, 'z': z } def loss_function( self, x: torch.Tensor, outputs: Dict[str, torch.Tensor], beta: float = 1.0 ) -> Dict[str, torch.Tensor]: """ Compute VAE loss: reconstruction + beta * KL. Args: x: Original input outputs: Dictionary from forward pass beta: KL weight (1.0 = standard VAE) Returns: Dictionary with 'loss', 'recon_loss', 'kl_loss' """ recon = outputs['recon'] mu = outputs['mu'] log_var = outputs['log_var'] # Reconstruction loss (BCE with logits) recon_loss = F.binary_cross_entropy_with_logits( recon, x, reduction='sum' ) / x.size(0) # KL divergence kl_loss = -0.5 * torch.sum( 1 + log_var - mu.pow(2) - log_var.exp() ) / x.size(0) # Total loss loss = recon_loss + beta * kl_loss return { 'loss': loss, 'recon_loss': recon_loss, 'kl_loss': kl_loss } @torch.no_grad() def generate(self, num_samples: int, device: torch.device) -> torch.Tensor: """Generate samples from the prior.""" z = torch.randn(num_samples, self.latent_dim, device=device) samples = self.decode(z) return torch.sigmoid(samples) # Convert logits to probabilities @torch.no_grad() def reconstruct(self, x: torch.Tensor) -> torch.Tensor: """Reconstruct input (deterministic, using mean).""" mu, _ = self.encode(x) recon = self.decode(mu) return torch.sigmoid(recon)Building effective VAEs requires attention to architectural details that significantly impact performance. Here are empirically-validated best practices:
Batch Normalization: Common in image VAEs. Works well but has quirks:
Layer Normalization: More stable for generation:
Group Normalization: Good compromise for small batches:
LeakyReLU (0.1-0.2 slope): Standard choice for VAEs
Swish/SiLU: Increasingly popular
ELU: Alternative to LeakyReLU
When scaling VAEs: (1) Double channels as spatial size halves, (2) Keep latent dimension roughly 1-4% of flattened feature size before projection, (3) Balance encoder and decoder depth, (4) For very deep networks, always use residual connections.
Different data modalities require specialized encoder-decoder designs. Let's examine key patterns:
The architectures we've shown are optimized for images. Key considerations:
Encoder: Recurrent or Transformer architecture
Decoder: Autoregressive generation
Challenge: Posterior collapse is severe—decoder can ignore $\mathbf{z}$ and use only autoregressive context. Solutions include KL annealing, decoder weakening, bag-of-words auxiliary loss.
Encoder: Graph neural network
Decoder: Graph generation
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
import torchimport torch.nn as nnfrom typing import Tuple class LSTMEncoder(nn.Module): """LSTM encoder for sequence data.""" def __init__( self, vocab_size: int, embed_dim: int = 256, hidden_dim: int = 512, latent_dim: int = 128, num_layers: int = 2 ): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM( embed_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=True ) # Bidirectional doubles the hidden size self.fc_mu = nn.Linear(hidden_dim * 2, latent_dim) self.fc_logvar = nn.Linear(hidden_dim * 2, latent_dim) def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: Token indices [batch, seq_len] lengths: Sequence lengths [batch] """ embedded = self.embedding(x) # Pack sequence for efficient LSTM computation packed = nn.utils.rnn.pack_padded_sequence( embedded, lengths.cpu(), batch_first=True, enforce_sorted=False ) _, (hidden, _) = self.lstm(packed) # Concatenate final hidden states from both directions # hidden shape: [num_layers * 2, batch, hidden_dim] hidden = torch.cat([hidden[-2], hidden[-1]], dim=1) # [batch, hidden*2] mu = self.fc_mu(hidden) log_var = self.fc_logvar(hidden) return mu, log_var class LSTMDecoder(nn.Module): """Autoregressive LSTM decoder for sequence generation.""" def __init__( self, vocab_size: int, embed_dim: int = 256, hidden_dim: int = 512, latent_dim: int = 128, num_layers: int = 2, max_len: int = 100 ): super().__init__() self.max_len = max_len self.hidden_dim = hidden_dim self.num_layers = num_layers self.embedding = nn.Embedding(vocab_size, embed_dim) # Project latent to initial hidden state self.z_to_hidden = nn.Linear(latent_dim, hidden_dim * num_layers) self.z_to_cell = nn.Linear(latent_dim, hidden_dim * num_layers) # LSTM takes embedded token + latent code self.lstm = nn.LSTM( embed_dim + latent_dim, hidden_dim, num_layers=num_layers, batch_first=True ) self.output_proj = nn.Linear(hidden_dim, vocab_size) def forward( self, z: torch.Tensor, target: torch.Tensor = None, teacher_forcing_ratio: float = 1.0 ) -> torch.Tensor: """ Args: z: Latent code [batch, latent_dim] target: Target sequence for teacher forcing [batch, seq_len] teacher_forcing_ratio: Probability of using target as next input Returns: Logits [batch, seq_len, vocab_size] """ batch_size = z.size(0) # Initialize hidden state from latent h0 = self.z_to_hidden(z).view(batch_size, self.num_layers, self.hidden_dim) h0 = h0.transpose(0, 1).contiguous() c0 = self.z_to_cell(z).view(batch_size, self.num_layers, self.hidden_dim) c0 = c0.transpose(0, 1).contiguous() if target is not None: # Teacher forcing: feed target tokens seq_len = target.size(1) embedded = self.embedding(target) # [batch, seq_len, embed_dim] # Concatenate z at each step z_expanded = z.unsqueeze(1).expand(-1, seq_len, -1) lstm_input = torch.cat([embedded, z_expanded], dim=2) lstm_out, _ = self.lstm(lstm_input, (h0, c0)) logits = self.output_proj(lstm_out) return logits else: # Autoregressive generation (sampling or greedy) # Implementation would go here for generation raise NotImplementedError("Generation mode not shown")We've covered the complete architecture of VAE encoder and decoder networks. Here are the essential takeaways:
What's Next:
With encoder-decoder architecture understood, we turn to the latent space structure—examining what the learned representations look like, how the prior and posterior interact, and why the geometry of the latent space determines VAE capabilities for generation, interpolation, and manipulation.
You now have complete knowledge of VAE encoder-decoder design. You can implement VAEs for different data types, select appropriate likelihood functions, and make informed architectural decisions. Next, we explore what happens inside the learned latent space.