Loading learning content...
The Laplace approximation provides a Gaussian centered at the posterior mode—but what if we want more flexibility? What if the posterior is skewed, heavy-tailed, or has complex structure that a single Gaussian cannot capture?
Variational Inference (VI) takes a fundamentally different approach: instead of constructing an approximation at a single point, it frames approximate inference as an optimization problem. We define a family of tractable distributions Q and search for the distribution q*(θ) ∈ Q that is "closest" to the true posterior.
This perspective is transformative. It converts an intractable integration problem into a tractable optimization problem—and over the past two decades, advances in optimization (stochastic gradients, automatic differentiation) have made VI practical for models with millions of parameters.
By the end of this page, you will understand how variational inference frames approximation as optimization, derive the Evidence Lower Bound (ELBO) from first principles, interpret ELBO through the lens of KL divergence and compression, and recognize the fundamental trade-offs in variational families.
The core idea of variational inference is deceptively simple:
Question: Given an intractable posterior p(θ|D), what is the best approximation q(θ) from a tractable family Q?
Answer: The q(θ) that minimizes some measure of dissimilarity to p(θ|D).
The word "variational" comes from the calculus of variations—a branch of mathematics concerned with finding functions that optimize functionals. In our case, we're finding a probability distribution q that optimizes a functional measuring closeness to p.
Choosing the dissimilarity measure:
The natural choice is the Kullback-Leibler (KL) divergence:
$$\text{KL}(q | p) = \int q(\theta) \log \frac{q(\theta)}{p(\theta|D)} d\theta$$
KL divergence has two key properties:
We minimize KL(q || p) rather than KL(p || q) for a practical reason: the former requires expectations under q (which we control), while the latter requires expectations under p (which is intractable). This choice has consequences—KL(q || p) tends to be 'mode-seeking' and may underestimate posterior variance.
The variational problem:
$$q^* = \arg\min_{q \in Q} \text{KL}(q(\theta) | p(\theta|D))$$
But there's a problem: the KL divergence involves p(θ|D), which contains the intractable marginal likelihood in its normalization. We cannot evaluate this objective directly.
The breakthrough is recognizing that we can minimize an equivalent objective that doesn't require knowing p(D). This leads us to the Evidence Lower Bound (ELBO).
Let's derive the ELBO from the KL divergence. Start with:
$$\text{KL}(q | p) = \int q(\theta) \log \frac{q(\theta)}{p(\theta|D)} d\theta$$
Substitute Bayes' theorem: p(θ|D) = p(D|θ)p(θ)/p(D)
$$\text{KL}(q | p) = \int q(\theta) \log \frac{q(\theta) \cdot p(D)}{p(D|\theta)p(\theta)} d\theta$$
Split the logarithm:
$$\text{KL}(q | p) = \int q(\theta) \log \frac{q(\theta)}{p(D|\theta)p(\theta)} d\theta + \int q(\theta) \log p(D) d\theta$$
The second integral equals log p(D) since q integrates to 1:
$$\text{KL}(q | p) = \int q(\theta) \log \frac{q(\theta)}{p(D|\theta)p(\theta)} d\theta + \log p(D)$$
Rearranging:
$$\log p(D) = \text{KL}(q | p) - \int q(\theta) \log \frac{q(\theta)}{p(D|\theta)p(\theta)} d\theta$$
Define the Evidence Lower Bound (ELBO):
$$\mathcal{L}(q) = \int q(\theta) \log \frac{p(D|\theta)p(\theta)}{q(\theta)} d\theta = \mathbb{E}_{q}[\log p(D|\theta)] - \text{KL}(q(\theta) | p(\theta))$$
We now have the fundamental equation:
$$\boxed{\log p(D) = \mathcal{L}(q) + \text{KL}(q | p(\cdot|D))}$$
Since KL ≥ 0, we have:
$$\mathcal{L}(q) \leq \log p(D)$$
The ELBO is a lower bound on the log evidence—hence its name.
Since log p(D) is constant with respect to q, maximizing the ELBO is equivalent to minimizing KL(q || p). We've converted an intractable minimization (requires p(θ|D)) into a tractable maximization (requires only p(D|θ), p(θ), and q(θ)).
The ELBO admits two illuminating decompositions that reveal its structure:
Decomposition 1: Likelihood + KL
$$\mathcal{L}(q) = \underbrace{\mathbb{E}{q}[\log p(D|\theta)]}{\text{Expected log-likelihood}} - \underbrace{\text{KL}(q(\theta) | p(\theta))}_{\text{Complexity penalty}}$$
This form clearly shows the trade-off:
This is the variational analogue of regularized maximum likelihood! The prior acts as regularization, with the KL divergence playing the role of a penalty term.
Decomposition 2: Energy - Entropy
$$\mathcal{L}(q) = -\underbrace{\mathbb{E}{q}[-\log p(D, \theta)]}{\text{Expected negative log-joint (Energy)}} - \underbrace{(-\mathbb{E}{q}[\log q(\theta)])}{\text{Negative entropy}}$$
Simplifying:
$$\mathcal{L}(q) = \mathbb{E}_{q}[\log p(D, \theta)] + H(q)$$
where H(q) = -𝔼_q[log q(θ)] is the entropy of q.
This form reveals a different trade-off:
These two forces balance: energy wants q concentrated, entropy wants q spread out.
The choice of variational family Q determines the expressiveness-tractability trade-off in VI. A more flexible family can better approximate complex posteriors but may be harder to optimize.
The mean-field family:
The most common choice is the mean-field approximation, which assumes the posterior factorizes:
$$q(\theta) = \prod_{j=1}^d q_j(\theta_j)$$
Each parameter is independent under q, with its own variational distribution. This ignores posterior correlations but makes optimization tractable.
Gaussian mean-field:
$$q(\theta) = \prod_{j=1}^d \mathcal{N}(\theta_j | \mu_j, \sigma_j^2)$$
With 2d parameters (means μ and variances σ²), this is a diagonal Gaussian with no correlations.
| Family | Form | Parameters | Captures Correlations? |
|---|---|---|---|
| Mean-Field Gaussian | ∏ N(μⱼ, σⱼ²) | 2d | No |
| Full Covariance Gaussian | N(μ, Σ) | d + d(d+1)/2 | Yes (all) |
| Low-Rank Gaussian | N(μ, D + VVᵀ) | d + d + dk | Yes (rank k) |
| Mixture of Gaussians | Σ πₖ N(μₖ, Σₖ) | K(1 + d + d²) | Yes + multimodal |
| Normalizing Flow | f(z), z ~ N(0,I) | Flow params | Yes (learned) |
The expressiveness-tractability spectrum:
Mean-field (most tractable, least expressive)
Full-covariance Gaussian (tractable, moderately expressive)
Normalizing flows (flexible, computationally expensive)
Because mean-field ignores correlations, it systematically underestimates posterior variance when parameters are correlated. If θ₁ and θ₂ are positively correlated in the true posterior, mean-field treats them as independent, leading to overconfident predictions.
The variational inference problem is now:
$$\max_\phi \mathcal{L}(q_\phi) = \mathbb{E}{q\phi}[\log p(D|\theta)] - \text{KL}(q_\phi(\theta) | p(\theta))$$
where φ are the variational parameters (e.g., means and variances for Gaussian q).
Coordinate Ascent Variational Inference (CAVI):
The classical approach updates each factor q_j while holding others fixed:
$$\log q_j^*(\theta_j) = \mathbb{E}{q{-j}}[\log p(\theta, D)] + \text{const}$$
where q_{-j} denotes all factors except j. This is guaranteed to increase the ELBO at each step (or leave it unchanged at convergence).
CAVI is deterministic and converges to a local maximum, but:
Stochastic Variational Inference (SVI):
For large datasets, we can use stochastic gradient ascent on the ELBO. The key challenge is estimating the gradient:
$$\nabla_\phi \mathcal{L} = \nabla_\phi \mathbb{E}{q\phi}[\log p(D|\theta)] - \nabla_\phi \text{KL}(q_\phi | p)$$
The KL term often has a closed form. The likelihood term requires either:
Score function estimator (REINFORCE): $$\nabla_\phi \mathbb{E}{q\phi}[f(\theta)] = \mathbb{E}{q\phi}[f(\theta) \nabla_\phi \log q_\phi(\theta)]$$
Reparameterization trick (for location-scale families): Write θ = g(ε, φ) where ε ~ p(ε) is parameter-free: $$\nabla_\phi \mathbb{E}{q\phi}[f(\theta)] = \mathbb{E}{p(\epsilon)}[\nabla\phi f(g(\epsilon, \phi))]$$
Reparameterization has much lower variance and is preferred when applicable.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
import numpy as npimport torchimport torch.nn as nn class MeanFieldGaussianVI: """ Mean-field Gaussian variational inference with reparameterization. """ def __init__(self, dim, log_joint, n_samples=10): """ Parameters: ----------- dim : int Number of parameters log_joint : callable Function returning log p(θ, D) n_samples : int Monte Carlo samples for gradient estimation """ self.dim = dim self.log_joint = log_joint self.n_samples = n_samples # Variational parameters: means and log-variances self.mu = torch.zeros(dim, requires_grad=True) self.log_sigma = torch.zeros(dim, requires_grad=True) def sample_theta(self, n_samples=None): """Sample θ using reparameterization trick.""" n = n_samples or self.n_samples # θ = μ + σ ⊙ ε, where ε ~ N(0, I) epsilon = torch.randn(n, self.dim) sigma = torch.exp(self.log_sigma) return self.mu + sigma * epsilon def entropy(self): """Entropy of Gaussian: H = 0.5 * d * (1 + log(2π)) + Σ log(σ)""" return 0.5 * self.dim * (1 + np.log(2 * np.pi)) + self.log_sigma.sum() def elbo(self, data): """Compute ELBO estimate.""" theta_samples = self.sample_theta() # Expected log joint log_joints = torch.stack([ self.log_joint(theta, data) for theta in theta_samples ]) expected_log_joint = log_joints.mean() # ELBO = E[log p(θ, D)] + H(q) return expected_log_joint + self.entropy() def fit(self, data, n_iters=1000, lr=0.01): """Optimize ELBO via gradient ascent.""" optimizer = torch.optim.Adam([self.mu, self.log_sigma], lr=lr) history = [] for i in range(n_iters): optimizer.zero_grad() loss = -self.elbo(data) # Minimize negative ELBO loss.backward() optimizer.step() if i % 100 == 0: history.append(-loss.item()) print(f"Iter {i}: ELBO = {-loss.item():.4f}") return history def posterior_mean(self): return self.mu.detach().numpy() def posterior_std(self): return torch.exp(self.log_sigma).detach().numpy() # Example: Bayesian linear regressiondef bayesian_linreg_log_joint(theta, data): """Log joint for Bayesian linear regression.""" X, y = data w = theta[:-1] log_sigma = theta[-1] sigma = torch.exp(log_sigma) # Likelihood: y ~ N(Xw, σ²I) pred = X @ w log_lik = -0.5 * torch.sum((y - pred)**2) / sigma**2 log_lik -= len(y) * log_sigma # Prior: w ~ N(0, I), log_σ ~ N(0, 1) log_prior = -0.5 * torch.sum(w**2) - 0.5 * log_sigma**2 return log_lik + log_priorThe choice to minimize KL(q || p) rather than KL(p || q) has profound implications for the resulting approximation. Understanding this asymmetry is crucial for interpreting VI results.
Forward KL: KL(p || q) — Moment Matching
$$\text{KL}(p | q) = \int p(\theta) \log \frac{p(\theta)}{q(\theta)} d\theta$$
Minimizing forward KL:
Reverse KL: KL(q || p) — Mode Seeking
$$\text{KL}(q | p) = \int q(\theta) \log \frac{q(\theta)}{p(\theta)} d\theta$$
Minimizing reverse KL:
When the true posterior has multiple modes, reverse KL (standard VI) will collapse q onto a single mode. This is a fundamental limitation: you're guaranteed to underestimate uncertainty. For multimodal posteriors, consider mixture variational families or use forward KL approaches like expectation propagation.
Variational inference and MCMC/sampling represent two fundamentally different philosophies for approximate inference. Each has distinct advantages and trade-offs.
| Aspect | Variational Inference | MCMC / Sampling |
|---|---|---|
| Nature | Optimization problem | Sampling problem |
| Output | Parametric distribution q(θ) | Samples θ⁽¹⁾, ..., θ⁽ᴺ⁾ |
| Approximation bias | Yes (family restriction) | Asymptotically unbiased |
| Variance | Low (deterministic params) | High (Monte Carlo noise) |
| Convergence diagnostic | ELBO + gradient norms | Trace plots, R̂, ESS |
| Speed | Generally faster | Can be slow to mix |
| Scalability | Better for big data (SGD) | Struggle with large n |
| Multimodality | Difficult (mode-seeking) | Better (if mixing works) |
When to use VI:
When to use MCMC:
In practice, consider using VI for fast initial exploration and hyperparameter selection, then running MCMC from the VI solution for final inference. VI provides a good initialization (near a mode), and MCMC refines with unbiased sampling. This combines the speed of VI with the accuracy of MCMC.
Since 2014, variational inference has experienced a renaissance driven by deep learning and automatic differentiation. Several key developments have expanded its applicability:
Probabilistic programming and automatic VI:
Modern frameworks like Pyro, Stan, and NumPyro provide automatic variational inference:
# Pyro example: automatic guide (variational distribution)
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal
def model(data):
theta = pyro.sample("theta", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(theta, 1), obs=data)
guide = AutoDiagonalNormal(model) # Automatic mean-field Gaussian
svi = SVI(model, guide, optim.Adam({"lr": 0.01}), Trace_ELBO())
The guide (variational distribution) is automatically constructed, and ELBO gradients are computed via automatic differentiation.
Variational inference transforms the intractable problem of computing a posterior into a tractable optimization problem. By choosing an approximating family and maximizing the ELBO, we find the distribution in that family closest to the true posterior in KL divergence.
What's next:
We've seen two deterministic approximations: Laplace (Gaussian at the mode) and variational inference (optimal approximation in a family). Next, we'll explore expectation propagation—a method that uses forward rather than reverse KL, leading to different approximation properties and often better calibrated uncertainties.
You now understand variational inference as an optimization framework—defining a family Q, deriving the ELBO, and optimizing variational parameters. This conceptual foundation underlies modern probabilistic deep learning, from VAEs to Bayesian neural networks.