Loading learning content...
Imagine you're trying to find the highest point in a mountain range, but you can only move in one direction at a time—only North-South, or only East-West, never diagonally. Despite this constraint, you can still reach the peak by alternating: move North until you can't go higher, then move East until you can't go higher, then North again, and so on.
This is the essence of coordinate ascent.
In mean-field variational inference, we face a formidable optimization problem: find the best factorized distribution $q(\mathbf{z}) = \prod_i q_i(z_i)$ that maximizes the ELBO. Optimizing all factors simultaneously would be intractable. But the factorization assumption hands us a gift: we can optimize each factor independently while holding the others fixed, iterating until convergence.
By the end of this page, you will understand and be able to implement Coordinate Ascent Variational Inference (CAVI). You'll know how to derive update equations, why the algorithm is guaranteed to converge, and the practical considerations that make it work efficiently in real applications.
Coordinate ascent (also called coordinate-wise optimization or block coordinate descent/ascent) is a general optimization strategy that breaks a multivariate optimization problem into a sequence of univariate problems.
General Principle:
To maximize $f(x_1, x_2, \ldots, x_m)$:
Each step optimizes a single variable while treating all others as constants. The key insight is that the univariate optimization is often much easier than the joint optimization.
In the optimization landscape, each variable represents a 'coordinate axis.' Coordinate ascent moves along one axis at a time, rather than in arbitrary directions. When maximizing (going uphill), it's 'ascent'; when minimizing, it's 'descent.' Mean-field VI maximizes the ELBO, so we use coordinate ascent.
Application to Mean-Field VI:
In mean-field variational inference, the 'variables' we're optimizing are not scalar numbers—they're the factors $q_1(z_1), q_2(z_2), \ldots, q_m(z_m)$. Each factor is a probability distribution.
Our objective is the ELBO:
$$\mathcal{L}(q_1, q_2, \ldots, q_m) = \mathbb{E}{q}[\log p(\mathbf{x}, \mathbf{z})] - \sum{i=1}^{m} \mathbb{E}_{q_i}[\log q_i(z_i)]$$
Coordinate ascent for mean-field VI (commonly called CAVI) iterates:
| Problem | Variables | Update Step | Guarantee |
|---|---|---|---|
| Linear regression (normal equations) | Coefficients β₁, β₂, ... | Closed-form per coefficient | Converges in 1 pass (for linear) |
| Lasso regression | Coefficients β₁, β₂, ... | Soft-thresholding per coef | Converges to global optimum |
| Mean-field VI | Factors q₁, q₂, ... | Optimal factor formula | Converges to local ELBO maximum |
| K-means clustering | Assignments, centroids | Nearest centroid / mean | Converges to local minimum |
| EM algorithm | E-step, M-step | Posterior / MLE | Converges to local maximum |
Coordinate Ascent Variational Inference (CAVI) applies the coordinate ascent principle to the ELBO, using the optimal factor theorem we derived in the previous page.
Algorithm: CAVI
Input: Model $p(\mathbf{x}, \mathbf{z})$, observed data $\mathbf{x}$, initial factors ${q_j^{(0)}(z_j)}_{j=1}^{m}$
Output: Optimized factors ${q_j^*(z_j)}_{j=1}^{m}$
The notation q_{-j} means 'all factors except j.' When we write E_{q_{-j}}, we're taking an expectation over all latent variables except z_j, using their current approximate distributions. This expectation uses the most recent updates for factors 1 through j-1 and the previous iteration's values for factors j+1 through m.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
import numpy as npfrom typing import List, Callable, Tuple, Optionalfrom dataclasses import dataclassfrom abc import ABC, abstractmethod @dataclassclass CAVIResult: """Results from CAVI optimization.""" factors: List['VariationalFactor'] elbo_history: List[float] converged: bool n_iterations: int class VariationalFactor(ABC): """Abstract base class for variational factors.""" @abstractmethod def update(self, expected_sufficient_stats: dict): """Update factor parameters given expected sufficient statistics.""" pass @abstractmethod def entropy(self) -> float: """Compute entropy H[q_j].""" pass @abstractmethod def expected_log_prob(self) -> float: """Compute E_q[log q_j(z_j)].""" pass class CAVIOptimizer: """ Coordinate Ascent Variational Inference (CAVI) Optimizes a mean-field variational approximation by iteratively updating each factor while holding others fixed. The algorithm guarantees monotonic improvement of the ELBO and convergence to a local optimum. """ def __init__( self, compute_elbo: Callable[[], float], update_factor: Callable[[int], None], n_factors: int, verbose: bool = True ): """ Args: compute_elbo: Function that computes current ELBO update_factor: Function that updates factor j n_factors: Number of factors in the mean-field approximation verbose: Whether to print progress """ self.compute_elbo = compute_elbo self.update_factor = update_factor self.n_factors = n_factors self.verbose = verbose def optimize( self, max_iterations: int = 100, tolerance: float = 1e-6, check_elbo_every: int = 1 ) -> Tuple[List[float], bool]: """ Run CAVI until convergence. Args: max_iterations: Maximum number of full passes through factors tolerance: Convergence threshold for ELBO change check_elbo_every: How often to compute and check ELBO Returns: Tuple of (elbo_history, converged) """ elbo_history = [] prev_elbo = float('-inf') for iteration in range(max_iterations): # ================================== # One full pass: update each factor # ================================== for j in range(self.n_factors): # Update factor j using optimal factor formula: # log q_j(z_j) = E_{q_{-j}}[log p(x, z)] + const self.update_factor(j) # ================================== # Check ELBO and convergence # ================================== if iteration % check_elbo_every == 0: current_elbo = self.compute_elbo() elbo_history.append(current_elbo) # Verify monotonic increase (with numerical tolerance) if current_elbo < prev_elbo - 1e-10: print(f"WARNING: ELBO decreased! {prev_elbo:.6f} -> {current_elbo:.6f}") print("This indicates a bug in the update equations.") if self.verbose and iteration % 10 == 0: print(f"Iteration {iteration}: ELBO = {current_elbo:.6f}") # Check convergence if abs(current_elbo - prev_elbo) < tolerance: if self.verbose: print(f"Converged at iteration {iteration}") print(f"Final ELBO: {current_elbo:.6f}") return elbo_history, True prev_elbo = current_elbo if self.verbose: print(f"Did not converge in {max_iterations} iterations") print(f"Final ELBO: {elbo_history[-1]:.6f}") return elbo_history, False def cavi_pseudocode(): """Display the CAVI algorithm in clear pseudocode format.""" pseudocode = """ ╔══════════════════════════════════════════════════════════════╗ ║ COORDINATE ASCENT VARIATIONAL INFERENCE ║ ╠══════════════════════════════════════════════════════════════╣ ║ ║ ║ INPUT: Model p(x, z), data x, initial factors {q_j} ║ ║ OUTPUT: Optimized factors {q_j*} ║ ║ ║ ║ 1. Initialize factors q_1, q_2, ..., q_m ║ ║ ║ ║ 2. REPEAT: ║ ║ FOR j = 1 TO m: ║ ║ ┌─────────────────────────────────────────────────┐ ║ ║ │ log q_j(z_j) ← E_{q_{-j}}[log p(x, z)] + const │ ║ ║ │ │ ║ ║ │ (This is the OPTIMAL FACTOR FORMULA) │ ║ ║ └─────────────────────────────────────────────────┘ ║ ║ END FOR ║ ║ ║ ║ Compute ELBO = E_q[log p(x,z)] - E_q[log q(z)] ║ ║ ║ ║ 3. UNTIL ELBO converges ║ ║ ║ ║ 4. RETURN {q_1*, q_2*, ..., q_m*} ║ ║ ║ ╠══════════════════════════════════════════════════════════════╣ ║ GUARANTEES: ║ ║ • ELBO increases (or stays same) at each step ║ ║ • Algorithm converges to a local maximum ║ ║ • Each update is optimal given fixed other factors ║ ╚══════════════════════════════════════════════════════════════╝ """ print(pseudocode) if __name__ == "__main__": cavi_pseudocode()The heart of CAVI is deriving the specific update equation for each factor. This is where the mathematics meets the model. Let's work through the process systematically.
The General Recipe:
After computing E_{q_{-j}}[log p(x, z)], you need to recognize the functional form. If it looks like -λz_j + z_j log λ, that's a Poisson. If it looks like -(z_j - μ)²/2σ², that's a Gaussian. If it looks like α z_j + β(1-z_j), that's a Bernoulli. This pattern recognition becomes second nature with practice.
Worked Example: Gaussian Mixture Model
Consider a Gaussian Mixture Model for data $\mathbf{x} = {x_1, \ldots, x_n}$:
Mean-field factorization: $$q(\mathbf{z}, \boldsymbol{\mu}) = \prod_{i=1}^{n} q(z_i) \prod_{k=1}^{K} q(\mu_k)$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
import numpy as npfrom scipy.special import logsumexpfrom typing import Tuple """CAVI for Gaussian Mixture Model Complete derivation and implementation of update equations. Model: π = mixing proportions (fixed or with Dirichlet prior) μ_k ~ N(0, σ₀²) for k = 1, ..., K z_i ~ Categorical(π) for i = 1, ..., n x_i | z_i, μ ~ N(μ_{z_i}, σ²) Variational family: q(z, μ) = ∏_i q(z_i) × ∏_k q(μ_k) q(z_i) = Categorical(r_i) where r_{ik} = q(z_i = k) q(μ_k) = N(m_k, s_k²)""" def derive_gmm_updates(): """ Step-by-step derivation of CAVI updates for GMM. """ print("=" * 70) print("CAVI Updates for Gaussian Mixture Model") print("=" * 70) # ========================================================= # STEP 1: Write the joint log-probability # ========================================================= print("" + "─" * 70) print("STEP 1: Joint Log-Probability") print("─" * 70) print(""" log p(x, z, μ) = log p(μ) + log p(z) + log p(x | z, μ) = Σ_k log N(μ_k | 0, σ₀²) [prior on means] + Σ_i log π_{z_i} [prior on assignments] + Σ_i log N(x_i | μ_{z_i}, σ²) [likelihood] Expanding: = Σ_k [-μ_k²/(2σ₀²)] + Σ_i log π_{z_i} + Σ_i [-(x_i - μ_{z_i})²/(2σ²)] + constants """) # ========================================================= # STEP 2: Update for q(z_i) - cluster assignments # ========================================================= print("" + "─" * 70) print("STEP 2: Update for q(z_i)") print("─" * 70) print(""" log q*(z_i) = E_{q(-z_i)}[log p(x, z, μ)] + const Terms involving z_i: = E[log π_{z_i}] + E[-(x_i - μ_{z_i})²/(2σ²)] For z_i = k: log q*(z_i = k) = log π_k - (1/2σ²) E_q[(x_i - μ_k)²] = log π_k - (1/2σ²) [x_i² - 2x_i E[μ_k] + E[μ_k²]] = log π_k - (1/2σ²) [x_i² - 2x_i m_k + (m_k² + s_k²)] After normalization: ┌─────────────────────────────────────────────────────────────────┐ │ r_{ik} ∝ π_k × exp(-[x_i - m_k]²/(2σ²)) × exp(-s_k²/(2σ²)) │ │ │ │ The first two terms are standard GMM responsibility. │ │ The third term penalizes uncertain cluster means! │ └─────────────────────────────────────────────────────────────────┘ """) # ========================================================= # STEP 3: Update for q(μ_k) - cluster means # ========================================================= print("" + "─" * 70) print("STEP 3: Update for q(μ_k)") print("─" * 70) print(""" log q*(μ_k) = E_{q(-μ_k)}[log p(x, z, μ)] + const Terms involving μ_k: = -μ_k²/(2σ₀²) + Σ_i E_q[𝟙(z_i=k)] × [-(x_i - μ_k)²/(2σ²)] = -μ_k²/(2σ₀²) + Σ_i r_{ik} × [-(x_i - μ_k)²/(2σ²)] Let N_k = Σ_i r_{ik} (effective number of points in cluster k) Let x̄_k = (Σ_i r_{ik} x_i) / N_k (weighted mean of data) Collecting terms in μ_k: log q*(μ_k) = -μ_k² [1/(2σ₀²) + N_k/(2σ²)] + μ_k [N_k x̄_k / σ²] + const This is a Gaussian! Completing the square: ┌─────────────────────────────────────────────────────────────────┐ │ q*(μ_k) = N(m_k, s_k²) │ │ │ │ where: s_k² = 1 / (1/σ₀² + N_k/σ²) │ │ m_k = s_k² × (N_k x̄_k / σ²) │ │ │ │ Interpretation: │ │ - More data points → smaller variance (more certain) │ │ - Mean is weighted average of prior (0) and data mean (x̄_k) │ └─────────────────────────────────────────────────────────────────┘ """) print("" + "=" * 70) print("Summary: CAVI for GMM cycles between these two updates") print("=" * 70) class GaussianMixtureCavi: """ Complete CAVI implementation for Gaussian Mixture Model. """ def __init__( self, X: np.ndarray, K: int, sigma_sq: float = 1.0, sigma0_sq: float = 10.0 ): """ Args: X: Data matrix, shape (n, d) K: Number of clusters sigma_sq: Observation noise variance sigma0_sq: Prior variance on cluster means """ self.X = X self.n, self.d = X.shape self.K = K self.sigma_sq = sigma_sq self.sigma0_sq = sigma0_sq # Initialize variational parameters # q(z_i) = Categorical(r_i) self.r = np.random.dirichlet(np.ones(K), size=self.n) # q(μ_k) = N(m_k, s_k²) self.m = np.random.randn(K, self.d) # means self.s_sq = np.ones((K, self.d)) # variances (diagonal) # Uniform mixing proportions self.pi = np.ones(K) / K def update_assignments(self): """ Update q(z_i) for all i. r_{ik} ∝ π_k × exp(-(x_i - m_k)²/(2σ²) - s_k²/(2σ²)) """ log_r = np.zeros((self.n, self.K)) for k in range(self.K): # Log mixing weight log_r[:, k] = np.log(self.pi[k]) # Expected log-likelihood under q(μ_k) # E_q[-(x - μ)²/(2σ²)] = -(x - m)²/(2σ²) - s²/(2σ²) diff_sq = np.sum((self.X - self.m[k])**2, axis=1) var_term = np.sum(self.s_sq[k]) log_r[:, k] -= (diff_sq + var_term) / (2 * self.sigma_sq) # Normalize (softmax) log_r -= logsumexp(log_r, axis=1, keepdims=True) self.r = np.exp(log_r) def update_means(self): """ Update q(μ_k) for all k. s_k² = 1 / (1/σ₀² + N_k/σ²) m_k = s_k² × (N_k x̄_k / σ²) """ for k in range(self.K): # Effective count N_k = np.sum(self.r[:, k]) # Weighted data mean if N_k > 1e-10: x_bar_k = np.sum(self.r[:, k, np.newaxis] * self.X, axis=0) / N_k else: x_bar_k = np.zeros(self.d) # Update variance precision = 1/self.sigma0_sq + N_k/self.sigma_sq self.s_sq[k] = 1 / precision # Update mean self.m[k] = self.s_sq[k] * (N_k * x_bar_k / self.sigma_sq) def compute_elbo(self) -> float: """ Compute the ELBO. ELBO = E_q[log p(x, z, μ)] - E_q[log q(z, μ)] """ elbo = 0.0 # E_q[log p(μ)] - prior on means for k in range(self.K): elbo -= np.sum(self.m[k]**2 + self.s_sq[k]) / (2 * self.sigma0_sq) elbo -= 0.5 * self.d * np.log(2 * np.pi * self.sigma0_sq) # E_q[log p(z)] - prior on assignments for i in range(self.n): for k in range(self.K): if self.r[i, k] > 1e-10: elbo += self.r[i, k] * np.log(self.pi[k]) # E_q[log p(x | z, μ)] - likelihood for i in range(self.n): for k in range(self.K): if self.r[i, k] > 1e-10: diff_sq = np.sum((self.X[i] - self.m[k])**2) var_term = np.sum(self.s_sq[k]) log_lik = -0.5 * self.d * np.log(2 * np.pi * self.sigma_sq) log_lik -= (diff_sq + var_term) / (2 * self.sigma_sq) elbo += self.r[i, k] * log_lik # -E_q[log q(z)] - entropy of assignments for i in range(self.n): for k in range(self.K): if self.r[i, k] > 1e-10: elbo -= self.r[i, k] * np.log(self.r[i, k]) # -E_q[log q(μ)] - entropy of means for k in range(self.K): # Entropy of Gaussian: 0.5 * d * (1 + log(2π)) + 0.5 * sum(log(s²)) elbo += 0.5 * self.d * (1 + np.log(2 * np.pi)) elbo += 0.5 * np.sum(np.log(self.s_sq[k])) return elbo def fit(self, max_iter: int = 100, tol: float = 1e-6) -> list: """Run CAVI until convergence.""" elbo_history = [] for iteration in range(max_iter): # ======================== # E-step like: update q(z) # ======================== self.update_assignments() # ======================== # M-step like: update q(μ) # ======================== self.update_means() # ======================== # Monitor ELBO # ======================== elbo = self.compute_elbo() elbo_history.append(elbo) if iteration > 0 and abs(elbo - elbo_history[-2]) < tol: print(f"Converged at iteration {iteration}") break return elbo_history def demo_gmm_cavi(): """Demonstrate CAVI for GMM on synthetic data.""" # Generate synthetic data from 3 clusters np.random.seed(42) n_per_cluster = 100 true_means = np.array([[0, 0], [5, 0], [2.5, 4]]) X = np.vstack([ np.random.randn(n_per_cluster, 2) + true_means[0], np.random.randn(n_per_cluster, 2) + true_means[1], np.random.randn(n_per_cluster, 2) + true_means[2] ]) print("Running CAVI for Gaussian Mixture Model") print(f"Data: {X.shape[0]} points, {X.shape[1]} dimensions") print(f"True cluster means:{true_means}") print() model = GaussianMixtureCavi(X, K=3) elbo_history = model.fit() print(f"Inferred cluster means:") print(model.m) print(f"Final ELBO: {elbo_history[-1]:.2f}") return model, elbo_history if __name__ == "__main__": derive_gmm_updates() print("" + "=" * 70 + "") demo_gmm_cavi()The order in which we update factors can significantly affect CAVI's performance. While the algorithm converges regardless of ordering (under mild conditions), different orderings can lead to faster convergence or different local optima.
Common Ordering Strategies:
Sequential updates typically converge in fewer iterations because each update uses the most recent information. However, parallel updates can be faster in wall-clock time on multi-core systems by updating all factors simultaneously. The choice depends on your computational environment and model structure.
The EM Connection:
For many models, CAVI naturally decomposes into updates that resemble the E-step and M-step of Expectation Maximization:
E-step-like: Update the factors corresponding to local latent variables (e.g., $q(z_i)$ for each data point). These updates compute 'responsibilities' or 'posterior assignments.'
M-step-like: Update the factors corresponding to global parameters (e.g., $q(\theta)$ for model parameters). These updates aggregate information from all data points.
Grouping updates this way often improves convergence because it respects the natural structure of the model: local variables depend on global parameters, which aggregate local information.
| Strategy | Parallelizable | Convergence Speed | Memory | Best For |
|---|---|---|---|---|
| Sequential | No | Fast (fewer iterations) | Low | Small/medium models |
| Parallel | Yes | Slower (more iterations) | Higher | Large-scale, GPU |
| Random | No | Variable | Low | Escaping local optima |
| Adaptive | Partial | Fastest | Medium | When updates differ in cost |
| Block | Partial | Fast | Medium | Hierarchical structure |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
import numpy as npfrom typing import List, Callablefrom enum import Enum class UpdateOrder(Enum): SEQUENTIAL = "sequential" PARALLEL = "parallel" RANDOM = "random" ADAPTIVE = "adaptive" def cavi_with_ordering( update_factor: Callable[[int, dict], dict], n_factors: int, compute_elbo: Callable[[], float], ordering: UpdateOrder = UpdateOrder.SEQUENTIAL, max_iter: int = 100) -> List[float]: """ CAVI with configurable update ordering. Args: update_factor: Function (factor_idx, current_params) -> new_params n_factors: Number of factors compute_elbo: Function to compute current ELBO ordering: Update ordering strategy max_iter: Maximum iterations Returns: History of ELBO values """ elbo_history = [] params = {j: None for j in range(n_factors)} # Factor parameters for iteration in range(max_iter): if ordering == UpdateOrder.SEQUENTIAL: # ========================================= # Sequential: Update 1, 2, 3, ..., m # Each update uses most recent values # ========================================= for j in range(n_factors): params[j] = update_factor(j, params) elif ordering == UpdateOrder.PARALLEL: # ========================================= # Parallel: Compute all updates from prev iteration # Then apply all at once # ========================================= new_params = {} for j in range(n_factors): new_params[j] = update_factor(j, params) params = new_params elif ordering == UpdateOrder.RANDOM: # ========================================= # Random: Permute order each iteration # Helps avoid pathological orderings # ========================================= order = np.random.permutation(n_factors) for j in order: params[j] = update_factor(j, params) elif ordering == UpdateOrder.ADAPTIVE: # ========================================= # Adaptive: Prioritize factors with large gradients # (Simplified version - true adaptive is more complex) # ========================================= # Compute priority scores (e.g., gradient magnitude) # For simplicity, just use random here order = np.random.permutation(n_factors) for j in order: params[j] = update_factor(j, params) elbo = compute_elbo() elbo_history.append(elbo) return elbo_history def demonstrate_ordering_effects(): """Show how different orderings affect convergence.""" print("Update Ordering Effects") print("=" * 50) print() print("For a model with factors q(z₁), q(z₂), q(z₃):") print() print("SEQUENTIAL (iteration t):") print(" q₁⁽ᵗ⁾ ← update using q₂⁽ᵗ⁻¹⁾, q₃⁽ᵗ⁻¹⁾") print(" q₂⁽ᵗ⁾ ← update using q₁⁽ᵗ⁾ [new!], q₃⁽ᵗ⁻¹⁾") print(" q₃⁽ᵗ⁾ ← update using q₁⁽ᵗ⁾ [new!], q₂⁽ᵗ⁾ [new!]") print() print("PARALLEL (iteration t):") print(" q₁⁽ᵗ⁾ ← update using q₂⁽ᵗ⁻¹⁾, q₃⁽ᵗ⁻¹⁾") print(" q₂⁽ᵗ⁾ ← update using q₁⁽ᵗ⁻¹⁾, q₃⁽ᵗ⁻¹⁾") print(" q₃⁽ᵗ⁾ ← update using q₁⁽ᵗ⁻¹⁾, q₂⁽ᵗ⁻¹⁾") print() print("Key insight: Sequential uses 'fresher' information") print("but cannot be parallelized across factors.") if __name__ == "__main__": demonstrate_ordering_effects()One of the most important properties of CAVI is its convergence guarantee. Unlike many optimization algorithms that might oscillate or diverge, CAVI is guaranteed to converge under mild conditions.
Theorem (CAVI Convergence):
For a mean-field variational family with the optimal factor updates:
Proof Sketch:
Each factor update sets $q_j$ to maximize the ELBO while holding other factors fixed. Since we're choosing the optimal update (via the optimal factor formula), the ELBO can only increase or stay the same. Since the ELBO is bounded above by $\log p(\mathbf{x})$, the monotonically increasing sequence must converge.
CAVI converges to a LOCAL maximum of the ELBO, not necessarily the GLOBAL maximum. The ELBO landscape can have many local optima, and CAVI will find one depending on initialization. This is similar to EM, K-means, and other coordinate-wise algorithms. Multiple restarts with different initializations are often used to find better solutions.
Why Monotonicity Matters:
The monotonic ELBO increase property is crucial for:
Conditions for Convergence:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
import numpy as npfrom typing import List, Tuple, Optionalimport warnings class CAVIMonitor: """ Monitor CAVI convergence and detect issues. This class tracks ELBO history and provides diagnostics for debugging and ensuring correct implementation. """ def __init__(self, tolerance: float = 1e-6): self.tolerance = tolerance self.elbo_history: List[float] = [] self.iteration = 0 self.n_increases = 0 self.n_decreases = 0 self.decrease_magnitudes: List[float] = [] def log_elbo(self, elbo: float) -> dict: """ Log an ELBO value and return diagnostics. Args: elbo: Current ELBO value Returns: Dictionary with convergence diagnostics """ self.iteration += 1 diagnostics = { 'iteration': self.iteration, 'elbo': elbo, 'converged': False, 'issue': None } if len(self.elbo_history) > 0: prev_elbo = self.elbo_history[-1] change = elbo - prev_elbo diagnostics['change'] = change # Check for invalid values if np.isnan(elbo) or np.isinf(elbo): diagnostics['issue'] = 'ELBO is NaN or Inf - numerical instability' warnings.warn(diagnostics['issue']) # Check monotonicity (ELBO should not decrease) if change < -1e-10: # Small tolerance for numerical errors self.n_decreases += 1 self.decrease_magnitudes.append(-change) diagnostics['issue'] = f'ELBO DECREASED by {-change:.6f}' warnings.warn(f"CAVI iteration {self.iteration}: {diagnostics['issue']}") warnings.warn("This indicates a bug in the update equations!") else: self.n_increases += 1 # Check convergence if abs(change) < self.tolerance: diagnostics['converged'] = True self.elbo_history.append(elbo) return diagnostics def summary(self) -> str: """Generate a summary of the CAVI run.""" if len(self.elbo_history) == 0: return "No ELBO values logged." lines = [ "=" * 50, "CAVI Convergence Summary", "=" * 50, f"Total iterations: {self.iteration}", f"Initial ELBO: {self.elbo_history[0]:.4f}", f"Final ELBO: {self.elbo_history[-1]:.4f}", f"Total improvement: {self.elbo_history[-1] - self.elbo_history[0]:.4f}", "", f"Monotonicity check:", f" Increases: {self.n_increases}", f" Decreases: {self.n_decreases} {'(BUG!)' if self.n_decreases > 0 else '(OK)'}", ] if self.n_decreases > 0: lines.extend([ f" Max decrease: {max(self.decrease_magnitudes):.6f}", " WARNING: ELBO should NEVER decrease.", " Check your update equations!" ]) return "\n".join(lines) def plot_convergence(self): """Generate convergence plot data.""" return { 'iterations': list(range(len(self.elbo_history))), 'elbo': self.elbo_history, 'converged_at': self._find_convergence_point() } def _find_convergence_point(self) -> Optional[int]: """Find first iteration where convergence criterion is met.""" for i in range(1, len(self.elbo_history)): if abs(self.elbo_history[i] - self.elbo_history[i-1]) < self.tolerance: return i return None def analyze_local_optima(): """ Demonstrate that CAVI can find different local optima. """ print("Local Optima in CAVI") print("=" * 50) print() print("CAVI finds LOCAL maxima, not global.") print("Different initializations can lead to different solutions.") print() print("Strategies to find better optima:") print(" 1. Multiple random restarts") print(" 2. Informed initialization (e.g., from simpler method)") print(" 3. Annealing (start with flattened ELBO)") print(" 4. Hierarchical initialization (coarse to fine)") print() print("Example: In a Gaussian mixture with K clusters,") print("random initialization might find a solution where") print("two true clusters are merged and one is split.") print("A different initialization might find the correct clustering.") if __name__ == "__main__": # Demonstrate monitoring monitor = CAVIMonitor(tolerance=1e-4) # Simulate a normal CAVI run np.random.seed(42) elbo = -1000.0 print("Simulating CAVI run...") for i in range(50): # ELBO should increase (with decreasing increments) improvement = np.random.exponential(10) / (i + 1) elbo += improvement diag = monitor.log_elbo(elbo) if diag['converged']: print(f"Converged at iteration {i}") break print() print(monitor.summary()) print() analyze_local_optima()Implementing CAVI in practice requires attention to several details that can make the difference between an algorithm that works and one that fails. Here are the key practical considerations:
1. Initialization:
Good initialization can dramatically speed convergence and help find better local optima:
For mixture models, initializing cluster assignments randomly is often better than initializing all points to one cluster. For topic models, initializing word-topic assignments based on word frequencies can help. The key is to start in a 'reasonable' part of the space, not necessarily the optimal one.
2. Numerical Stability:
Variational inference involves many operations that can cause numerical issues:
3. Convergence Criteria:
Choosing when to stop CAVI requires balancing computation time against solution quality:
Typical values: $\epsilon \approx 10^{-4}$ to $10^{-6}$ for ELBO-based criteria.
| Component | Best Practice | Common Mistake |
|---|---|---|
| Initialization | Use data-driven or multiple restarts | Single random initialization |
| Log probabilities | Work in log-space throughout | Compute probabilities directly |
| Normalization | Use logsumexp function | exp then divide (underflow) |
| Variance params | Floor at small positive value | Allow zero or negative |
| ELBO monitoring | Check every iteration | Never compute ELBO |
| Convergence check | Relative and absolute | Only max iterations |
| Debugging | Verify ELBO never decreases | Ignore ELBO behavior |
Coordinate Ascent Variational Inference (CAVI) is the workhorse algorithm for mean-field variational inference. Let's consolidate the key points:
Now that we understand the generic CAVI algorithm, the next page derives the specific update equations for common models. You'll see how the optimal factor formula gives closed-form updates for exponential family models, making CAVI highly efficient in practice.
You now understand Coordinate Ascent Variational Inference—the fundamental algorithm for optimizing mean-field approximations. You know how to derive updates, choose ordering strategies, ensure convergence, and handle practical implementation details. Next, we'll derive specific update equations that make CAVI work for real models.