Loading content...
Traditional variational inference requires that we can evaluate the density q(z) for any value of z. This requirement arises because the ELBO contains terms like log q(z). However, this constraint severely limits the class of distributions we can use.
Implicit distributions are defined only through their sampling procedure—we can draw samples z ~ q, but we cannot evaluate q(z) for a given z. Many powerful generative models are implicit: GANs, diffusion models with ancestral sampling, and neural samplers. Implicit variational inference extends VI to work with these flexible but density-free distributions.
By the end of this page, you will understand why implicit distributions are desirable, master density ratio estimation as a key enabling technique, learn about adversarial variational inference and its variants, understand the tradeoffs between implicit and explicit VI, and see how to implement implicit VI methods.
To understand implicit VI, we must first understand the distinction between explicit and implicit distributions.
Explicit (Prescribed) Distributions:
An explicit distribution has a known density function: $$q_\phi(z) = \frac{1}{Z_\phi} \tilde{q}_\phi(z)$$
where either Z_φ is known (e.g., Gaussian) or intractable but cancels out in the ELBO.
Examples: Gaussian, Mixture of Gaussians, Normalizing Flows
Implicit Distributions:
An implicit distribution is defined only through a sampling procedure: $$z = G_\phi(\epsilon), \quad \epsilon \sim p(\epsilon)$$
where G_φ is a neural network (generator) and p(ε) is a simple noise distribution. We can sample from this distribution easily, but there's no formula for q_φ(z).
Examples: GAN generators, Neural samplers, Some diffusion models
| Property | Explicit | Implicit |
|---|---|---|
| Density evaluation | ✓ Available | ✗ Not available |
| Sampling | ✓ Usually easy | ✓ By construction |
| Expressiveness | Limited by tractability | Virtually unlimited |
| ELBO computation | Direct | Requires estimation |
| Mode coverage | Can undercover | Can overcover or miss modes |
| Training stability | Generally stable | Can be unstable (GAN-like) |
Why Use Implicit Distributions?
The key motivation is expressiveness. Explicit distributions face a fundamental tradeoff:
A neural network G_φ can in principle represent any continuous mapping from noise to samples. This means implicit distributions can capture:
GANs use implicit distributions for generation. The generator G(z) defines an implicit distribution over images—you can sample images but not evaluate their probability. Implicit VI brings this flexibility to posterior approximation, using a neural sampler to approximate p(z|x).
The key insight enabling implicit VI is that even without q(z) directly, we can estimate log q(z)/p(z) using density ratio estimation.
The Core Idea:
Consider two distributions p(z) (prior) and q(z) (variational). We can estimate their ratio using a discriminator/classifier:
Given samples from both distributions, train a classifier D(z) to distinguish them:
At optimality, the discriminator reveals the density ratio: $$D^(z) = \frac{q(z)}{q(z) + p(z)} \quad \Rightarrow \quad \frac{q(z)}{p(z)} = \frac{D^(z)}{1 - D^*(z)}$$
Derivation:
The optimal discriminator minimizes binary cross-entropy: $$\mathcal{L}D = -\mathbb{E}{q(z)}[\log D(z)] - \mathbb{E}_{p(z)}[\log(1 - D(z))]$$
Taking the functional derivative and setting to zero: $$\frac{\delta \mathcal{L}_D}{\delta D(z)} = -\frac{q(z)}{D(z)} + \frac{p(z)}{1-D(z)} = 0$$
Solving: $D^*(z) = \frac{q(z)}{p(z) + q(z)}$
Log-ratio form: $$\log \frac{q(z)}{p(z)} = \log \frac{D^(z)}{1 - D^(z)} = \sigma^{-1}(D^*(z))$$
where σ⁻¹ is the logit function. If D outputs logits directly, the log-ratio is simply D(z).
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
import torchimport torch.nn as nnimport torch.nn.functional as F class DensityRatioEstimator(nn.Module): """ Estimates log density ratio log(q(z)/p(z)) using a discriminator. The discriminator is trained to classify samples from q vs p. """ def __init__(self, latent_dim, hidden_dims=[256, 256]): super().__init__() layers = [] prev_dim = latent_dim for h_dim in hidden_dims: layers.extend([ nn.Linear(prev_dim, h_dim), nn.LayerNorm(h_dim), nn.LeakyReLU(0.2), nn.Dropout(0.1), ]) prev_dim = h_dim layers.append(nn.Linear(prev_dim, 1)) # Output logits self.network = nn.Sequential(*layers) def forward(self, z): """ Returns discriminator logits. For optimal D: logits ≈ log(q(z)/p(z)) """ return self.network(z).squeeze(-1) def log_density_ratio(self, z): """Estimate log q(z)/p(z).""" return self.forward(z) def train_step(self, z_q, z_p, optimizer): """ Single training step. Args: z_q: Samples from q (variational) z_p: Samples from p (prior) optimizer: Discriminator optimizer """ optimizer.zero_grad() # Discriminator outputs for both logits_q = self.forward(z_q) logits_p = self.forward(z_p) # Binary cross-entropy loss # q samples are "real" (label=1), p samples are "fake" (label=0) loss_q = F.binary_cross_entropy_with_logits( logits_q, torch.ones_like(logits_q) ) loss_p = F.binary_cross_entropy_with_logits( logits_p, torch.zeros_like(logits_p) ) loss = loss_q + loss_p loss.backward() optimizer.step() # Compute accuracy for monitoring with torch.no_grad(): acc_q = (logits_q > 0).float().mean() acc_p = (logits_p < 0).float().mean() return { 'loss': loss.item(), 'acc_q': acc_q.item(), 'acc_p': acc_p.item(), }The density ratio estimator is only an approximation. With finite samples and limited discriminator capacity, the ratio estimate is biased. This bias propagates to VI objectives, potentially degrading inference quality. Careful discriminator design and sufficient training are essential.
Adversarial Variational Bayes (AVB) is a foundational implicit VI method that combines VAE-style inference with GAN-style density ratio estimation.
The Setup:
We have:
The encoder defines an implicit approximate posterior q_ψ(z|x)—we can sample from it but cannot evaluate its density.
The ELBO with Density Ratio:
$$\mathcal{L} = \mathbb{E}{q\psi(z|x)}[\log p_\theta(x|z)] - \mathbb{E}{q\psi(z|x)}\left[\log \frac{q_\psi(z|x)}{p(z)}\right]$$
The first term is the reconstruction loss (standard). The second term is the KL, which requires the intractable ratio. We replace it with the discriminator:
$$\mathcal{L} \approx \mathbb{E}{q\psi(z|x)}[\log p_\theta(x|z)] - \mathbb{E}{q\psi(z|x)}[D_\omega(x, z)]$$
Training Procedure:
AVB alternates between:
Discriminator update: Train D_ω to distinguish (x, z) pairs where:
Generator/Encoder update: Maximize ELBO using discriminator's ratio estimate
Decoder update: Maximize reconstruction term
Advantages over standard VAE:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
import torchimport torch.nn as nnimport torch.nn.functional as F class AVBEncoder(nn.Module): """ Implicit encoder for Adversarial Variational Bayes. Maps (x, noise) -> z, defining an implicit q(z|x). """ def __init__(self, input_dim, noise_dim, latent_dim, hidden_dim=256): super().__init__() self.noise_dim = noise_dim self.network = nn.Sequential( nn.Linear(input_dim + noise_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Linear(hidden_dim, latent_dim), ) def forward(self, x, num_samples=1): """Sample from implicit posterior q(z|x).""" batch_size = x.shape[0] # Sample noise eps = torch.randn(batch_size, num_samples, self.noise_dim, device=x.device) # Concatenate x with noise x_expanded = x.unsqueeze(1).expand(-1, num_samples, -1) inputs = torch.cat([x_expanded, eps], dim=-1) # Generate latent samples z = self.network(inputs.view(-1, inputs.shape[-1])) z = z.view(batch_size, num_samples, -1) return z class AVBDiscriminator(nn.Module): """ Discriminator for AVB. Distinguishes (x, z) pairs from encoder vs prior. """ def __init__(self, input_dim, latent_dim, hidden_dim=256): super().__init__() self.network = nn.Sequential( nn.Linear(input_dim + latent_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Linear(hidden_dim, 1), ) def forward(self, x, z): """Return logits for density ratio estimation.""" # Handle multiple z samples per x if z.dim() == 3: batch_size, num_samples, latent_dim = z.shape x_expanded = x.unsqueeze(1).expand(-1, num_samples, -1) inputs = torch.cat([x_expanded, z], dim=-1) inputs = inputs.view(-1, inputs.shape[-1]) logits = self.network(inputs) return logits.view(batch_size, num_samples) else: inputs = torch.cat([x, z], dim=-1) return self.network(inputs).squeeze(-1) class AVB(nn.Module): """ Adversarial Variational Bayes. Combines implicit encoder with discriminator-based KL estimation. """ def __init__(self, input_dim, latent_dim, noise_dim=32, hidden_dim=256): super().__init__() self.encoder = AVBEncoder(input_dim, noise_dim, latent_dim, hidden_dim) self.discriminator = AVBDiscriminator(input_dim, latent_dim, hidden_dim) self.decoder = nn.Sequential( nn.Linear(latent_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), ) self.latent_dim = latent_dim def train_discriminator(self, x, optimizer): """Train discriminator to estimate density ratio.""" optimizer.zero_grad() batch_size = x.shape[0] # Samples from q(z|x) with torch.no_grad(): z_q = self.encoder(x, num_samples=1).squeeze(1) # Samples from p(z) = N(0, I) z_p = torch.randn(batch_size, self.latent_dim, device=x.device) # Discriminator outputs logits_q = self.discriminator(x, z_q) logits_p = self.discriminator(x, z_p) # Binary cross-entropy loss = F.binary_cross_entropy_with_logits( logits_q, torch.ones_like(logits_q) ) + F.binary_cross_entropy_with_logits( logits_p, torch.zeros_like(logits_p) ) loss.backward() optimizer.step() return loss.item() def train_generator(self, x, optimizer): """Train encoder and decoder to maximize ELBO.""" optimizer.zero_grad() # Sample from encoder z = self.encoder(x, num_samples=1).squeeze(1) # Reconstruction loss x_recon = self.decoder(z) recon_loss = F.mse_loss(x_recon, x, reduction='none').sum(-1) # KL estimated by discriminator # log q(z|x)/p(z) ≈ D(x, z) kl_estimate = self.discriminator(x, z) # ELBO = -recon_loss - KL loss = (recon_loss + kl_estimate).mean() loss.backward() optimizer.step() return { 'loss': loss.item(), 'recon_loss': recon_loss.mean().item(), 'kl_estimate': kl_estimate.mean().item(), }SVGD is a particle-based implicit VI method that represents the posterior through a set of particles, updated to minimize KL divergence to the target.
The Key Idea:
Instead of parameterizing a density q(z), maintain a set of particles {z₁, z₂, ..., zₙ}. Update particles to make their empirical distribution close to the posterior:
$$z_i^{(t+1)} = z_i^{(t)} + \epsilon \cdot \phi^*(z_i^{(t)})$$
where φ(z)* is the optimal perturbation direction that maximally decreases KL divergence.
Stein's Identity:
The magic of SVGD comes from Stein's identity. For smooth functions and distributions:
$$\mathbb{E}_p[\mathcal{A}_p \phi(z)] = 0$$
where A_p is the Stein operator. This identity allows us to characterize p without knowing its normalizing constant.
The SVGD Update:
The optimal perturbation in a reproducing kernel Hilbert space (RKHS) is:
$$\phi^*(z) = \frac{1}{n}\sum_{j=1}^{n} \left[ k(z_j, z) \nabla_{z_j} \log p(z_j|x) + \nabla_{z_j} k(z_j, z) \right]$$
where k(·, ·) is a kernel (typically RBF).
The update has two terms:
This balance between attraction (to posterior modes) and repulsion (maintaining coverage) is what makes SVGD effective.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
import torchimport torch.nn as nn class SVGD: """ Stein Variational Gradient Descent for implicit VI. Maintains particles that are iteratively updated to approximate the posterior. """ def __init__(self, log_posterior_fn, num_particles=100, latent_dim=10): """ Args: log_posterior_fn: Function computing log p(z|x) up to constant num_particles: Number of particles latent_dim: Dimensionality of latent space """ self.log_posterior = log_posterior_fn self.num_particles = num_particles self.latent_dim = latent_dim # Initialize particles from prior self.particles = torch.randn(num_particles, latent_dim, requires_grad=True) def rbf_kernel(self, z1, z2, bandwidth=None): """ RBF kernel with median heuristic for bandwidth. """ # Compute pairwise distances diff = z1.unsqueeze(1) - z2.unsqueeze(0) # (n, m, d) dist_sq = (diff ** 2).sum(-1) # (n, m) # Median heuristic for bandwidth if bandwidth is None: median_dist = torch.median(dist_sq.detach()) bandwidth = median_dist / (2 * torch.log(torch.tensor(z1.shape[0] + 1.0))) bandwidth = torch.clamp(bandwidth, min=1e-5) # RBF kernel K = torch.exp(-dist_sq / (2 * bandwidth)) # Gradient of kernel w.r.t. z1 grad_K = -diff / bandwidth * K.unsqueeze(-1) # (n, m, d) return K, grad_K def svgd_update(self, x, step_size=0.01): """ Perform one SVGD update step. Args: x: Conditioning observation step_size: Step size for particle updates """ n = self.particles.shape[0] # Compute log posterior gradient for each particle particles = self.particles.detach().requires_grad_(True) log_probs = self.log_posterior(x, particles) grad_log_probs = torch.autograd.grad( log_probs.sum(), particles, create_graph=False )[0] # (n, d) # Compute kernel matrix and gradients K, grad_K = self.rbf_kernel(particles, particles) # K: (n, n), grad_K: (n, n, d) # SVGD update direction # phi(z_i) = (1/n) * sum_j [K(z_j, z_i) * grad_log_p(z_j) + grad_K(z_j, z_i)] phi = (K @ grad_log_probs + grad_K.sum(dim=0)) / n # (n, d) # Update particles self.particles = (self.particles + step_size * phi).detach().requires_grad_(True) return { 'mean_log_prob': log_probs.mean().item(), 'particle_std': self.particles.std(dim=0).mean().item(), } def run(self, x, num_steps=1000, step_size=0.01): """Run SVGD for multiple steps.""" for t in range(num_steps): stats = self.svgd_update(x, step_size) if t % 100 == 0: print(f"Step {t}: log_prob={stats['mean_log_prob']:.3f}, " f"std={stats['particle_std']:.3f}") return self.particles.detach() def sample(self, num_samples=None): """Return particle samples (with optional subsampling).""" if num_samples is None or num_samples >= self.num_particles: return self.particles.detach() indices = torch.randperm(self.num_particles)[:num_samples] return self.particles[indices].detach()SVGD is deterministic given initial particles. It can capture multimodality if initialized appropriately. The computational cost scales as O(n²) due to pairwise kernel computations, limiting scalability to large particle counts. Variants like SSVGD (subset SVGD) address this.
Several important variants extend the basic implicit VI framework:
Importance Weighted AVB (IWAE-style):
We can use importance weighting even with implicit proposals. Although we cannot evaluate q(z|x), we can use self-normalized importance sampling:
$$\log p(x) \geq \mathbb{E}{z^{1:k} \sim q}\left[ \log \frac{1}{k} \sum{i=1}^k \frac{p_\theta(x, z^i)}{\hat{r}(z^i|x)} \right]$$
where r̂ is an auxiliary distribution that approximates q.
Contrastive Predictive Coding (CPC):
CPC uses contrastive objectives that implicitly perform density ratio estimation:
$$\mathcal{L}{\text{CPC}} = \mathbb{E}\left[ \log \frac{f(x, z)}{\sum{j=1}^K f(x, z_j)} \right]$$
This objective trains an encoder without requiring explicit densities.
Kernel Implicit VI:
Replaces the parametric discriminator with kernel-based estimation:
Normalizing Flow + Implicit Base:
Hybrid approach:
This combines the flexibility of implicit distributions with the tractability of flows for the transformed distribution.
| Method | Density Estimation | Scalability | Stability | Key Strength |
|---|---|---|---|---|
| AVB | Adversarial | High | Medium | Expressive encoder |
| SVGD | Not needed | Medium | High | Particle diversity |
| Kernel IVI | MMD | High | High | Training stability |
| CPC/NCE | Contrastive | Very High | High | Self-supervised |
| Flow + Implicit | Hybrid | Medium | Medium | Best of both |
Implicit VI isn't always the right choice. Here's guidance on when it shines versus when explicit methods are preferable:
Use Implicit VI When:
Start with explicit VI (standard VAE or normalizing flows). If you observe significant approximation errors or posterior collapse, consider implicit methods. The increased training complexity of implicit VI is only worthwhile when explicit methods demonstrably fail.
We've explored implicit variational inference as an approach to VI when density evaluation is intractable or overly restrictive. Here are the key takeaways:
You now understand implicit variational inference and when it's appropriate. Next, we'll explore hierarchical VI, which addresses inference in models with multiple levels of latent variables.