Loading content...
Every powerful tool has limitations. Mean-field variational inference gives us tractable approximate inference for complex models—but at a cost. The factorization assumption that enables tractability also constrains what the approximation can represent.
Understanding these limitations is not about dismissing mean-field VI. It's about knowing when to use it confidently, when to use it cautiously, and when to reach for alternatives. A skilled practitioner knows both the power and the boundaries of their tools.
By the end of this page, you will deeply understand the fundamental limitations of mean-field VI: the correlation blindness problem, systematic variance underestimation, mode-seeking behavior, and failure cases. You'll also learn about alternatives that address these limitations.
The fundamental limitation of mean-field VI stems directly from its defining assumption: the approximate posterior factorizes as a product of independent marginals.
$$q(\mathbf{z}) = \prod_{i=1}^{m} q_i(z_i)$$
This means:
$$\text{Cov}_q(z_i, z_j) = 0 \quad \text{for all } i eq j$$
If the true posterior has ANY correlation between variables, mean-field cannot represent it. The approximation is fundamentally blind to dependencies between latent variables.
This is not a bug or an implementation issue—it's a direct consequence of the factorization assumption. No matter how well you optimize, no matter how many iterations you run, mean-field will NEVER capture posterior correlations. If correlations matter for your application, you need a different method.
Why Correlations Matter:
Posterior correlations carry crucial information:
Parameter Uncertainty Relationships: If one parameter is overestimated, another may need to be underestimated to fit the data. This negative correlation matters for understanding what the data actually tells us.
Identifiability Issues: Some models have parameters that are only identifiable in combination (e.g., slope and intercept, or factor loadings and factors). Correlations encode these constraints.
Predictive Uncertainty: When making predictions, correlated uncertainties can either compound or cancel. Ignoring correlations misestimates predictive variance.
Model Comparison: Bayesian model comparison relies on properly accounting for uncertainty. Ignoring correlations distorts model evidence calculations.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
import numpy as npfrom scipy import statsimport warnings def demonstrate_correlation_blindness(): """ Show how mean-field completely misses posterior correlations. Example: Bayesian linear regression with correlated predictors. """ print("Correlation Blindness in Mean-Field VI") print("=" * 60) # Generate data with correlated predictors np.random.seed(42) n = 100 # Correlated predictors: x1 and x2 have correlation 0.9 rho = 0.9 cov_x = np.array([[1, rho], [rho, 1]]) X = np.random.multivariate_normal([0, 0], cov_x, n) # True coefficients true_beta = np.array([1.0, 1.0]) sigma = 1.0 # Generate response y = X @ true_beta + np.random.randn(n) * sigma print(f"Data: n={n}, p=2") print(f"Predictor correlation: {rho}") print(f"True coefficients: {true_beta}") # ===================================================== # Compute true posterior (analytically for linear regression) # ===================================================== # Prior: β ~ N(0, τ²I) with τ² = 10 tau_sq = 10.0 prior_precision = np.eye(2) / tau_sq # Posterior: β | y ~ N(μ_post, Σ_post) # Σ_post^{-1} = X'X / σ² + prior_precision # μ_post = Σ_post @ X'y / σ² XtX = X.T @ X Xty = X.T @ y post_precision = XtX / sigma**2 + prior_precision post_cov = np.linalg.inv(post_precision) post_mean = post_cov @ Xty / sigma**2 post_corr = post_cov[0, 1] / np.sqrt(post_cov[0, 0] * post_cov[1, 1]) print(f"True Posterior:") print(f" Means: [{post_mean[0]:.3f}, {post_mean[1]:.3f}]") print(f" Variances: [{post_cov[0,0]:.4f}, {post_cov[1,1]:.4f}]") print(f" Correlation: {post_corr:.3f}") print(f" Covariance: {post_cov[0,1]:.4f}") # ===================================================== # Mean-field approximation # ===================================================== # Mean-field: q(β) = q(β1) q(β2) # Each q(βj) is Gaussian, but independent # The mean-field optimal marginals have the same means # but diagonal covariance (ignoring off-diagonal) # For Gaussian, mean-field marginals match true marginal means mf_means = post_mean.copy() # But variances come from diagonal of precision matrix # (not diagonal of covariance!) # For mean-field: var_j = 1 / precision_jj mf_vars = 1.0 / np.diag(post_precision) print(f"Mean-Field Approximation:") print(f" Means: [{mf_means[0]:.3f}, {mf_means[1]:.3f}]") print(f" Variances: [{mf_vars[0]:.4f}, {mf_vars[1]:.4f}]") print(f" Correlation: 0.000 (by construction)") print(f" Covariance: 0.000 (by construction)") # ===================================================== # Compare consequences # ===================================================== print(f"" + "=" * 60) print("Consequences of Correlation Blindness:") print("=" * 60) # 1. Variance comparison print(f"1. Marginal Variance Comparison:") print(f" True variances: [{post_cov[0,0]:.4f}, {post_cov[1,1]:.4f}]") print(f" MF variances: [{mf_vars[0]:.4f}, {mf_vars[1]:.4f}]") print(f" MF underestimates variance by factor: " f"[{post_cov[0,0]/mf_vars[0]:.2f}x, {post_cov[1,1]/mf_vars[1]:.2f}x]") # 2. Sum of coefficients print(f"2. Variance of β₁ + β₂:") true_var_sum = post_cov[0,0] + post_cov[1,1] + 2*post_cov[0,1] mf_var_sum = mf_vars[0] + mf_vars[1] # No covariance term! print(f" True: Var(β₁ + β₂) = {true_var_sum:.4f}") print(f" MF: Var(β₁ + β₂) = {mf_var_sum:.4f}") print(f" Error: {abs(true_var_sum - mf_var_sum)/true_var_sum * 100:.1f}%") # 3. Credible regions print(f"3. Joint 95% Credible Region:") print(f" True: ellipse (captures {post_corr:.2f} correlation)") print(f" MF: axis-aligned rectangle (wrong shape entirely)") # 4. Predictive at new point x_new = np.array([1.0, 1.0]) true_pred_var = x_new @ post_cov @ x_new + sigma**2 mf_pred_var = x_new @ np.diag(mf_vars) @ x_new + sigma**2 print(f"4. Predictive Variance at x=[1, 1]:") print(f" True: {true_pred_var:.4f}") print(f" MF: {mf_pred_var:.4f}") print(f" MF overestimates by {(mf_pred_var/true_pred_var - 1)*100:.1f}%") return { 'true_cov': post_cov, 'mf_vars': mf_vars, 'post_corr': post_corr } if __name__ == "__main__": demonstrate_correlation_blindness()One of the most practically important limitations of mean-field VI is its tendency to underestimate posterior variance. This leads to overconfident uncertainty estimates.
The Mechanism:
Consider a true posterior with positive correlation between $z_1$ and $z_2$. When $z_1$ is high, $z_2$ tends to be high. The joint uncertainty 'moves together.'
Mean-field approximates this with independent marginals. Each marginal captures its own uncertainty, but misses that the uncertainties are linked. For positively correlated variables, this means:
$$\text{Var}_q(z_1 + z_2) = \text{Var}_q(z_1) + \text{Var}_q(z_2) < \text{Var}_p(z_1) + \text{Var}_p(z_2) + 2\text{Cov}_p(z_1, z_2) = \text{Var}_p(z_1 + z_2)$$
Mean-field credible intervals are often too narrow. A '95% credible interval' from mean-field might actually contain the true value only 70-80% of the time. This is a serious problem for decision-making and uncertainty quantification applications.
Why This Happens: The KL Divergence Perspective
Mean-field minimizes $\text{KL}(q || p)$, not $\text{KL}(p || q)$. These are different objectives:
Minimizing forward KL ($\text{KL}(q || p)$) encourages $q$ to focus on the modes of $p$ and avoid putting probability in low-density regions. This naturally leads to underestimation of variance.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
import numpy as npfrom scipy import stats def analyze_variance_underestimation(): """ Quantify how mean-field underestimates variance. """ print("Variance Underestimation in Mean-Field VI") print("=" * 60) # True bivariate Gaussian posterior # Strong positive correlation rho = 0.8 true_cov = np.array([[1.0, rho], [rho, 1.0]]) # Mean-field: diagonal covariance # Best MF approximation to match marginals mf_cov = np.diag([1.0, 1.0]) print(f"True posterior: bivariate Gaussian, ρ = {rho}") print(f"Mean-field: product of independent Gaussians") # Compare variances of various quantities quantities = [ ("z₁", np.array([1, 0])), ("z₂", np.array([0, 1])), ("z₁ + z₂", np.array([1, 1])), ("z₁ - z₂", np.array([1, -1])), ("2z₁ + z₂", np.array([2, 1])), ] print(f"{'Quantity':<12} {'True Var':<12} {'MF Var':<12} {'Ratio':<10} {'Error'}") print("-" * 60) for name, w in quantities: true_var = w @ true_cov @ w mf_var = w @ mf_cov @ w ratio = true_var / mf_var error = (mf_var - true_var) / true_var * 100 print(f"{name:<12} {true_var:<12.4f} {mf_var:<12.4f} {ratio:<10.2f} {error:+.1f}%") print() print("Key observations:") print("- For z₁ + z₂: True var includes positive cov term, MF misses it") print("- For z₁ - z₂: True var is reduced by positive cov, MF overestimates") print("- Mean-field errors compound for linear combinations") # Coverage probability experiment print(f"" + "=" * 60) print("Credible Interval Coverage Analysis") print("=" * 60) n_samples = 10000 # Generate samples from true posterior true_samples = np.random.multivariate_normal([0, 0], true_cov, n_samples) # Mean-field "posterior" mf_sample_z1 = np.random.randn(n_samples) mf_sample_z2 = np.random.randn(n_samples) # For z1 + z2, compute 95% intervals true_sum = true_samples[:, 0] + true_samples[:, 1] mf_sum = mf_sample_z1 + mf_sample_z2 # True 95% interval true_lower = np.percentile(true_sum, 2.5) true_upper = np.percentile(true_sum, 97.5) # MF 95% interval mf_lower = np.percentile(mf_sum, 2.5) mf_upper = np.percentile(mf_sum, 97.5) print(f"95% Credible Intervals for z₁ + z₂:") print(f" True: [{true_lower:.3f}, {true_upper:.3f}] (width: {true_upper-true_lower:.3f})") print(f" MF: [{mf_lower:.3f}, {mf_upper:.3f}] (width: {mf_upper-mf_lower:.3f})") # What coverage does MF interval achieve under true posterior? mf_coverage = np.mean((true_sum >= mf_lower) & (true_sum <= mf_upper)) print(f"MF '95%' interval actual coverage: {mf_coverage*100:.1f}%") print(f"MF interval is too narrow!") def scaling_with_dimensions(): """ Show how variance underestimation scales with dimensionality. """ print("" + "=" * 60) print("Scaling of Variance Underestimation with Dimension") print("=" * 60) print() rho = 0.3 # Moderate correlation print(f"Equal correlation ρ = {rho} between all pairs") print(f"Quantity: sum of all variables") print() print(f"{'Dimension':<12} {'True Var(sum)':<15} {'MF Var(sum)':<15} {'Underest. Factor'}") print("-" * 60) for d in [2, 5, 10, 20, 50]: # True covariance: compound symmetry true_cov = rho * np.ones((d, d)) + (1 - rho) * np.eye(d) # MF: diagonal mf_cov = np.eye(d) # Variance of sum = 1ᵀ Σ 1 ones = np.ones(d) true_var_sum = ones @ true_cov @ ones mf_var_sum = ones @ mf_cov @ ones factor = true_var_sum / mf_var_sum print(f"{d:<12} {true_var_sum:<15.2f} {mf_var_sum:<15.2f} {factor:.2f}x") print() print("As dimensionality increases, the underestimation of") print("variance for sums/averages gets WORSE, not better!") if __name__ == "__main__": analyze_variance_underestimation() scaling_with_dimensions()When the true posterior is multi-modal (has multiple local maxima), mean-field VI typically focuses on a single mode, ignoring the others. This can dramatically misrepresent the posterior uncertainty.
Why Mode-Seeking?
The KL divergence $\text{KL}(q || p)$ is infinite if $q(z) > 0$ where $p(z) = 0$. To avoid infinite divergence, $q$ must be zero wherever $p$ is zero or very small—including the valleys between modes.
This forces $q$ to concentrate on a single mode. Placing probability mass between modes (in low-density regions of $p$) is heavily penalized.
Each mode of the posterior often represents a qualitatively different explanation of the data. By focusing on just one mode, mean-field commits to a single explanation and ignores all alternatives. If the ignored modes have substantial probability, this is a serious limitation.
Common Multi-Modal Scenarios:
Practical Implications:
Initialization sensitivity: Different initializations lead to different modes. This isn't just about finding the 'best' mode—each mode might be equally valid.
Posterior averaging: If multiple modes are a-priori equally likely, the posterior mean should average across modes. Mean-field can't do this.
Uncertainty underestimation: The uncertainty from considering multiple modes is lost.
Decision making: If decisions depend on which mode is true, ignoring modes leads to incorrect decisions.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
import numpy as npfrom scipy import stats def demonstrate_mode_seeking(): """ Show how mean-field focuses on single modes. """ print("Mode-Seeking Behavior in Mean-Field VI") print("=" * 60) # True posterior: mixture of two Gaussians # Represents a bimodal posterior over a single parameter mu1, sigma1 = -2.0, 0.5 mu2, sigma2 = 2.0, 0.5 weight1 = 0.6 # Slightly more mass in mode 1 print(f"True posterior: mixture of Gaussians") print(f" Mode 1: N({mu1}, {sigma1}²) with weight {weight1}") print(f" Mode 2: N({mu2}, {sigma2}²) with weight {1-weight1}") # True posterior statistics true_mean = weight1 * mu1 + (1 - weight1) * mu2 true_var = (weight1 * (sigma1**2 + mu1**2) + (1 - weight1) * (sigma2**2 + mu2**2) - true_mean**2) print(f"True posterior statistics:") print(f" Mean: {true_mean:.3f}") print(f" Variance: {true_var:.3f}") print(f" Std Dev: {np.sqrt(true_var):.3f}") # Mean-field: single Gaussian # KL minimization will focus on ONE mode print(f"Mean-field approximation (Gaussian):") # Mean-field focused on mode 1 mf_mode1_mean = mu1 mf_mode1_var = sigma1**2 # Mean-field focused on mode 2 mf_mode2_mean = mu2 mf_mode2_var = sigma2**2 print(f" If initialized near mode 1:") print(f" Mean: {mf_mode1_mean:.3f}") print(f" Variance: {mf_mode1_var:.3f}") print(f" If initialized near mode 2:") print(f" Mean: {mf_mode2_mean:.3f}") print(f" Variance: {mf_mode2_var:.3f}") # What if we naively averaged across modes? print(f" Naive Gaussian matching true mean/var:") print(f" Mean: {true_mean:.3f}") print(f" Variance: {true_var:.3f}") print(f" (This is NOT what mean-field produces!)") # Coverage analysis print(f"" + "=" * 60) print("Coverage Analysis") print("=" * 60) # Generate samples from true bimodal posterior n = 10000 assignments = np.random.binomial(1, weight1, n) true_samples = np.where( assignments, np.random.normal(mu1, sigma1, n), np.random.normal(mu2, sigma2, n) ) # Mean-field 95% CI (using mode 1) mf1_ci_low = mu1 - 1.96 * sigma1 mf1_ci_high = mu1 + 1.96 * sigma1 coverage_mf1 = np.mean((true_samples >= mf1_ci_low) & (true_samples <= mf1_ci_high)) print(f"Mean-field (mode 1) '95%' CI: [{mf1_ci_low:.2f}, {mf1_ci_high:.2f}]") print(f" Actual coverage: {coverage_mf1*100:.1f}%") print(f" COMPLETELY misses mode 2!") # True 95% CI (empirical) true_ci_low = np.percentile(true_samples, 2.5) true_ci_high = np.percentile(true_samples, 97.5) print(f"True 95% CI: [{true_ci_low:.2f}, {true_ci_high:.2f}]") print(f" (Much wider, covers both modes)") def mixture_label_switching(): """ Illustrate label switching issue in mixture models. """ print("" + "=" * 60) print("Label Switching in Mixture Models") print("=" * 60) print(""" Consider a 2-component Gaussian mixture: p(data | μ₁, μ₂, π) = Σᵢ [π N(xᵢ | μ₁, σ²) + (1-π) N(xᵢ | μ₂, σ²)] The likelihood is SYMMETRIC in (μ₁, μ₂): Swapping μ₁ ↔ μ₂ and π ↔ (1-π) gives identical likelihood. True posterior has (at least) 2 equivalent modes: Mode A: μ₁ ≈ 0, μ₂ ≈ 5 Mode B: μ₁ ≈ 5, μ₂ ≈ 0 (swapped) Mean-field will pick ONE mode based on initialization. It cannot represent both simultaneously. Consequences: • Posterior mean of μ₁ is NOT meaningful • Must apply 'label processing' to interpret results • Uncertainty about cluster means is underestimated """) if __name__ == "__main__": demonstrate_mode_seeking() mixture_label_switching()Let's examine specific scenarios where mean-field VI performs particularly poorly, demonstrating its limitations in practice.
Failure Case 1: The Banana Distribution
A 'banana' or 'boomerang' shaped posterior is common when parameters have strong nonlinear dependencies. Mean-field must approximate this curved region with an axis-aligned rectangle—a poor fit.
Failure Case 2: Funnel Posteriors
Hierarchical models often produce 'funnel' geometries: at one end, variance is high; at the other, it's low. The posterior looks like a funnel. Mean-field's fixed-variance-per-dimension assumption cannot capture this.
Failure Case 3: Strong Constraints
When the posterior is concentrated on a low-dimensional manifold (e.g., z₁ + z₂ ≈ 1), the true posterior is a thin 'sheet' in high-dimensional space. Mean-field's axis-aligned approximation is volumetrically wasteful.
Banana distributions arise from nonlinear models. Funnel posteriors appear in virtually every hierarchical Bayesian model. Constraints appear whenever parameters must satisfy relationships. These 'failure cases' are actually common in real problems.
| Failure Case | Posterior Shape | Mean-Field Approximation | Consequence |
|---|---|---|---|
| Banana/curved | Curved manifold | Axis-aligned ellipse | Misses curvature, wrong area |
| Funnel | Wide top, narrow bottom | Uniform width | Wrong uncertainty at both ends |
| Manifold constraint | Thin sheet | Full-dimensional box | Massive probability outside support |
| Multi-modal | Multiple peaks | Single peak | Ignores alternative explanations |
| Heavy tails | Slow decay | Gaussian decay | Underestimates extreme events |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
import numpy as np def funnel_posterior_failure(): """ Demonstrate mean-field failure on funnel posteriors. The Neal's funnel is a canonical example from hierarchical models. """ print("Funnel Posterior Failure Case") print("=" * 60) print(""" Neal's Funnel: v ~ N(0, 3²) # Wide prior on log-variance x | v ~ N(0, exp(v)) # Conditional Gaussian The marginal posterior p(x, v | data) has a funnel shape: - When v is large: x has high variance (wide) - When v is small: x has low variance (narrow) Mean-field q(x)q(v) uses FIXED variance for x - Cannot adapt to v-dependent width - Either too wide at narrow end OR too narrow at wide end """) # Simulate the issue n_samples = 10000 # True funnel samples v_samples = np.random.randn(n_samples) * 3 x_samples = np.random.randn(n_samples) * np.exp(v_samples / 2) # Analyze x variance conditional on v v_bins = [(-np.inf, -2), (-2, 0), (0, 2), (2, np.inf)] print("True x variance by v region:") for low, high in v_bins: mask = (v_samples >= low) & (v_samples < high) if np.sum(mask) > 100: x_var = np.var(x_samples[mask]) print(f" v ∈ [{low:+.0f}, {high:+.0f}): Var(x) = {x_var:.2f}") # Mean-field uses single variance for all v mf_x_var = np.var(x_samples) print(f"Mean-field uses single Var(x) = {mf_x_var:.2f}") print("This is WAY too large for small v, too small for large v!") def constraint_failure(): """ Demonstrate failure when posterior has constraints. """ print("" + "=" * 60) print("Constraint Satisfaction Failure") print("=" * 60) print(""" Scenario: Simplex constraint True posterior: (z₁, z₂, z₃) must sum to 1 Lives on a 2D triangle in 3D space Mean-field: q(z₁)q(z₂)q(z₃) Each marginal is independent Product is a 3D region Problem: Mean-field places probability OUTSIDE the valid simplex! Probability leaks to invalid regions. """) # Simulate n_samples = 10000 # True samples: on simplex true_samples = np.random.dirichlet([2, 2, 2], n_samples) # Mean-field: match marginals independently mf_means = np.mean(true_samples, axis=0) mf_vars = np.var(true_samples, axis=0) print(f"True marginal means: {mf_means.round(3)}") print(f"True marginal vars: {mf_vars.round(4)}") # Sample from mean-field (independent marginals) # Using truncated normal as crude approximation mf_samples = np.column_stack([ np.random.normal(mf_means[i], np.sqrt(mf_vars[i]), n_samples) for i in range(3) ]) # Check how many violate simplex sums = mf_samples.sum(axis=1) valid = (mf_samples >= 0).all(axis=1) & np.isclose(sums, 1, atol=0.3) print(f"Mean-field samples summing to ~1 (±0.3): {np.mean(valid)*100:.1f}%") print(f"Mean-field samples with all positive: {np.mean((mf_samples >= 0).all(axis=1))*100:.1f}%") print("Mean-field puts substantial probability in invalid regions!") def summarize_failures(): """Summary of when to avoid mean-field.""" print("" + "=" * 60) print("When to Avoid Mean-Field VI") print("=" * 60) print(""" Avoid mean-field when: 1. Strong posterior correlations matter for your analysis → Use structured VI or full-covariance VI 2. Accurate uncertainty quantification is critical → Use MCMC or importance-weighted VI 3. Posterior is known to be multi-modal → Use mixture VI or tempering methods 4. Parameters have constraints (simplex, positive definite, etc.) → Use constrained parameterizations 5. Hierarchical models with funnel geometries → Use non-centered parameterizations Use mean-field when: 1. Point estimates are sufficient 2. Problem scale requires fast inference 3. Posterior correlations are expected to be weak 4. Initial exploration before refined analysis """) if __name__ == "__main__": funnel_posterior_failure() constraint_failure() summarize_failures()When mean-field's limitations are prohibitive, several alternatives are available. Each offers different trade-offs between accuracy and computational cost.
Beyond Fully-Factorized Mean-Field:
Alternative Inference Methods:
Start with mean-field for quick exploration and baseline. If results seem suspicious (too-confident intervals, sensitivity to initialization), upgrade to structured VI or MCMC for validation. Use domain knowledge to identify which correlations matter most.
| Method | Accuracy | Speed | Scalability | Best For |
|---|---|---|---|---|
| Mean-field VI | Low-Medium | Fast | High | Large-scale exploration |
| Structured VI | Medium | Medium | Medium | Known correlation structure |
| Full-cov Gaussian VI | Medium-High | Medium | Low-Medium | Moderate dimensions |
| Normalizing Flows | High | Medium | Medium | Complex posteriors |
| MCMC (standard) | High | Slow | Low | Gold standard verification |
| Hamiltonian MC | High | Medium | Low-Medium | Continuous parameters |
| EP | Medium-High | Medium | Medium | When marginals matter |
Mean-field variational inference is a powerful tool, but like all tools, it has fundamental limitations. Understanding these limitations is essential for using it appropriately.
Mean-field VI exemplifies a recurring theme in machine learning: trading accuracy for tractability. The factorization assumption makes inference possible at scale—but at the cost of ignoring posterior structure. This trade-off appears throughout probabilistic machine learning, and understanding it deeply is the mark of an expert practitioner.
Module Completion:
You have now completed the Mean-Field Approximation module. You understand:
This knowledge forms the foundation for more advanced topics in variational inference, including stochastic VI, variational autoencoders, and normalizing flows.
Congratulations! You have mastered Mean-Field Variational Inference. You understand both its power and its limitations, enabling you to apply it appropriately in practice. The next module explores ELBO Optimization and gradient-based approaches that extend VI to neural network models.