Loading learning content...
Imagine you're trying to understand a room full of people engaged in complex conversations. Everyone influences everyone else—their choices of topic, their body language, their emotional states form an intricate web of dependencies. Understanding this system completely would require tracking every possible combination of all their states—an exponentially large space.
But what if you could approximate this complex reality by assuming each person acts independently?
This is the central insight of the mean-field approximation: we approximate a complex, intractable joint distribution with a simpler, factorized distribution where variables are treated as independent. While this assumption is clearly wrong (people do influence each other), the resulting approximation is often remarkably accurate and, crucially, computationally tractable.
By the end of this page, you will understand the factorization assumption at its mathematical core: why we make this assumption, what it implies about our approximation, and how it transforms an intractable optimization problem into a sequence of tractable ones. This is the theoretical foundation upon which all mean-field methods are built.
To understand why the factorization assumption is necessary, we must first understand the problem it solves. Recall from our variational inference framework that we seek to approximate an intractable posterior distribution $p(\mathbf{z} | \mathbf{x})$ with a tractable variational distribution $q(\mathbf{z})$.
The Core Challenge:
Consider a probabilistic model with latent variables $\mathbf{z} = (z_1, z_2, \ldots, z_m)$. The true posterior $p(\mathbf{z} | \mathbf{x})$ typically involves:
This integral sums over all possible configurations of the latent variables—an exponentially large or continuous space that rarely admits a closed-form solution.
Even when we turn to variational inference to avoid computing p(x) directly, we still face a challenge: the ELBO optimization involves expectations over the full joint q(z). If z has many dimensions with complex dependencies, these expectations remain intractable. The factorization assumption addresses this specific coupling problem.
Quantifying the Complexity:
Consider discrete latent variables where each $z_i$ can take $K$ states. The full joint distribution $q(\mathbf{z})$ would require:
For a model with $m = 100$ latent variables each with $K = 10$ states:
This exponential reduction is the fundamental reason we adopt the factorization assumption.
| Latent Dims (m) | States (K) | Joint Parameters | Factorized Parameters | Reduction Factor |
|---|---|---|---|---|
| 10 | 2 | 1,023 | 10 | ~100× |
| 20 | 2 | ~1 million | 20 | ~50,000× |
| 50 | 10 | ~10⁵⁰ | 450 | ~10⁴⁷× |
| 100 | 10 | ~10¹⁰⁰ | 900 | ~10⁹⁷× |
| 1000 | 100 | ~10²⁰⁰⁰ | 99,000 | ~10¹⁹⁹⁵× |
The mean-field approximation restricts the variational family $\mathcal{Q}$ to distributions that fully factorize over the latent variables:
$$q(\mathbf{z}) = \prod_{i=1}^{m} q_i(z_i)$$
Each factor $q_i(z_i)$ is a marginal distribution over a single latent variable (or group of variables), and crucially, these factors are assumed to be independent of each other.
The term 'mean-field' originates from statistical physics, where it describes approximations that replace complex many-body interactions with an effective 'mean field.' Each particle feels an average effect of all others, rather than tracking individual interactions. In variational inference, each latent variable 'sees' the expected values of others, not their full distributions—hence the name.
Mathematical Formulation:
Let $\mathbf{z} = (z_1, z_2, \ldots, z_m)$ be our latent variables. The mean-field variational family is:
$$\mathcal{Q}{MF} = \left{ q : q(\mathbf{z}) = \prod{i=1}^{m} q_i(z_i) \right}$$
Each factor $q_i(z_i)$ can be:
The non-parametric formulation is particularly elegant because it allows the optimal form of each factor to emerge from the optimization, rather than being pre-specified.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
import numpy as npfrom typing import List, Callablefrom dataclasses import dataclass @dataclassclass MeanFieldFactor: """ Represents a single factor q_i(z_i) in the mean-field approximation. Each factor is a probability distribution over a single latent variable. The mean-field assumption dictates that all factors are independent. """ name: str dimension: int # Dimension of z_i # Parameters of the variational distribution # Form depends on the chosen distributional family params: dict def log_prob(self, z_i: np.ndarray) -> np.ndarray: """Compute log q_i(z_i) for given values.""" raise NotImplementedError("Subclasses must implement") def sample(self, n_samples: int) -> np.ndarray: """Draw samples from q_i(z_i).""" raise NotImplementedError("Subclasses must implement") def expected_value(self) -> np.ndarray: """Return E_{q_i}[z_i] - the 'mean field'.""" raise NotImplementedError("Subclasses must implement") def entropy(self) -> float: """Return H[q_i] = -E_{q_i}[log q_i(z_i)].""" raise NotImplementedError("Subclasses must implement") class GaussianFactor(MeanFieldFactor): """ A Gaussian mean-field factor: q_i(z_i) = N(z_i | μ_i, σ²_i) Univariate Gaussian is the most common choice for continuous latents. """ def __init__(self, name: str, init_mean: float = 0.0, init_var: float = 1.0): super().__init__(name=name, dimension=1, params={ 'mean': init_mean, 'variance': init_var }) @property def mean(self) -> float: return self.params['mean'] @mean.setter def mean(self, value: float): self.params['mean'] = value @property def variance(self) -> float: return self.params['variance'] @variance.setter def variance(self, value: float): self.params['variance'] = max(value, 1e-10) # Numerical stability def log_prob(self, z_i: np.ndarray) -> np.ndarray: """Log probability under Gaussian distribution.""" return -0.5 * ( np.log(2 * np.pi * self.variance) + (z_i - self.mean)**2 / self.variance ) def sample(self, n_samples: int) -> np.ndarray: """Draw samples from the Gaussian.""" return np.random.normal( self.mean, np.sqrt(self.variance), size=n_samples ) def expected_value(self) -> float: """E[z_i] = μ_i for Gaussian.""" return self.mean def expected_square(self) -> float: """E[z_i²] = μ_i² + σ²_i for Gaussian.""" return self.mean**2 + self.variance def entropy(self) -> float: """H[N(μ, σ²)] = 0.5 * (1 + log(2πσ²))""" return 0.5 * (1 + np.log(2 * np.pi * self.variance)) class MeanFieldApproximation: """ The complete mean-field variational approximation. q(z) = ∏_i q_i(z_i) Each factor is independent, enabling tractable computation. """ def __init__(self, factors: List[MeanFieldFactor]): self.factors = factors self.n_factors = len(factors) def log_prob(self, z: List[np.ndarray]) -> np.ndarray: """ Compute log q(z) = Σ_i log q_i(z_i) Due to factorization, joint log-prob is sum of marginal log-probs. """ total_log_prob = 0.0 for i, (factor, z_i) in enumerate(zip(self.factors, z)): total_log_prob += factor.log_prob(z_i) return total_log_prob def sample(self, n_samples: int) -> List[np.ndarray]: """ Sample from q(z) by independently sampling each factor. Independence makes sampling trivially parallelizable. """ return [factor.sample(n_samples) for factor in self.factors] def entropy(self) -> float: """ Total entropy H[q] = Σ_i H[q_i] Entropy of independent variables is sum of individual entropies. This is a KEY computational advantage of the factorization. """ return sum(factor.entropy() for factor in self.factors) def expected_values(self) -> List[float]: """Get expected values for all factors (the 'mean fields').""" return [factor.expected_value() for factor in self.factors] # Example: Setting up mean-field approximation for a 3-variable modeldef example_setup(): """Demonstrate mean-field setup for a simple model.""" # Model has 3 continuous latent variables: z1, z2, z3 # True posterior p(z1, z2, z3 | x) is intractable # We approximate with q(z1)q(z2)q(z3) factors = [ GaussianFactor("z1", init_mean=0.0, init_var=1.0), GaussianFactor("z2", init_mean=0.0, init_var=1.0), GaussianFactor("z3", init_mean=0.0, init_var=1.0), ] q = MeanFieldApproximation(factors) print("Mean-Field Approximation Setup:") print(f" Number of factors: {q.n_factors}") print(f" Total parameters: {q.n_factors * 2}") # mean + var per factor print(f" Initial entropy: {q.entropy():.4f}") print(f" Initial mean fields: {q.expected_values()}") # Compare to full joint # Joint Gaussian over 3 variables: 3 means + 6 covariance params = 9 # Mean-field: 3 means + 3 variances = 6 (no off-diagonal covariances) print(f"Parameter comparison:") print(f" Full joint Gaussian: 9 parameters") print(f" Mean-field Gaussian: 6 parameters") print(f" Reduction: No covariance modeling between variables") return q if __name__ == "__main__": example_setup()The factorization assumption has deep implications for what our approximation can and cannot represent. Understanding these implications is crucial for knowing when mean-field methods are appropriate and when they might fail.
Key Properties of the Factorized Distribution:
If the true posterior has strong correlations between latent variables—for example, if knowing z₁ tells you a lot about z₂—the mean-field approximation will miss this structure entirely. Each variable's uncertainty is treated in isolation, which can lead to overconfident or poorly calibrated posterior approximations.
Visualizing the Independence Assumption:
Consider a 2D posterior where $z_1$ and $z_2$ are correlated. The true posterior might be an elongated ellipse tilted at 45°, indicating that when $z_1$ is high, $z_2$ tends to be high as well.
The mean-field approximation must represent this with a product of two independent 1D distributions. The best it can do is an axis-aligned ellipse (or rectangle for non-Gaussian marginals), completely missing the tilt that captures the correlation.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
import numpy as npimport matplotlib.pyplot as pltfrom scipy import stats def visualize_factorization_bias(): """ Demonstrate how mean-field approximation misses correlations. True posterior: correlated bivariate Gaussian Mean-field: product of independent univariate Gaussians """ # True posterior: correlated bivariate Gaussian # Correlation coefficient ρ = 0.8 (strong positive correlation) rho = 0.8 true_mean = np.array([0.0, 0.0]) true_cov = np.array([ [1.0, rho], [rho, 1.0] ]) # Mean-field approximation: independent marginals # Best we can do: match the marginal means and variances mf_mean_1, mf_var_1 = 0.0, 1.0 # From marginal of z1 mf_mean_2, mf_var_2 = 0.0, 1.0 # From marginal of z2 # Generate samples for visualization n_samples = 2000 # True posterior samples true_samples = np.random.multivariate_normal(true_mean, true_cov, n_samples) # Mean-field samples (independent) mf_samples = np.column_stack([ np.random.normal(mf_mean_1, np.sqrt(mf_var_1), n_samples), np.random.normal(mf_mean_2, np.sqrt(mf_var_2), n_samples) ]) # Compute statistics true_corr = np.corrcoef(true_samples.T)[0, 1] mf_corr = np.corrcoef(mf_samples.T)[0, 1] print("Correlation Comparison:") print(f" True posterior correlation: {true_corr:.4f}") print(f" Mean-field correlation: {mf_corr:.4f}") print(f" Correlation captured: {mf_corr/true_corr*100:.1f}%") # The fundamental limitation: mean-field cannot capture covariance # This is not a bug—it's a direct consequence of the factorization # Compute KL divergence (analytically for Gaussians) # KL(q_MF || p_true) measures information lost due to factorization # For Gaussians: KL = 0.5 * (tr(Σ_p⁻¹ Σ_q) + (μ_p - μ_q)ᵀ Σ_p⁻¹ (μ_p - μ_q) # - d + ln(det Σ_p / det Σ_q)) mf_cov = np.array([[mf_var_1, 0], [0, mf_var_2]]) # Diagonal true_cov_inv = np.linalg.inv(true_cov) kl_divergence = 0.5 * ( np.trace(true_cov_inv @ mf_cov) + 0 + # Means are equal -2 + # d = 2 np.log(np.linalg.det(true_cov) / np.linalg.det(mf_cov)) ) print(f"KL(q_MF || p_true) = {kl_divergence:.4f} nats") print("This measures information lost due to ignoring correlations") return { 'true_samples': true_samples, 'mf_samples': mf_samples, 'kl_divergence': kl_divergence, 'true_cov': true_cov, 'mf_cov': mf_cov } def compute_variance_underestimation(): """ Show how mean-field can underestimate marginal variances when the true posterior has strong correlations and we optimize to match certain aspects. """ # Consider a posterior with constraint: z1 + z2 ≈ c (some constant) # This creates strong NEGATIVE correlation # True posterior: z1 and z2 sum to approximately 0 # Imagine: z1 ~ N(0, 1), z2 | z1 ~ N(-z1, 0.1) # Strong negative correlation n_samples = 10000 z1_true = np.random.normal(0, 1, n_samples) z2_true = np.random.normal(-z1_true, 0.1) # True marginal variances var_z1_true = np.var(z1_true) var_z2_true = np.var(z2_true) var_sum = np.var(z1_true + z2_true) # For mean-field: independent distributions # If we match marginal variances: mf_z1 = np.random.normal(0, np.sqrt(var_z1_true), n_samples) mf_z2 = np.random.normal(0, np.sqrt(var_z2_true), n_samples) # Variance of sum mf_var_sum = np.var(mf_z1 + mf_z2) print("Variance of z1 + z2:") print(f" True posterior: {var_sum:.4f}") print(f" Mean-field: {mf_var_sum:.4f}") print(f" Overestimation factor: {mf_var_sum / var_sum:.1f}x") print() print("Mean-field overestimates variance of sums when true") print("variables are negatively correlated, and underestimates") print("when they are positively correlated.") if __name__ == "__main__": results = visualize_factorization_bias() print() compute_variance_underestimation()The magic of the mean-field approximation becomes apparent when we examine how the factorization simplifies the ELBO (Evidence Lower Bound). This simplification is what makes optimization tractable.
The General ELBO:
$$\mathcal{L}(q) = \mathbb{E}{q(\mathbf{z})}[\log p(\mathbf{x}, \mathbf{z})] - \mathbb{E}{q(\mathbf{z})}[\log q(\mathbf{z})]$$
Under Mean-Field Factorization:
Substituting $q(\mathbf{z}) = \prod_i q_i(z_i)$:
$$\mathcal{L}(q) = \mathbb{E}{\prod_i q_i}[\log p(\mathbf{x}, \mathbf{z})] - \sum{i=1}^{m} \mathbb{E}_{q_i}[\log q_i(z_i)]$$
The entropy term H[q] = -E[log q] decomposes into a sum of individual entropies H[q_i]. This means we can compute the entropy contribution from each factor independently—a massive simplification compared to computing entropy of a complex joint distribution.
Deriving the Update for a Single Factor:
The key insight is that we can optimize each factor $q_j$ while holding all other factors fixed. Consider optimizing $q_j(z_j)$ while treating ${q_i(z_i)}_{i eq j}$ as constants.
Rewrite the ELBO, isolating terms involving $q_j$:
$$\mathcal{L}(q_j) = \mathbb{E}{q_j}\left[\mathbb{E}{q_{-j}}[\log p(\mathbf{x}, \mathbf{z})]\right] - \mathbb{E}_{q_j}[\log q_j(z_j)] + \text{const}$$
where $q_{-j} = \prod_{i eq j} q_i(z_i)$ are all factors except $j$.
Defining the Expected Log Joint:
Let: $$\widetilde{\log p}(z_j) = \mathbb{E}{q{-j}}[\log p(\mathbf{x}, \mathbf{z})]$$
This is the log joint probability, averaged over all latent variables except $z_j$. Then:
$$\mathcal{L}(q_j) = \mathbb{E}{q_j}[\widetilde{\log p}(z_j)] - \mathbb{E}{q_j}[\log q_j(z_j)] + \text{const}$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
import numpy as npfrom typing import List, Callablefrom scipy.special import digamma, gammaln class FactorizedELBO: """ Compute the ELBO under mean-field factorization. ELBO = E_q[log p(x, z)] - E_q[log q(z)] = E_q[log p(x, z)] + H[q] (H = entropy) = E_q[log p(x, z)] + Σ_i H[q_i] (under factorization) The factorization makes both terms tractable: 1. log p(x, z) often has structure we can exploit 2. Entropy decomposes into sum of individual entropies """ def __init__( self, log_joint: Callable, # log p(x, z) as function of z factors: List['VariationalFactor'] ): self.log_joint = log_joint self.factors = factors def compute_elbo(self, n_samples: int = 1000) -> float: """ Compute ELBO via Monte Carlo estimation. ELBO = E_q[log p(x, z)] - E_q[log q(z)] = E_q[log p(x, z)] + H[q] """ # Sample from factorized q(z) samples = [f.sample(n_samples) for f in self.factors] # Compute E_q[log p(x, z)] via Monte Carlo log_joint_values = np.array([ self.log_joint([s[i] for s in samples]) for i in range(n_samples) ]) expected_log_joint = np.mean(log_joint_values) # Compute H[q] = Σ_i H[q_i] analytically total_entropy = sum(f.entropy() for f in self.factors) elbo = expected_log_joint + total_entropy return elbo def compute_expected_log_joint_for_factor( self, factor_idx: int, z_j_values: np.ndarray, n_mc_samples: int = 100 ) -> np.ndarray: """ Compute E_{q_{-j}}[log p(x, z)] for given values of z_j. This is the key quantity needed to update factor j. We average over all factors except j, evaluated at specific z_j. Returns array of same length as z_j_values. """ results = [] for z_j in z_j_values: # Monte Carlo estimate over q_{-j} mc_estimates = [] for _ in range(n_mc_samples): # Sample all factors except j z_full = [] for i, f in enumerate(self.factors): if i == factor_idx: z_full.append(z_j) else: z_full.append(f.sample(1)[0]) mc_estimates.append(self.log_joint(z_full)) results.append(np.mean(mc_estimates)) return np.array(results) def demonstrate_elbo_decomposition(): """ Show how ELBO decomposes under mean-field assumption. """ print("ELBO Decomposition Under Mean-Field") print("=" * 50) print() print("General ELBO:") print(" L(q) = E_q[log p(x,z)] - E_q[log q(z)]") print() print("Under factorization q(z) = Π_i q_i(z_i):") print(" L(q) = E_q[log p(x,z)] + Σ_i H[q_i]") print() print("Key simplifications:") print(" 1. Entropy term decomposes into sum") print(" 2. Each H[q_i] can be computed independently") print(" 3. For exponential family q_i, H[q_i] is analytic") print() # Example: Gaussian factors print("Example: Gaussian Mean-Field") print("-" * 30) n_factors = 5 variances = [0.5, 1.0, 1.5, 2.0, 2.5] # Entropy of Gaussian: H[N(μ, σ²)] = 0.5 * (1 + log(2πσ²)) individual_entropies = [ 0.5 * (1 + np.log(2 * np.pi * v)) for v in variances ] total_entropy = sum(individual_entropies) print(f"Number of factors: {n_factors}") print(f"Variances: {variances}") print(f"Individual entropies: {[f'{h:.3f}' for h in individual_entropies]}") print(f"Total entropy H[q] = Σ H[q_i] = {total_entropy:.3f}") print() print("Without factorization, computing H[q] for a 5D joint") print("distribution would require 5D integration!") if __name__ == "__main__": demonstrate_elbo_decomposition()Perhaps the most elegant result in mean-field variational inference is the derivation of the optimal form for each factor. Given the factorization assumption, we can derive the exact form that each $q_j(z_j)$ must take to maximize the ELBO.
The Optimal Factor Theorem:
The optimal form for factor $q_j(z_j)$, holding all other factors fixed, is:
$$\log q_j^*(z_j) = \mathbb{E}{q{-j}}[\log p(\mathbf{x}, \mathbf{z})] + \text{const}$$
Or equivalently:
$$q_j^*(z_j) \propto \exp\left(\mathbb{E}{q{-j}}[\log p(\mathbf{x}, \mathbf{z})]\right)$$
The ELBO for updating q_j can be written as the negative KL divergence: L(q_j) = -KL(q_j || q̃_j) + const, where log q̃_j(z_j) = E_{q_{-j}}[log p(x, z)]. Since KL divergence is minimized (equals zero) when q_j = q̃_j, the optimal factor has the form above.
The Intuition:
The optimal factor $q_j^*(z_j)$ is proportional to the exponentiated expected log joint. Think of it this way:
The expectation over other factors creates an "effective" log-density for $z_j$ that accounts for its interactions with other variables through their expected values.
Exponential Family Simplification:
When the model $p(\mathbf{x}, \mathbf{z})$ belongs to the exponential family (which includes Gaussians, Bernoullis, multinomials, Poissons, etc.), something beautiful happens: the optimal factor $q_j^*(z_j)$ is also in the exponential family, often in the same family as the prior or complete conditional.
For example:
This means we don't need to represent arbitrary distributions—we only need to track the natural parameters of familiar distributions.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
import numpy as npfrom scipy import statsfrom typing import Tuple """Optimal Factor Derivation for Mean-Field Variational Inference For a model: p(x, z) = p(x | z) p(z)The optimal factor: q_j*(z_j) ∝ exp(E_{q_{-j}}[log p(x, z)]) This file demonstrates the derivation for a simple but illustrative case:Bayesian Linear Regression with Gaussian prior.""" def bayesian_linear_regression_mean_field(): """ Example: Bayesian Linear Regression Model: w ~ N(0, τ⁻¹ I) (prior on weights) τ ~ Gamma(a₀, b₀) (prior on precision) y | X, w, τ ~ N(Xw, τ⁻¹) (likelihood) Latent variables: z = (w, τ) Goal: q(w, τ) = q(w) q(τ) We derive the optimal forms for q(w) and q(τ). """ # Generate synthetic data np.random.seed(42) n, d = 100, 3 X = np.random.randn(n, d) true_w = np.array([1.0, -0.5, 0.3]) true_tau = 2.0 y = X @ true_w + np.random.randn(n) / np.sqrt(true_tau) # Prior hyperparameters a_0, b_0 = 1.0, 1.0 # Gamma prior for τ # ================================================ # DERIVE OPTIMAL FORM FOR q(w) # ================================================ print("Deriving optimal q(w)...") print("-" * 50) # log p(x, z) = log p(y | X, w, τ) + log p(w | τ) + log p(τ) # # log p(y | X, w, τ) = (n/2) log τ - (τ/2) ||y - Xw||² + const # log p(w | τ) = (d/2) log τ - (τ/2) ||w||² + const # log p(τ) = (a₀-1) log τ - b₀ τ + const # E_{q(τ)}[log p(x, z)] with respect to w: # = E[τ] * (-1/2)(wᵀXᵀXw - 2wᵀXᵀy + yᵀy + wᵀw) + const # = -E[τ]/2 * (wᵀ(XᵀX + I)w - 2wᵀXᵀy) + const # This is quadratic in w! # log q*(w) = -1/2 * wᵀ (E[τ](XᵀX + I)) w + wᵀ (E[τ] Xᵀy) + const # = log N(w | μ_w, Σ_w) # Where: # Σ_w⁻¹ = E[τ] (XᵀX + I) # μ_w = Σ_w * E[τ] Xᵀy print("log p(x, z) contains terms:") print(" From likelihood: -(τ/2) ||y - Xw||²") print(" From prior: -(τ/2) ||w||²") print() print("Taking E_{q(τ)} and collecting terms in w:") print(" log q*(w) = -E[τ]/2 * wᵀ(XᵀX + I)w + E[τ] * wᵀXᵀy + const") print() print("This is a quadratic form → q*(w) is GAUSSIAN!") print() print("Parameters of optimal q(w) = N(μ_w, Σ_w):") print(" Σ_w⁻¹ = E_q[τ] * (XᵀX + I)") print(" μ_w = Σ_w * E_q[τ] * Xᵀy") print() # ================================================ # DERIVE OPTIMAL FORM FOR q(τ) # ================================================ print("Deriving optimal q(τ)...") print("-" * 50) # E_{q(w)}[log p(x, z)] with respect to τ: # = ((n+d)/2 + a₀ - 1) log τ - τ * (b₀ + E[||y-Xw||² + ||w||²]/2) # = (a_n - 1) log τ - b_n τ # where a_n = a₀ + (n+d)/2 # b_n = b₀ + E_q[||y - Xw||² + ||w||²]/2 # This is the log of a Gamma distribution! # q*(τ) = Gamma(τ | a_n, b_n) print("log p(x,z) contains τ terms:") print(" From likelihood: (n/2) log τ - (τ/2)||y - Xw||²") print(" From w prior: (d/2) log τ - (τ/2)||w||²") print(" From τ prior: (a₀-1) log τ - b₀ τ") print() print("Taking E_{q(w)} and collecting terms in τ:") print(" log q*(τ) = (a_n - 1) log τ - b_n τ + const") print() print("This is log of a Gamma! → q*(τ) is GAMMA!") print() print("Parameters of optimal q(τ) = Gamma(a_n, b_n):") print(" a_n = a₀ + (n + d)/2") print(" b_n = b₀ + E_q[||y - Xw||² + ||w||²]/2") print() # ================================================ # Numerical verification # ================================================ print("=" * 50) print("Running coordinate ascent to verify...") print() # Initialize E_tau = 1.0 # E[τ] under q(τ) for iteration in range(10): # Update q(w) precision_w = E_tau * (X.T @ X + np.eye(d)) Sigma_w = np.linalg.inv(precision_w) mu_w = Sigma_w @ (E_tau * X.T @ y) # Compute E[||w||²] = ||μ_w||² + tr(Σ_w) E_w_sq = mu_w @ mu_w + np.trace(Sigma_w) # Compute E[||y - Xw||²] # = ||y||² - 2yᵀX E[w] + E[wᵀXᵀXw] # = ||y||² - 2yᵀX μ_w + tr(XᵀX(μ_w μ_wᵀ + Σ_w)) residual_sq = (y @ y - 2 * y @ X @ mu_w + np.trace(X.T @ X @ (np.outer(mu_w, mu_w) + Sigma_w))) # Update q(τ) a_n = a_0 + (n + d) / 2 b_n = b_0 + (residual_sq + E_w_sq) / 2 E_tau = a_n / b_n # E[τ] for Gamma(a, b) print(f"Iteration {iteration + 1}:") print(f" q(w): μ = [{', '.join(f'{m:.4f}' for m in mu_w)}]") print(f" q(τ): E[τ] = {E_tau:.4f} (vs true τ = {true_tau})") print() print(f"True weights: [{', '.join(f'{w:.4f}' for w in true_w)}]") print(f"Inferred weights: [{', '.join(f'{m:.4f}' for m in mu_w)}]") print(f"True precision: {true_tau:.4f}") print(f"Inferred E[τ]: {E_tau:.4f}") if __name__ == "__main__": bayesian_linear_regression_mean_field()The fully-factorized mean-field assumption—where every single latent variable has its own independent factor—is sometimes too restrictive. A more flexible approach is structured mean-field, where we group related variables together and only assume independence between groups.
Grouped Factorization:
$$q(\mathbf{z}) = \prod_{g=1}^{G} q_g(\mathbf{z}_g)$$
where $\mathbf{z}_g$ is a subset of latent variables forming group $g$, and ${\mathbf{z}_1, \ldots, \mathbf{z}_G}$ partitions all latent variables.
Larger groups capture more correlations within the group (better approximation) but are more expensive to optimize (the group's distribution is a joint over more variables). Smaller groups are cheaper but miss intra-group correlations. The art is in choosing groups that balance accuracy and tractability.
Common Grouping Strategies:
Example: Topic Models
In Latent Dirichlet Allocation (LDA):
Fully-factorized: $q(\theta, \phi, z) = \prod_d q(\theta_d) \prod_k q(\phi_k) \prod_{d,n} q(z_{dn})$
Grouped by document: $q(\theta, \phi, z) = \left(\prod_d q(\theta_d, z_d)\right) \prod_k q(\phi_k)$
The grouped version captures correlation between a document's topic distribution and its word assignments, which can significantly improve inference quality.
| Strategy | Correlations Captured | Computational Cost | Best For |
|---|---|---|---|
| Fully Factorized | None | O(m) per iteration | Initial exploration, massive models |
| Pairs | Pairwise within groups | O(m) with larger constants | Strongly coupled variable pairs |
| Blocks (size k) | Within k-sized blocks | O(m/k × k²) = O(mk) | Spatially/temporally local structure |
| By Data Point | All latents for one observation | O(n × local cost) | Variational autoencoders |
| Hierarchical | Within hierarchy levels | Depends on structure | Hierarchical Bayesian models |
The factorization assumption is the foundational principle that makes mean-field variational inference tractable. Let's consolidate the key insights from this page:
Now that we understand the factorization assumption, we can explore how to actually optimize the factors. The next page covers Coordinate Ascent Variational Inference (CAVI)—the algorithm that iteratively updates each factor while holding others fixed, climbing toward the optimal factorized approximation.
The Philosophical Perspective:
The mean-field approximation represents a fundamental trade-off in machine learning: we sacrifice fidelity (by ignoring correlations) to gain tractability (polynomial-time computation). This trade-off is not unique to variational inference—it appears throughout machine learning:
The mean-field assumption is one of many tools for making intractable problems tractable. Understanding when this assumption is reasonable—and when it breaks down—is essential for applying variational inference effectively.
You now understand the factorization assumption—the mathematical foundation of mean-field variational inference. You know why this assumption is necessary, what it implies about the approximation, and how it leads to the optimal factor form. Next, we'll learn the coordinate ascent algorithm that makes optimization tractable.