Loading learning content...
Traditional variational inference optimizes variational parameters separately for each observation—a process that can be computationally prohibitive when dealing with millions of data points. Amortized inference fundamentally changes this paradigm by training a neural network to directly output approximate posterior parameters given an observation.
Instead of solving a new optimization problem for each data point, we solve one optimization problem to learn an inference network (also called an encoder or recognition network) that generalizes across all observations. This amortization of computational cost is what makes modern approaches like Variational Autoencoders practical at scale.
By the end of this page, you will understand the amortization principle, master the design of inference networks, learn about the amortization gap and its implications, explore encoder architectures, and understand how amortized inference enables scalable probabilistic models.
The Core Problem:
In standard VI, for each observation xₙ, we optimize variational parameters φₙ:
$$\phi_n^* = \arg\max_{\phi_n} \mathcal{L}(\phi_n; x_n, \theta)$$
For a dataset of N observations, this requires N separate optimizations. The computational cost is O(N × cost-per-optimization), which becomes prohibitive for large datasets.
The Amortization Solution:
Instead of optimizing φₙ directly, we learn a function q(z|x; ψ) parameterized by ψ that maps any observation to approximate posterior parameters:
$$\text{Traditional: } x_n \rightarrow \text{optimize } \phi_n \rightarrow q_{\phi_n}(z)$$ $$\text{Amortized: } x_n \rightarrow \text{neural network}\psi \rightarrow q\psi(z|x_n)$$
The key insight: the computational cost of optimization is amortized across all data points. We pay the cost once during training, then get instant inference at test time.
| Aspect | Traditional VI | Amortized VI |
|---|---|---|
| Parameters per sample | Unique φₙ for each xₙ | Shared ψ across all x |
| Training complexity | O(N × iterations) | O(iterations) with minibatches |
| Test-time inference | Full optimization required | Single forward pass |
| Memory requirement | O(N × dim(φ)) | O(dim(ψ)) - independent of N |
| Generalization | No generalization | Generalizes to unseen x |
| Optimality guarantee | Local optimum per sample | Average optimality (amortization gap) |
Think of amortized inference like compiling a program. Traditional VI is like interpreting code at runtime—flexible but slow. Amortized VI compiles the inference procedure into a neural network—there's upfront cost, but execution is fast. The compiled network is an 'inference program' that runs in constant time regardless of dataset size.
The inference network (encoder) maps observations to variational parameters. Its architecture depends on the data modality and the chosen variational family.
Standard Gaussian Encoder:
The most common choice outputs the mean and (log) variance of a Gaussian approximate posterior:
$$q_\psi(z|x) = \mathcal{N}(z; \mu_\psi(x), \text{diag}(\sigma_\psi^2(x)))$$
where μ_ψ(x) and σ_ψ(x) are neural network outputs.
Architecture Components:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.distributions import Normal class GaussianEncoder(nn.Module): """ Inference network for amortized VI. Maps observations to Gaussian posterior parameters. """ def __init__(self, input_dim, hidden_dims, latent_dim): super().__init__() self.latent_dim = latent_dim # Build backbone network layers = [] prev_dim = input_dim for h_dim in hidden_dims: layers.extend([ nn.Linear(prev_dim, h_dim), nn.LayerNorm(h_dim), nn.GELU(), ]) prev_dim = h_dim self.backbone = nn.Sequential(*layers) # Separate heads for mean and log-variance self.mean_head = nn.Linear(prev_dim, latent_dim) self.logvar_head = nn.Linear(prev_dim, latent_dim) # Initialize log-variance head to produce small initial variances nn.init.zeros_(self.logvar_head.weight) nn.init.constant_(self.logvar_head.bias, -1.0) def forward(self, x): """ Encode observation to posterior parameters. Returns: mean, log_variance """ h = self.backbone(x) mean = self.mean_head(h) logvar = self.logvar_head(h) # Clamp log-variance for numerical stability logvar = torch.clamp(logvar, min=-10, max=10) return mean, logvar def sample(self, x, num_samples=1): """ Sample from the approximate posterior using reparameterization. """ mean, logvar = self.forward(x) std = torch.exp(0.5 * logvar) # Reparameterization trick eps = torch.randn(x.shape[0], num_samples, self.latent_dim, device=x.device) z = mean.unsqueeze(1) + std.unsqueeze(1) * eps return z, mean, logvar def log_prob(self, x, z): """Compute log q(z|x) for given z samples.""" mean, logvar = self.forward(x) std = torch.exp(0.5 * logvar) dist = Normal(mean, std) return dist.log_prob(z).sum(dim=-1) class ConvolutionalEncoder(nn.Module): """ Convolutional inference network for image data. Used in VAEs for image generation. """ def __init__(self, in_channels, latent_dim, base_channels=32): super().__init__() self.conv_layers = nn.Sequential( # 32x32 -> 16x16 nn.Conv2d(in_channels, base_channels, 4, stride=2, padding=1), nn.BatchNorm2d(base_channels), nn.LeakyReLU(0.2), # 16x16 -> 8x8 nn.Conv2d(base_channels, base_channels * 2, 4, stride=2, padding=1), nn.BatchNorm2d(base_channels * 2), nn.LeakyReLU(0.2), # 8x8 -> 4x4 nn.Conv2d(base_channels * 2, base_channels * 4, 4, stride=2, padding=1), nn.BatchNorm2d(base_channels * 4), nn.LeakyReLU(0.2), # 4x4 -> 2x2 nn.Conv2d(base_channels * 4, base_channels * 8, 4, stride=2, padding=1), nn.BatchNorm2d(base_channels * 8), nn.LeakyReLU(0.2), ) # Compute flattened size (assuming 32x32 input) self.flat_dim = base_channels * 8 * 2 * 2 self.fc_mean = nn.Linear(self.flat_dim, latent_dim) self.fc_logvar = nn.Linear(self.flat_dim, latent_dim) def forward(self, x): h = self.conv_layers(x) h = h.view(h.size(0), -1) mean = self.fc_mean(h) logvar = self.fc_logvar(h) return mean, logvarWe output log-variance rather than variance directly. This ensures positivity (variance = exp(log-variance) > 0) and provides better numerical properties during optimization. The gradient of exp() naturally scales small variances up and large variances down, stabilizing training.
Amortization comes with a fundamental tradeoff: the inference network cannot compute the optimal variational parameters for every possible observation. This creates an amortization gap—the difference between the true optimal ELBO and the amortized ELBO.
Formal Definition:
$$\text{Amortization Gap} = \mathbb{E}{x}\left[ \max{\phi} \mathcal{L}(\phi; x) - \mathcal{L}(\psi; x) \right]$$
The first term is the optimal ELBO achievable by per-sample optimization; the second is what the amortized encoder achieves. The gap is always non-negative.
Sources of the Amortization Gap:
Measuring the Amortization Gap:
To diagnose amortization gap issues:
If the gap is significant, the encoder is leaving performance on the table. This can be addressed through architectural improvements or semi-amortized schemes.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
import torchimport torch.optim as optim def measure_amortization_gap(encoder, decoder, x_samples, latent_dim, num_refinement_steps=100, lr=0.01): """ Measure the amortization gap by comparing amortized vs. optimized ELBO. Args: encoder: Trained inference network decoder: Generative model (likelihood) x_samples: Test samples to evaluate latent_dim: Dimensionality of latent space num_refinement_steps: Optimization steps for per-sample refinement lr: Learning rate for refinement Returns: amortized_elbo: ELBO using encoder output optimized_elbo: ELBO after per-sample optimization gap: Difference (always >= 0) """ device = x_samples.device batch_size = x_samples.shape[0] # Step 1: Compute amortized ELBO with torch.no_grad(): mean_amort, logvar_amort = encoder(x_samples) std_amort = torch.exp(0.5 * logvar_amort) # Sample and compute ELBO z = mean_amort + std_amort * torch.randn_like(std_amort) log_likelihood = decoder.log_prob(x_samples, z) kl = -0.5 * (1 + logvar_amort - mean_amort.pow(2) - logvar_amort.exp()).sum(-1) amortized_elbo = (log_likelihood - kl).mean().item() # Step 2: Per-sample optimization (initialize from encoder) mean_opt = mean_amort.clone().detach().requires_grad_(True) logvar_opt = logvar_amort.clone().detach().requires_grad_(True) optimizer = optim.Adam([mean_opt, logvar_opt], lr=lr) for _ in range(num_refinement_steps): optimizer.zero_grad() std = torch.exp(0.5 * logvar_opt) z = mean_opt + std * torch.randn_like(std) log_likelihood = decoder.log_prob(x_samples, z) kl = -0.5 * (1 + logvar_opt - mean_opt.pow(2) - logvar_opt.exp()).sum(-1) loss = -(log_likelihood - kl).mean() # Negative ELBO loss.backward() optimizer.step() # Compute optimized ELBO with torch.no_grad(): std = torch.exp(0.5 * logvar_opt) z = mean_opt + std * torch.randn_like(std) log_likelihood = decoder.log_prob(x_samples, z) kl = -0.5 * (1 + logvar_opt - mean_opt.pow(2) - logvar_opt.exp()).sum(-1) optimized_elbo = (log_likelihood - kl).mean().item() gap = optimized_elbo - amortized_elbo return { 'amortized_elbo': amortized_elbo, 'optimized_elbo': optimized_elbo, 'amortization_gap': gap, 'gap_percentage': 100 * gap / abs(optimized_elbo) if optimized_elbo != 0 else 0 }Semi-amortized inference bridges the gap between fully amortized and fully optimized approaches. The idea: use the encoder to provide a good initialization, then refine with a few gradient steps.
The Procedure:
This achieves the best of both worlds:
Iterative Amortization (Learned Refinement):
A more sophisticated approach uses a second network to perform refinement:
$$\phi_{k+1} = \phi_k + g_\theta(x, \phi_k, \nabla_{\phi} \mathcal{L})$$
where g_θ is a learned update function. This can be viewed as:
Amortized Variational Filtering (AVF):
For sequential data, AVF uses an RNN that takes the previous belief state and new observation to produce updated posterior parameters. This amortizes inference across time while maintaining coherent beliefs.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
import torchimport torch.nn as nn class SemiAmortizedVI(nn.Module): """ Semi-amortized inference: encoder initialization + gradient refinement. Reduces amortization gap while maintaining efficiency. """ def __init__(self, encoder, decoder, latent_dim, num_refinement_steps=5, lr=0.1): super().__init__() self.encoder = encoder self.decoder = decoder self.latent_dim = latent_dim self.num_steps = num_refinement_steps self.lr = lr def forward(self, x, refine=True): """ Compute approximate posterior, optionally with refinement. Args: x: Input observations refine: Whether to perform gradient refinement Returns: mean, logvar: Final posterior parameters """ # Initialize from encoder mean, logvar = self.encoder(x) if not refine or self.num_steps == 0: return mean, logvar # Gradient refinement mean = mean.detach().requires_grad_(True) logvar = logvar.detach().requires_grad_(True) for step in range(self.num_steps): # Compute ELBO gradient std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mean + std * eps log_lik = self.decoder.log_prob(x, z) kl = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(-1) elbo = log_lik - kl elbo_sum = elbo.sum() # Compute gradients grad_mean, grad_logvar = torch.autograd.grad( elbo_sum, [mean, logvar], create_graph=False ) # Gradient ascent step mean = mean + self.lr * grad_mean logvar = logvar + self.lr * grad_logvar # Detach for next iteration mean = mean.detach().requires_grad_(True) logvar = logvar.detach().requires_grad_(True) return mean.detach(), logvar.detach() class IterativeAmortization(nn.Module): """ Learned iterative refinement for amortized inference. Uses a refinement network to update posterior parameters. """ def __init__(self, encoder, refinement_net, num_iterations=3): super().__init__() self.encoder = encoder self.refinement_net = refinement_net # Maps (x, phi, grad) -> delta_phi self.num_iterations = num_iterations def forward(self, x, decoder): # Initial encoding mean, logvar = self.encoder(x) for i in range(self.num_iterations): # Compute gradient of ELBO w.r.t. current parameters mean.requires_grad_(True) logvar.requires_grad_(True) std = torch.exp(0.5 * logvar) z = mean + std * torch.randn_like(std) log_lik = decoder.log_prob(x, z) kl = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(-1) elbo = (log_lik - kl).sum() grad_mean, grad_logvar = torch.autograd.grad( elbo, [mean, logvar], retain_graph=True ) # Refinement network computes update phi = torch.cat([mean, logvar], dim=-1) grad = torch.cat([grad_mean, grad_logvar], dim=-1) delta = self.refinement_net(x, phi, grad) delta_mean, delta_logvar = delta.chunk(2, dim=-1) mean = mean + delta_mean logvar = logvar + delta_logvar mean = mean.detach() logvar = logvar.detach() return mean, logvarBeyond simple factorial Gaussians, amortized inference extends to complex posterior structures including hierarchical models, structured latent spaces, and graph-based dependencies.
Hierarchical Amortization:
For hierarchical latent variable models with multiple levels:
$$p(x, z_1, z_2) = p(x|z_1)p(z_1|z_2)p(z_2)$$
The inference network can output parameters for the entire hierarchy:
$$q(z_1, z_2|x) = q(z_1|x, z_2)q(z_2|x)$$
This requires the encoder to output parameters for all latent levels, often with bottom-up and top-down information flow.
Structured Posterior Amortization:
For latent variables with structure (sequences, graphs, sets):
Flow-Based Amortization:
Combining amortization with normalizing flows:
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
import torchimport torch.nn as nn class HierarchicalEncoder(nn.Module): """ Encoder for hierarchical latent variable models. Outputs parameters for multiple levels in a bottom-up pass, then refines with top-down information. """ def __init__(self, input_dim, hidden_dim, latent_dims): """ Args: input_dim: Observation dimensionality hidden_dim: Hidden layer size latent_dims: List of latent dims per level [z1_dim, z2_dim, ...] """ super().__init__() self.num_levels = len(latent_dims) self.latent_dims = latent_dims # Bottom-up encoders self.bottom_up = nn.ModuleList() prev_dim = input_dim for z_dim in latent_dims: self.bottom_up.append(nn.Sequential( nn.Linear(prev_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), )) prev_dim = hidden_dim # Top-down refinement self.top_down = nn.ModuleList() for i in range(self.num_levels - 1): self.top_down.append(nn.Sequential( nn.Linear(hidden_dim + latent_dims[i + 1], hidden_dim), nn.GELU(), )) # Parameter heads for each level self.mean_heads = nn.ModuleList([ nn.Linear(hidden_dim, z_dim) for z_dim in latent_dims ]) self.logvar_heads = nn.ModuleList([ nn.Linear(hidden_dim, z_dim) for z_dim in latent_dims ]) def forward(self, x): """ Encode to hierarchical posterior parameters. Returns: List of (mean, logvar) tuples, one per level """ batch_size = x.shape[0] # Bottom-up pass: compute features at each level features = [] h = x for encoder in self.bottom_up: h = encoder(h) features.append(h) # Top-down pass: refine with higher-level information posteriors = [] # Highest level uses bottom-up features directly mean_top = self.mean_heads[-1](features[-1]) logvar_top = self.logvar_heads[-1](features[-1]) posteriors.append((mean_top, logvar_top)) # Sample highest level z_sample = mean_top + torch.exp(0.5 * logvar_top) * torch.randn_like(mean_top) # Lower levels incorporate higher-level samples for i in range(self.num_levels - 2, -1, -1): # Combine bottom-up features with top-down sample combined = torch.cat([features[i], z_sample], dim=-1) h = self.top_down[i](combined) mean = self.mean_heads[i](h) logvar = self.logvar_heads[i](h) posteriors.insert(0, (mean, logvar)) # Sample for next level down z_sample = mean + torch.exp(0.5 * logvar) * torch.randn_like(mean) return posteriors # [(mean_z1, logvar_z1), (mean_z2, logvar_z2), ...]Training amortized inference networks presents unique challenges. Here are key considerations for practical implementations:
Encoder-Decoder Balance:
The encoder and decoder must be balanced in capacity:
Posterior Collapse:
A common failure mode where q(z|x) ≈ p(z) regardless of input. Symptoms:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
import torchimport torch.nn as nnimport numpy as np class VAETrainer: """ Training utilities for amortized variational inference. Includes techniques to prevent posterior collapse and improve training. """ def __init__(self, encoder, decoder, latent_dim, kl_annealing_epochs=100, free_bits=0.0): self.encoder = encoder self.decoder = decoder self.latent_dim = latent_dim self.kl_annealing_epochs = kl_annealing_epochs self.free_bits = free_bits # Minimum KL per dimension self.current_epoch = 0 def kl_weight(self): """Linear annealing from 0 to 1.""" if self.kl_annealing_epochs == 0: return 1.0 return min(1.0, self.current_epoch / self.kl_annealing_epochs) def compute_elbo(self, x, num_samples=1): """ Compute ELBO with KL annealing and free bits. """ # Encode mean, logvar = self.encoder(x) std = torch.exp(0.5 * logvar) # Sample using reparameterization eps = torch.randn(x.shape[0], num_samples, self.latent_dim, device=x.device) z = mean.unsqueeze(1) + std.unsqueeze(1) * eps # Reconstruction loss x_expanded = x.unsqueeze(1).expand(-1, num_samples, -1) log_lik = self.decoder.log_prob(x_expanded, z).mean(dim=1) # KL divergence with free bits kl_per_dim = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1) if self.free_bits > 0: # Clamp KL per dimension to at least free_bits kl_per_dim = torch.clamp(kl_per_dim, min=self.free_bits) kl = kl_per_dim.sum(dim=-1) # Apply KL annealing beta = self.kl_weight() elbo = log_lik - beta * kl return { 'elbo': elbo.mean(), 'log_likelihood': log_lik.mean(), 'kl': kl.mean(), 'beta': beta, 'active_dims': (kl_per_dim > 0.1).float().mean() * self.latent_dim } def step(self): """Call at end of each epoch.""" self.current_epoch += 1 def monitor_collapse(self, x_batch): """ Check for signs of posterior collapse. Returns warning if detected. """ with torch.no_grad(): mean, logvar = self.encoder(x_batch) kl_per_dim = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1) active_dims = (kl_per_dim.mean(0) > 0.1).sum().item() total_kl = kl_per_dim.sum(-1).mean().item() if total_kl < 0.5: return f"⚠️ Posterior collapse warning: KL={total_kl:.3f}, Active dims={active_dims}" return NoneWe've explored amortized inference as an efficient approach to variational inference that enables scalability to massive datasets. Here are the key takeaways:
You now understand amortized inference and its crucial role in scaling variational methods. Next, we'll explore implicit variational inference, which avoids explicit density specification altogether.