Loading content...
In variational inference, we approximate an intractable posterior $p(\mathbf{z} | \mathbf{x})$ with a tractable distribution $q(\mathbf{z})$ chosen from a variational family $\mathcal{Q}$. The choice of this family is perhaps the most consequential design decision in variational inference—it determines what approximations are even possible.
The variational family represents a fundamental trade-off:
This page develops a rigorous understanding of variational families: their formal definition, parameterization, common choices, and the expressiveness-tractability trade-off that governs their design.
By completing this page, you will: (1) Formally define the variational family and its role in VI, (2) Understand how parameterization connects distributions to optimization, (3) Master the mean-field assumption and its implications, (4) Explore structured and implicit variational families, and (5) Develop intuition for choosing appropriate families in practice.
A variational family $\mathcal{Q}$ is a set of probability distributions over latent variables $\mathbf{z}$, typically parameterized by variational parameters $\phi$:
$$\mathcal{Q} = { q_\phi(\mathbf{z}) : \phi \in \Phi }$$
where:
The variational inference objective is then:
$$\phi^* = \arg\max_{\phi \in \Phi} \mathcal{L}(\phi)$$
where $\mathcal{L}(\phi)$ is the ELBO (Evidence Lower Bound), derived in a subsequent page.
For a variational family to be useful, it must satisfy several properties:
The mapping from $\phi$ to $q_\phi$ is central to VI optimization. Consider a Gaussian variational family:
$$q_\phi(\mathbf{z}) = \mathcal{N}(\mathbf{z} | \boldsymbol{\mu}, \mathbf{\Sigma})$$
The parameters $\phi = (\boldsymbol{\mu}, \mathbf{\Sigma})$ determine the distribution. But $\mathbf{\Sigma}$ must be positive semi-definite, so we often reparameterize:
$$\mathbf{\Sigma} = \mathbf{L} \mathbf{L}^T$$
where $\mathbf{L}$ is lower-triangular (Cholesky factor). Now $\phi = (\boldsymbol{\mu}, \mathbf{L})$ with unconstrained optimization over $\mathbf{L}$.
The parameterization choice affects:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
import torchimport torch.nn as nnimport torch.distributions as dist class GaussianVariationalFamily(nn.Module): """ A full-covariance Gaussian variational family. This demonstrates the key components: 1. Parameterization (mean + Cholesky factor) 2. Tractable density evaluation 3. Tractable sampling via reparameterization """ def __init__(self, dim: int): super().__init__() self.dim = dim # Variational parameters phi = (mu, L_raw) # mu: mean vector # L_raw: unconstrained parameters for Cholesky factor self.mu = nn.Parameter(torch.zeros(dim)) # For stability, parameterize L with log-diagonal # L_raw has dim*(dim+1)/2 parameters for lower triangular n_tril = dim * (dim + 1) // 2 self.L_raw = nn.Parameter(torch.zeros(n_tril)) def _get_L(self) -> torch.Tensor: """Convert unconstrained L_raw to valid Cholesky factor.""" L = torch.zeros(self.dim, self.dim) # Fill lower triangular tril_indices = torch.tril_indices(self.dim, self.dim) L[tril_indices[0], tril_indices[1]] = self.L_raw # Ensure positive diagonal (use softplus for numerical stability) L.diagonal().copy_(torch.nn.functional.softplus(L.diagonal()) + 1e-6) return L def log_prob(self, z: torch.Tensor) -> torch.Tensor: """ Evaluate log q_phi(z). Requirement: Tractable density evaluation """ L = self._get_L() Sigma = L @ L.T mvn = dist.MultivariateNormal(self.mu, scale_tril=L) return mvn.log_prob(z) def sample(self, n_samples: int = 1) -> torch.Tensor: """ Draw samples z ~ q_phi(z) using reparameterization. Requirement: Tractable sampling + differentiable Reparameterization trick: z = mu + L @ epsilon, where epsilon ~ N(0, I) """ L = self._get_L() epsilon = torch.randn(n_samples, self.dim) z = self.mu + epsilon @ L.T # Broadcasting return z def entropy(self) -> torch.Tensor: """ Closed-form entropy for Gaussian: H[q] = (D/2) * (1 + log(2*pi)) + log|det(L)| """ L = self._get_L() log_det_L = L.diagonal().log().sum() entropy = 0.5 * self.dim * (1 + torch.log(torch.tensor(2 * 3.14159))) + log_det_L return entropy # Example usagedim = 10q = GaussianVariationalFamily(dim) # Sample from variational distributionz_samples = q.sample(n_samples=100)print(f"Sample shape: {z_samples.shape}") # [100, 10] # Evaluate log densitylog_probs = q.log_prob(z_samples)print(f"Log prob shape: {log_probs.shape}") # [100] # Compute entropyH = q.entropy()print(f"Entropy: {H.item():.4f}") # Gradients flow through all operations!loss = -log_probs.mean()loss.backward()print(f"Gradient on mu: {q.mu.grad[:3]}")The mean-field approximation is the most widely used and historically important variational family. It assumes that the approximate posterior factorizes completely over latent variables:
$$q(\mathbf{z}) = \prod_{j=1}^{D} q_j(z_j)$$
where each factor $q_j(z_j)$ is a distribution over the $j$-th latent variable, independent of all others.
The name "mean-field" comes from statistical physics, where it describes an approximation in which each particle experiences the average (mean) field produced by all other particles, ignoring detailed correlations. In VI, each latent variable "sees" only the expected effect of other variables, not their joint distribution.
For a model with $D$ latent variables, the mean-field family is:
$$\mathcal{Q}{\text{MF}} = \left{ q(\mathbf{z}) = \prod{j=1}^{D} q_j(z_j) : q_j \in \mathcal{Q}_j \right}$$
where each $\mathcal{Q}_j$ is a family for the individual factor (often chosen based on the variable's support—Gaussian for continuous, Categorical for discrete).
The factorization assumption makes inference dramatically more tractable. For D latent variables, a general distribution has O(exp(D)) degrees of freedom. A factorized approximation has O(D) parameters. This reduction enables VI to scale to models with millions of latent variables.
For mean-field families, a classical optimization approach is coordinate ascent: iteratively optimize each factor $q_j$ while holding others fixed. Under natural exponential family assumptions, the optimal update has a closed form:
$$\log q_j^*(z_j) = \mathbb{E}{q{-j}}[\log p(z_j, \mathbf{z}_{-j}, \mathbf{x})] + \text{const}$$
where $q_{-j} = \prod_{k \neq j} q_k(z_k)$ and $\mathbf{z}_{-j}$ denotes all variables except $z_j$.
In words: The optimal $q_j$ has log-density proportional to the expected log-complete conditional, where the expectation is over the current estimates of all other factors.
A common instantiation is Gaussian mean-field:
$$q(\mathbf{z}) = \prod_{j=1}^{D} \mathcal{N}(z_j | \mu_j, \sigma_j^2)$$
Variational parameters: $\phi = { \mu_1, \sigma_1, \ldots, \mu_D, \sigma_D }$
This requires only $2D$ parameters regardless of the true posterior's covariance structure.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
import numpy as npimport matplotlib.pyplot as pltfrom scipy.stats import multivariate_normal # True posterior: correlated 2D Gaussiantrue_mean = np.array([0, 0])true_cov = np.array([[1.0, 0.8], [0.8, 1.0]]) # Correlation = 0.8 # Mean-field approximation: factorized Gaussian# Best mean-field matches marginal variances but ignores correlationmf_mean = np.array([0, 0])mf_cov = np.array([[1.0, 0.0], # Diagonal only! [0.0, 1.0]]) # Visualization on a gridx = np.linspace(-3, 3, 100)y = np.linspace(-3, 3, 100)X, Y = np.meshgrid(x, y)pos = np.dstack((X, Y)) # Density contourstrue_rv = multivariate_normal(true_mean, true_cov)mf_rv = multivariate_normal(mf_mean, mf_cov) Z_true = true_rv.pdf(pos)Z_mf = mf_rv.pdf(pos) # Key observation: Mean-field CANNOT capture the tilted ellipse# It will always produce axis-aligned ellipses# # Information lost:# - Direction of maximum variance# - Conditional relationships: p(z_1 | z_2) is non-trivial in true posterior# - In mean-field: p(z_1 | z_2) = p(z_1), completely independent! # Quantifying the approximation error# KL divergence from true to mean-field for Gaussians:def kl_gaussians(mu_p, cov_p, mu_q, cov_q): """KL(p || q) for multivariate Gaussians""" k = len(mu_p) cov_q_inv = np.linalg.inv(cov_q) trace_term = np.trace(cov_q_inv @ cov_p) mean_term = (mu_q - mu_p).T @ cov_q_inv @ (mu_q - mu_p) log_det_term = np.log(np.linalg.det(cov_q) / np.linalg.det(cov_p)) return 0.5 * (trace_term + mean_term - k + log_det_term) kl = kl_gaussians(true_mean, true_cov, mf_mean, mf_cov)print(f"KL(true || mean-field) = {kl:.4f} nats")print("This measures information lost by ignoring correlation") # For this example with correlation 0.8:# KL ≈ 0.32 nats, representing substantial information lossMean-field systematically underestimates posterior uncertainty. By forcing independence, it effectively "shrinks" the posterior. When making predictions, this can lead to overconfident estimates. For uncertainty-critical applications (medical diagnosis, safety systems), mean-field may be inadequate.
When mean-field is too restrictive, we can use structured variational families that capture some (but not all) dependencies. The key idea is to factorize over groups of variables rather than individual variables:
$$q(\mathbf{z}) = \prod_{g \in \mathcal{G}} q_g(\mathbf{z}_g)$$
where $\mathcal{G}$ partitions latent variables into groups, and variables within each group can have arbitrary dependencies.
1. Block Mean-Field
Group correlated variables together: $$q(\mathbf{z}) = q(z_1, z_2) \cdot q(z_3, z_4, z_5) \cdot q(z_6)$$
The first block captures correlation between $z_1$ and $z_2$; the second block captures $(z_3, z_4, z_5)$ jointly.
2. Tridiagonal/Banded Covariance
For sequential data, model adjacent dependencies: $$q(\mathbf{z}) = \mathcal{N}(\mathbf{z} | \boldsymbol{\mu}, \mathbf{\Sigma}_{\text{band}})$$
where $\mathbf{\Sigma}_{\text{band}}$ has non-zero entries only near the diagonal.
3. Chain-Structured Factorization
For time series: $$q(z_1, \ldots, z_T) = q(z_1) \prod_{t=2}^{T} q(z_t | z_{t-1})$$
This captures Markovian dependencies while remaining tractable.
The most expressive common parametric family for continuous latents:
$$q(\mathbf{z}) = \mathcal{N}(\mathbf{z} | \boldsymbol{\mu}, \mathbf{\Sigma})$$
with $\mathbf{\Sigma}$ a full (dense) covariance matrix.
Parameters: $D$ for mean + $D(D+1)/2$ for covariance = $O(D^2)$ total
Advantages:
Disadvantages:
A compromise that scales better:
$$\mathbf{\Sigma} = \mathbf{D} + \mathbf{U}\mathbf{U}^T$$
where $\mathbf{D}$ is diagonal and $\mathbf{U} \in \mathbb{R}^{D \times K}$ with $K \ll D$.
Parameters: $D + DK = O(DK)$
This captures $K$ principal directions of correlation while remaining computationally tractable.
| Family | Parameters | Dependencies | Use Case |
|---|---|---|---|
| Mean-Field | O(D) | None | Large models, quick iteration |
| Block Mean-Field | O(ΣB²) | Within blocks | Grouped latent structure |
| Full Covariance | O(D²) | All pairwise | Small D, accurate uncertainty |
| Low-Rank + Diagonal | O(DK) | K principal directions | Medium D, moderate correlation |
| Banded Covariance | O(DB) | Local (band B) | Sequential/spatial data |
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
import torchimport torch.nn as nn class LowRankPlusDiagonalGaussian(nn.Module): """ Low-rank + diagonal covariance structure. Covariance: Σ = D + U @ U.T where D is diagonal and U has rank K << D This captures the K most important correlation directions while remaining computationally tractable. """ def __init__(self, dim: int, rank: int): super().__init__() self.dim = dim self.rank = rank # Mean parameter self.mu = nn.Parameter(torch.zeros(dim)) # Diagonal component (log for positivity) self.log_diag = nn.Parameter(torch.zeros(dim)) # Low-rank component self.U = nn.Parameter(torch.randn(dim, rank) * 0.01) def _get_covariance_factors(self): """Return (D, U) where Σ = D + U @ U.T""" D = torch.diag(torch.exp(self.log_diag)) return D, self.U def sample(self, n_samples: int) -> torch.Tensor: """ Efficient sampling using Woodbury identity. z = mu + L_D @ eps1 + U @ eps2 where eps1 ~ N(0, I_D), eps2 ~ N(0, I_K) """ D_sqrt = torch.exp(0.5 * self.log_diag) eps1 = torch.randn(n_samples, self.dim) eps2 = torch.randn(n_samples, self.rank) # Note: This is an approximation; exact sampling from # low-rank + diagonal requires more care z = self.mu + D_sqrt * eps1 + eps2 @ self.U.T return z def log_prob(self, z: torch.Tensor) -> torch.Tensor: """ Log probability using Woodbury matrix identity for efficiency. For Σ = D + U @ U.T: Σ^{-1} = D^{-1} - D^{-1} @ U @ (I + U.T @ D^{-1} @ U)^{-1} @ U.T @ D^{-1} Cost: O(D * K^2) instead of O(D^3) """ D_inv = torch.exp(-self.log_diag) # Diagonal of D^{-1} # Centered variable z_centered = z - self.mu # Woodbury: (D + UU^T)^{-1} = D^{-1} - D^{-1}U(I + U^T D^{-1}U)^{-1}U^T D^{-1} D_inv_U = D_inv.unsqueeze(1) * self.U # D^{-1} @ U M = torch.eye(self.rank) + self.U.T @ D_inv_U # I + U^T D^{-1} U M_inv = torch.inverse(M) # Quadratic form: z^T Σ^{-1} z term1 = (z_centered ** 2 * D_inv).sum(-1) # z^T D^{-1} z v = z_centered @ D_inv_U # z^T D^{-1} U, shape [n, K] term2 = (v @ M_inv * v).sum(-1) # z^T D^{-1} U (I + ...)^{-1} U^T D^{-1} z quad_form = term1 - term2 # Log determinant: log|D + UU^T| = log|D| + log|M| log_det_D = self.log_diag.sum() log_det_M = torch.logdet(M) log_det = log_det_D + log_det_M # Gaussian log prob log_prob = -0.5 * (self.dim * torch.log(torch.tensor(2 * 3.14159)) + log_det + quad_form) return log_prob # Example: 1000D latent with rank-10 correlation structuredim, rank = 1000, 10q = LowRankPlusDiagonalGaussian(dim, rank) # Parameters: 1000 (mean) + 1000 (diag) + 1000*10 (U) = 12,000# vs Full covariance: 1000 + 500,500 = 501,500 parameters!print(f"Low-rank params: {sum(p.numel() for p in q.parameters())}")print(f"Full cov params would be: {dim + dim*(dim+1)//2}")The most expressive modern variational families are implicit—defined by a sampling procedure rather than an explicit density formula. The canonical example is the amortized variational family used in Variational Autoencoders (VAEs).
Instead of learning separate variational parameters for each data point, we learn a recognition network (encoder) that maps observations to variational parameters:
$$q_\phi(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mathbf{z} | \boldsymbol{\mu}\phi(\mathbf{x}), \text{diag}(\boldsymbol{\sigma}\phi^2(\mathbf{x})))$$
where $\boldsymbol{\mu}\phi$ and $\boldsymbol{\sigma}\phi$ are neural networks with parameters $\phi$.
The key insight: The network "amortizes" the cost of inference across all data points. Instead of optimizing $\phi_n$ for each $\mathbf{x}_n$, we optimize a single $\phi$ that works for all $\mathbf{x}$.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
import torchimport torch.nn as nn class AmortizedGaussianFamily(nn.Module): """ Amortized variational family: q(z|x) parameterized by neural network. Given input x, the encoder outputs mean and log-variance of q(z|x). This is the recognition network used in VAEs. """ def __init__(self, input_dim: int, latent_dim: int, hidden_dim: int = 256): super().__init__() self.latent_dim = latent_dim # Encoder network: x -> (mu, log_var) self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), ) # Separate heads for mean and log-variance self.mu_head = nn.Linear(hidden_dim, latent_dim) self.logvar_head = nn.Linear(hidden_dim, latent_dim) def encode(self, x: torch.Tensor): """Map input to variational parameters.""" h = self.encoder(x) mu = self.mu_head(h) log_var = self.logvar_head(h) return mu, log_var def sample(self, x: torch.Tensor, n_samples: int = 1): """ Sample z ~ q(z|x) using reparameterization trick. z = mu + sigma * epsilon where epsilon ~ N(0, I) Gradients flow through mu and sigma (not epsilon). """ mu, log_var = self.encode(x) std = torch.exp(0.5 * log_var) # Reparameterized samples eps = torch.randn(n_samples, x.shape[0], self.latent_dim) z = mu + std * eps return z.squeeze(0) if n_samples == 1 else z def log_prob(self, x: torch.Tensor, z: torch.Tensor): """Evaluate log q(z|x) for given x and z.""" mu, log_var = self.encode(x) log_prob = -0.5 * ( log_var + (z - mu) ** 2 / torch.exp(log_var) + torch.log(torch.tensor(2 * 3.14159)) ) return log_prob.sum(-1) # Sum over latent dimensions def kl_divergence(self, x: torch.Tensor): """ Closed-form KL divergence from q(z|x) to standard normal prior. KL(N(mu, sigma^2) || N(0, I)) = 0.5 * sum(sigma^2 + mu^2 - 1 - log(sigma^2)) """ mu, log_var = self.encode(x) kl = 0.5 * (torch.exp(log_var) + mu**2 - 1 - log_var) return kl.sum(-1) # Sum over latent dimensions # Example: Encoder for 784D input (MNIST) to 32D latent spaceencoder = AmortizedGaussianFamily(input_dim=784, latent_dim=32) # For a batch of imagesx = torch.randn(64, 784) # Batch of 64 imagesz = encoder.sample(x) # Sample latent codesprint(f"Latent shape: {z.shape}") # [64, 32] # Each x_i gets its own variational distribution q(z|x_i)# But all share the same encoder weights!Normalizing flows transform a simple base distribution through a sequence of invertible transformations to create a complex, expressive distribution:
$$\mathbf{z} = f_K \circ f_{K-1} \circ \cdots \circ f_1(\boldsymbol{\epsilon}), \quad \boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})$$
The density is computed via the change of variables formula:
$$\log q(\mathbf{z}) = \log p(\boldsymbol{\epsilon}) - \sum_{k=1}^{K} \log \left| \det \frac{\partial f_k}{\partial \mathbf{z}_{k-1}} \right|$$
Flow architectures:
Flows can approximate arbitrary distributions given sufficient depth, making them the most expressive practical variational families.
The evolution of variational families traces a path from simple (mean-field) to complex (flows). Modern research pushes toward families that are both expressive and computationally tractable. The ideal family would: (1) approximate any posterior, (2) have tractable density and sampling, (3) have cheap gradient computation, and (4) scale to high dimensions. No family achieves all four perfectly—design is always a trade-off.
The choice of variational family should be driven by:
| Scenario | Recommended Family | Rationale |
|---|---|---|
| Millions of parameters, fast iteration | Mean-field Gaussian | Minimal overhead, scales best |
| Moderate D (~1000), correlations matter | Low-rank + diagonal | Captures principal correlations |
| Small D (~100), accurate uncertainty needed | Full-covariance Gaussian | Optimal for Gaussian posteriors |
| Deep generative model (VAE) | Amortized + flow | Neural expressiveness + fast inference |
| Multi-modal posterior expected | Mixture or flow | Mean-field/Gaussian will fail |
| Sequential/temporal latent structure | Structured (chain) family | Matches model dependencies |
| Discrete latent variables | Categorical/Gumbel-Softmax | Must match support |
After optimization, you can diagnose family adequacy:
1. Monitor the ELBO gap
The gap between ELBO and true log-evidence (if computable) indicates approximation quality. A large gap suggests the family is too restrictive.
2. Check posterior samples
Compare samples from $q$ against any known posterior structure. Do samples cover expected modes? Do marginals match?
3. Predictive performance
If a richer family significantly improves downstream predictions, the simpler family was inadequate.
4. Compare against MCMC
For small problems, compare VI posterior moments against MCMC as ground truth. Systematic bias indicates family mismatch.
In practice, start with mean-field and add structure only when diagnostics indicate it's needed. Mean-field often works surprisingly well, especially when the main goal is prediction rather than posterior uncertainty. Premature complexity adds implementation burden and can make optimization harder.
You now understand the variational family—the search space for approximate inference. You've seen how parameterization enables optimization, explored the mean-field assumption and its limitations, and surveyed structured and implicit families. The next page develops the KL divergence objective that measures closeness between approximate and true posteriors.