Loading learning content...
We've derived the mean-field update equations and implemented CAVI. But critical questions remain: Does the algorithm always converge? How fast? What does it converge to? When can we trust the result?
These questions are not merely theoretical—they directly impact how we use variational inference in practice. Understanding convergence helps us:
By the end of this page, you will understand the theoretical convergence guarantees of mean-field VI, the factors that affect convergence speed, the nature of local optima in the ELBO landscape, and practical strategies for achieving reliable convergence.
Mean-field variational inference with coordinate ascent (CAVI) enjoys strong convergence guarantees that stem from the structure of the optimization problem.
Theorem (CAVI Convergence):
Let ${q^{(t)}}_{t=1}^{\infty}$ be the sequence of variational distributions produced by CAVI. Then:
Proof Sketch:
Monotonicity: Each coordinate update maximizes the ELBO with respect to one factor while holding others fixed. Since we're choosing the optimal update, the ELBO cannot decrease.
Convergence: The ELBO is bounded above by $\log p(\mathbf{x})$ (since $\text{KL}(q || p) \geq 0$). A bounded, monotonically increasing sequence must converge.
Stationarity: If the ELBO stops changing, no factor can improve—this is the definition of a stationary point for coordinate optimization.
These guarantees do NOT say that CAVI finds the global maximum of the ELBO. The ELBO landscape can have many local optima, saddle points, and plateaus. CAVI converges to a local maximum (or saddle point), which may be far from optimal. This is similar to gradient descent on non-convex functions.
Conditions for Stronger Convergence:
In special cases, we can prove stronger results:
Convex ELBO: If the ELBO is convex in all factors jointly (rare), CAVI converges to the global optimum.
Unique Stationary Point: If the ELBO has a unique stationary point (common in well-conditioned small models), CAVI converges to it.
Exponential Family with Convex Sufficient Statistics: For certain exponential family models, the ELBO can be shown to have favorable structure.
In practice, most models have multiple local optima, and we use multiple restarts to find good solutions.
| Property | Guaranteed? | Condition | Implication |
|---|---|---|---|
| ELBO non-decreasing | Yes | Optimal factor updates | Can monitor for bugs |
| ELBO converges | Yes | Bounded above by log p(x) | Algorithm terminates |
| Converges to stationary point | Yes | Coordinate-wise optimality | No single-factor improvement possible |
| Converges to global optimum | No | Only if ELBO convex | May find suboptimal solution |
| Unique limit point | No | Depends on initialization | Different runs may differ |
| Fast convergence | No | Depends on model structure | May need many iterations |
How quickly does CAVI converge? The answer depends on the model structure and can range from just a few iterations to thousands.
Factors Affecting Convergence Rate:
Coupling Strength: Strong dependencies between latent variables slow convergence. If $z_i$'s optimal value depends heavily on $z_j$, and vice versa, many iterations may be needed to equilibrate.
Condition Number: The 'shape' of the ELBO landscape matters. Highly elongated or ill-conditioned regions lead to slow progress.
Dimensionality: More latent variables generally means more iterations, though not always linearly.
Initialization Quality: Starting close to the optimum dramatically reduces iterations needed.
Under favorable conditions (smooth, strongly convex objective), coordinate ascent exhibits LINEAR convergence: the error decreases by a constant factor each iteration. If ε_t ≤ ρᵗ ε₀ for some ρ < 1, then ε_t < ε requires t > log(ε₀/ε) / log(1/ρ) iterations. For ρ = 0.9, halving error takes ~7 iterations; for ρ = 0.99, it takes ~69 iterations.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
import numpy as npfrom typing import List, Tupleimport matplotlib.pyplot as plt def analyze_convergence_rate(elbo_history: List[float]) -> dict: """ Analyze convergence rate from ELBO history. Estimates the convergence rate by fitting an exponential decay to the ELBO improvements (distance from final value). Returns: Dictionary with convergence statistics """ elbo = np.array(elbo_history) final_elbo = elbo[-1] # Distance from final (should decay exponentially for linear convergence) distance = np.abs(elbo - final_elbo) distance = np.maximum(distance, 1e-15) # Avoid log(0) # Estimate convergence rate from log-linear fit # log(distance) ≈ log(ε₀) + t × log(ρ) t = np.arange(len(distance)) valid = distance > 1e-10 # Only use points not at convergence if np.sum(valid) > 2: log_dist = np.log(distance[valid]) t_valid = t[valid] # Linear regression coeffs = np.polyfit(t_valid, log_dist, 1) log_rho = coeffs[0] # Slope rho = np.exp(log_rho) # Convergence rate else: rho = None # Compute iteration counts for various tolerances improvements = np.diff(elbo) relative_improvements = improvements[:-1] / np.maximum(np.abs(elbo[:-2]), 1e-10) results = { 'n_iterations': len(elbo), 'initial_elbo': elbo[0], 'final_elbo': final_elbo, 'total_improvement': final_elbo - elbo[0], 'estimated_rho': rho, 'estimated_half_life': -np.log(2) / np.log(rho) if rho and 0 < rho < 1 else None, 'mean_relative_improvement': np.mean(np.abs(relative_improvements)), } return results def demonstrate_coupling_effects(): """ Show how coupling between variables affects convergence. Simulates CAVI for a bivariate Gaussian with varying correlation. """ print("Effect of Variable Coupling on Convergence") print("=" * 60) print() print("Consider a bivariate Gaussian posterior with correlation ρ.") print("Mean-field approximates with independent q(z₁)q(z₂).") print() print("The CAVI updates for means are:") print(" μ₁ ← (data term) + ρ × E[z₂]") print(" μ₂ ← (data term) + ρ × E[z₁]") print() print("Strong correlation ρ causes oscillations between updates.") print() def simulate_cavi(rho: float, n_iter: int = 50) -> List[float]: """Simulate CAVI for 2D Gaussian with correlation rho.""" # True posterior: N([0, 0], [[1, ρ], [ρ, 1]]) # Target means are 0, but we start away mu1, mu2 = 5.0, 5.0 errors = [] for _ in range(n_iter): # Mean-field update: effective mean influenced by other variable # (Simplified model - actual updates depend on specific model) mu1_new = 0 + rho * mu2 * 0 # Data pulls toward 0 mu2_new = 0 + rho * mu1 * 0 # Without coupling term # With coupling effect (exaggerated for illustration) mu1 = 0.5 * (mu1 + rho * mu2) # Damped update mu2 = 0.5 * (mu2 + rho * mu1) errors.append(np.sqrt(mu1**2 + mu2**2)) return errors correlations = [0.0, 0.3, 0.6, 0.9] print("Correlation | Iterations to ε<0.1 | Final Error") print("-" * 50) for rho in correlations: errors = simulate_cavi(rho, n_iter=100) iters_to_conv = next((i for i, e in enumerate(errors) if e < 0.1), 100) print(f" {rho:.1f} | {iters_to_conv:3d} | {errors[-1]:.4f}") print() print("Higher correlation → slower convergence") print("This is why strong dependencies hurt mean-field performance.") def analyze_elbo_trajectory(): """ Visualize typical ELBO trajectories. """ print("\n" + "=" * 60) print("Typical ELBO Trajectories") print("=" * 60) # Simulate different scenarios n_iter = 100 # Scenario 1: Fast convergence (well-conditioned) fast = [-1000] for i in range(n_iter - 1): improvement = 200 * np.exp(-0.2 * i) fast.append(fast[-1] + improvement) # Scenario 2: Slow convergence (ill-conditioned) slow = [-1000] for i in range(n_iter - 1): improvement = 50 * np.exp(-0.05 * i) slow.append(slow[-1] + improvement) # Scenario 3: Plateau then progress plateau = [-1000] for i in range(n_iter - 1): if i < 30: improvement = 5 else: improvement = 100 * np.exp(-0.1 * (i - 30)) plateau.append(plateau[-1] + improvement) print() print("Scenario 1 (Fast): Well-conditioned, weak coupling") stats_fast = analyze_convergence_rate(fast) print(f" Converged in ~{stats_fast['n_iterations']} iterations") print(f" Estimated ρ = {stats_fast['estimated_rho']:.3f}") print() print("Scenario 2 (Slow): Ill-conditioned, strong coupling") stats_slow = analyze_convergence_rate(slow) print(f" Still improving at {stats_slow['n_iterations']} iterations") print(f" Estimated ρ = {stats_slow['estimated_rho']:.3f}") print() print("Scenario 3 (Plateau): Gets stuck then escapes") stats_plateau = analyze_convergence_rate(plateau[:60]) print(" Shows plateau behavior - may indicate local minimum") if __name__ == "__main__": demonstrate_coupling_effects() analyze_elbo_trajectory()Understanding the ELBO's local optima structure is crucial for using mean-field VI effectively. Different initial conditions can lead to dramatically different solutions.
Why Multiple Local Optima Exist:
Symmetry: Many models have inherent symmetries. In mixture models, relabeling clusters gives equivalent solutions. In factor models, rotating factors gives equivalent fits. The ELBO has multiple equivalent optima.
Multi-modality: Some posteriors are genuinely multi-modal—multiple explanations fit the data. Mean-field can only approximate one mode at a time.
Factorization Artifacts: The mean-field constraint creates artificial local optima that wouldn't exist if we optimized over all distributions.
Non-convexity: The ELBO is generally non-convex in the variational parameters, admitting multiple stationary points.
Consider a Gaussian mixture with 3 clusters. The ELBO has at least 3! = 6 equivalent global optima (cluster relabelings). But it also has many suboptimal local optima: two true clusters might be merged and one split, or data might be assigned incorrectly. Random initialization often finds these suboptimal solutions.
Strategies for Handling Local Optima:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
import numpy as npfrom typing import List, Tuple, Callable, Optionalfrom dataclasses import dataclass @dataclassclass CAVIResult: """Result from a single CAVI run.""" final_elbo: float elbo_history: List[float] parameters: dict n_iterations: int def multiple_restarts( run_cavi: Callable[[], CAVIResult], n_restarts: int = 10, verbose: bool = True) -> CAVIResult: """ Run CAVI multiple times and return the best result. Args: run_cavi: Function that runs CAVI with random initialization n_restarts: Number of random restarts verbose: Whether to print progress Returns: Best result (highest ELBO) """ best_result = None best_elbo = float('-inf') all_elbos = [] for i in range(n_restarts): result = run_cavi() all_elbos.append(result.final_elbo) if result.final_elbo > best_elbo: best_elbo = result.final_elbo best_result = result if verbose: print(f"Restart {i+1}/{n_restarts}: ELBO = {result.final_elbo:.2f}") if verbose: print(f"\nBest ELBO: {best_elbo:.2f}") print(f"ELBO range: [{min(all_elbos):.2f}, {max(all_elbos):.2f}]") print(f"ELBO std: {np.std(all_elbos):.2f}") # Analyze how many restarts found similar solutions threshold = 0.01 * abs(best_elbo) # Within 1% n_good = sum(1 for e in all_elbos if abs(e - best_elbo) < threshold) print(f"Restarts finding best (within 1%): {n_good}/{n_restarts}") return best_result def deterministic_annealing( run_cavi_with_temp: Callable[[float], CAVIResult], temperatures: List[float] = None, verbose: bool = True) -> CAVIResult: """ Use deterministic annealing to escape local optima. Start with high temperature (smoothed ELBO), gradually decrease. Uses solution at each temperature to initialize the next. Args: run_cavi_with_temp: Function (temperature) -> CAVIResult temperatures: Temperature schedule (default: geometric from 10 to 1) Returns: Final result at temperature 1.0 """ if temperatures is None: # Geometric schedule from 10 down to 1 temperatures = list(np.geomspace(10, 1, num=10)) current_result = None for temp in temperatures: if verbose: print(f"Temperature {temp:.2f}:", end=" ") result = run_cavi_with_temp(temp) if verbose: print(f"ELBO = {result.final_elbo:.2f}") current_result = result return current_result def analyze_optima_distribution( run_cavi: Callable[[], CAVIResult], n_runs: int = 50) -> dict: """ Analyze the distribution of local optima found by random restarts. """ elbos = [] for _ in range(n_runs): result = run_cavi() elbos.append(result.final_elbo) elbos = np.array(elbos) # Cluster similar ELBO values (likely same local optimum) sorted_elbos = np.sort(elbos)[::-1] # Find distinct optima (gaps > 1% of range) range_elbo = sorted_elbos[0] - sorted_elbos[-1] threshold = 0.01 * range_elbo if range_elbo > 0 else 1.0 n_distinct = 1 for i in range(1, len(sorted_elbos)): if sorted_elbos[i-1] - sorted_elbos[i] > threshold: n_distinct += 1 return { 'n_runs': n_runs, 'best_elbo': np.max(elbos), 'worst_elbo': np.min(elbos), 'mean_elbo': np.mean(elbos), 'std_elbo': np.std(elbos), 'n_distinct_optima': n_distinct, 'elbo_range': np.max(elbos) - np.min(elbos) } def demonstrate_local_optima(): """Demonstrate the local optima phenomenon.""" print("Local Optima in Mean-Field VI") print("=" * 60) print() # Simulated scenario: Gaussian mixture with K=3 clusters print("Scenario: Gaussian Mixture Model with K=3 clusters") print() print("True clusters: well-separated, roughly equal size") print() # Simulate different local optima optima = [ ("Global optimum (correct clustering)", -1500.2), ("Local optimum: clusters 1&2 merged, cluster 3 split", -1523.7), ("Local optimum: cluster 1 empty", -1548.1), ("Local optimum: poor centroid placement", -1561.9), ] print("Possible optima found by random restarts:") print("-" * 60) for description, elbo in optima: print(f" ELBO = {elbo:.1f}: {description}") print() print("Key observations:") print(" • ELBO difference between best and worst: " f"{optima[0][1] - optima[-1][1]:.1f}") print(" • Random initialization finds suboptimal solutions frequently") print(" • Multiple restarts are essential for good solutions") print() # Simulate restart statistics print("Simulated restart statistics (50 runs):") print("-" * 60) # Pretend we ran 50 restarts print(" Found global optimum: 15 times (30%)") print(" Found second-best: 18 times (36%)") print(" Found poor solutions: 17 times (34%)") print() print("Recommendation: Run at least 10-20 restarts for mixture models") if __name__ == "__main__": demonstrate_local_optima()How do we know when CAVI has converged? And how do we distinguish between convergence to a good solution versus getting stuck? Here are essential diagnostic tools.
Primary Diagnostic: ELBO Monitoring
The ELBO should be computed and tracked at every iteration (or every few iterations for large models). Key things to watch for:
Watch for: (1) ELBO decreasing — indicates implementation bug; (2) ELBO going to ±∞ — numerical instability; (3) Very slow progress — may need different parameterization or initialization; (4) Immediate plateau — degenerate initialization or model misspecification.
Secondary Diagnostics:
Parameter Stability: Track changes in variational parameters. Convergence means parameters stop changing.
Responsibility Entropy: For mixture models, monitor the entropy of cluster assignments. Very low entropy (hard assignments) early may indicate premature convergence.
Predictive Performance: If possible, monitor held-out likelihood or other external metrics.
Component Usage: For mixture/factor models, check that all components are being used. Empty components suggest identifiability issues.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
import numpy as npfrom typing import List, Dict, Optionalfrom dataclasses import dataclass, fieldimport warnings @dataclassclass ConvergenceDiagnostics: """ Comprehensive diagnostics for CAVI convergence. Tracks ELBO, parameters, and derived quantities to detect convergence and diagnose issues. """ # Settings elbo_rtol: float = 1e-6 # Relative tolerance for ELBO elbo_atol: float = 1e-8 # Absolute tolerance for ELBO param_tol: float = 1e-5 # Tolerance for parameter changes # History elbo_history: List[float] = field(default_factory=list) param_history: List[Dict] = field(default_factory=list) # Diagnostics n_decreases: int = 0 decrease_magnitudes: List[float] = field(default_factory=list) def record_iteration( self, elbo: float, params: Optional[Dict] = None ) -> Dict: """ Record one iteration and return diagnostics. Args: elbo: Current ELBO value params: Optional dictionary of current parameters Returns: Dictionary with diagnostic information """ diag = { 'iteration': len(self.elbo_history), 'elbo': elbo, 'converged': False, 'issues': [] } # Check for invalid values if np.isnan(elbo): diag['issues'].append('ELBO is NaN') warnings.warn("ELBO is NaN - numerical instability") elif np.isinf(elbo): diag['issues'].append('ELBO is infinite') warnings.warn("ELBO is infinite - check for overflow/underflow") # Check monotonicity if len(self.elbo_history) > 0: prev = self.elbo_history[-1] change = elbo - prev diag['elbo_change'] = change diag['elbo_rel_change'] = change / abs(prev) if prev != 0 else 0 if change < -1e-10: # Allow tiny numerical errors self.n_decreases += 1 self.decrease_magnitudes.append(-change) diag['issues'].append(f'ELBO decreased by {-change:.2e}') warnings.warn( f"ELBO decreased at iteration {diag['iteration']}: " f"{prev:.4f} -> {elbo:.4f}" ) # Check convergence (ELBO criterion) rel_change = abs(change) / max(abs(prev), 1e-10) if rel_change < self.elbo_rtol and abs(change) < self.elbo_atol: diag['converged'] = True self.elbo_history.append(elbo) if params is not None: self.param_history.append(params.copy()) # Check parameter convergence if len(self.param_history) > 1: max_param_change = self._max_param_change( self.param_history[-2], self.param_history[-1] ) diag['max_param_change'] = max_param_change if max_param_change < self.param_tol: diag['params_converged'] = True return diag def _max_param_change(self, old: Dict, new: Dict) -> float: """Compute maximum absolute change across all parameters.""" max_change = 0 for key in old: if key in new: old_val = np.asarray(old[key]) new_val = np.asarray(new[key]) change = np.max(np.abs(old_val - new_val)) max_change = max(max_change, change) return max_change def summary_report(self) -> str: """Generate a summary report of the optimization.""" lines = [ "=" * 60, "CAVI Convergence Report", "=" * 60, f"Total iterations: {len(self.elbo_history)}", ] if len(self.elbo_history) > 0: lines.extend([ f"Initial ELBO: {self.elbo_history[0]:.4f}", f"Final ELBO: {self.elbo_history[-1]:.4f}", f"Total improvement: {self.elbo_history[-1] - self.elbo_history[0]:.4f}", ]) lines.append("") lines.append("Monotonicity Check:") if self.n_decreases == 0: lines.append(" ✓ ELBO never decreased (correct behavior)") else: lines.append(f" ✗ ELBO decreased {self.n_decreases} times (BUG!)") lines.append(f" Max decrease: {max(self.decrease_magnitudes):.2e}") if len(self.elbo_history) > 1: lines.append("") lines.append("Convergence Assessment:") final_change = abs(self.elbo_history[-1] - self.elbo_history[-2]) final_rel = final_change / max(abs(self.elbo_history[-2]), 1e-10) lines.append(f" Final ELBO change: {final_change:.2e}") lines.append(f" Final relative change: {final_rel:.2e}") if final_rel < self.elbo_rtol: lines.append(" ✓ Converged by ELBO criterion") else: lines.append(" ✗ May not have fully converged") return "\n".join(lines) def detect_plateau(self, window: int = 10, threshold: float = 0.01) -> bool: """ Detect if optimization is stuck on a plateau. A plateau is detected if the ELBO improvement over the last 'window' iterations is less than 'threshold' fraction of total improvement so far. """ if len(self.elbo_history) < window + 1: return False recent_improvement = self.elbo_history[-1] - self.elbo_history[-window] total_improvement = self.elbo_history[-1] - self.elbo_history[0] if total_improvement <= 0: return True # No improvement at all return recent_improvement / total_improvement < threshold def run_with_diagnostics(): """Example of running CAVI with full diagnostics.""" print("Running CAVI with Convergence Diagnostics") print("=" * 60) # Simulate a CAVI run diag = ConvergenceDiagnostics() elbo = -1000.0 np.random.seed(42) for i in range(100): # Simulate ELBO improvement improvement = 50 * np.exp(-0.1 * i) + np.random.randn() * 0.1 elbo += max(improvement, 0) # Ensure non-decreasing result = diag.record_iteration(elbo) if result['converged']: print(f"Converged at iteration {i}") break print() print(diag.summary_report()) if __name__ == "__main__": run_with_diagnostics()Mean-field VI isn't always the right choice. Understanding when it works well—and when it fails—helps you decide whether to use it for your problem.
Conditions Favoring Mean-Field:
Mean-field systematically underestimates posterior variance when variables are positively correlated. Intuitively: if z₁ is high making z₂ likely high, but we model them independently, we miss that they 'move together' and underestimate how uncertain the combination is. This is a fundamental limitation, not a bug.
| Model Type | Mean-Field Suitability | Notes |
|---|---|---|
| Mixture models (GMM, LDA) | Good | Local latents often weakly correlated |
| Factor models (PCA, FA) | Moderate | May miss factor correlations |
| Linear regression | Good | Especially for point estimates |
| Hierarchical models | Moderate | Global-local correlations can matter |
| Time series (HMM, LGSSM) | Moderate | Sequential structure helps some |
| Deep latent models (VAE) | Good with caveats | Amortization compensates somewhat |
| Spatial models | Variable | Depends on correlation length |
| Small Bayesian networks | Often poor | Strong conditional dependencies |
Understanding convergence is essential for effectively using mean-field variational inference. Here are the key points:
The final page of this module examines the limitations of mean-field VI in depth. We'll explore exactly when and why the factorization assumption fails, and what alternatives exist for cases where mean-field isn't sufficient.
You now understand the convergence properties of mean-field VI: what's guaranteed, what affects convergence speed, how to diagnose problems, and when to expect good versus poor performance. This knowledge is essential for applying mean-field VI confidently and correctly.