Loading content...
We've established the mean-field factorization and the CAVI algorithm. We know that each optimal factor satisfies:
$$\log q_j^*(z_j) = \mathbb{E}{q{-j}}[\log p(\mathbf{x}, \mathbf{z})] + \text{const}$$
But how do we actually compute this for real models? The answer lies in the remarkable structure of exponential family distributions. When our model is composed of exponential family components—which includes nearly all distributions used in practice—the optimal factors have closed-form updates that can be computed efficiently.
This page is your reference for deriving and implementing these update equations.
By the end of this page, you will be able to derive mean-field update equations for any model built from exponential family components. You'll have a catalog of updates for the most common distributions and understand the pattern that makes these derivations systematic.
Before diving into specific update equations, let's review the exponential family—the mathematical structure that makes mean-field VI tractable.
Exponential Family Definition:
A distribution is in the exponential family if it can be written as:
$$p(x | \eta) = h(x) \exp\left(\eta^\top t(x) - A(\eta)\right)$$
where:
| Distribution | Natural Params η | Sufficient Stats t(x) | Log-Partition A(η) |
|---|---|---|---|
| Gaussian N(μ, σ²) | [μ/σ², -1/(2σ²)] | [x, x²] | μ²/(2σ²) + log σ |
| Bernoulli(p) | log(p/(1-p)) | x | log(1 + exp(η)) |
| Poisson(λ) | log λ | x | exp(η) |
| Gamma(α, β) | [α-1, -β] | [log x, x] | log Γ(α) - α log β |
| Dirichlet(α) | α - 1 | log x | Σ log Γ(αₖ) - log Γ(Σαₖ) |
| Multinomial(π) | log π | count(x) | log(Σ exp(ηₖ)) |
When log p(x, z) is linear in the sufficient statistics of z, taking expectations E_{q_{-j}}[log p(x, z)] yields an expression linear in E[t(z_j)]. This linear form means log q_j*(z_j) is in the same exponential family as the prior/conditional—we just update the natural parameters!
The Exponential Family Conjugacy Property:
For mean-field VI, the key insight is:
If $z_j$ appears in $\log p(\mathbf{x}, \mathbf{z})$ only through its sufficient statistics $t(z_j)$, then $q_j^*(z_j)$ is in the same exponential family as the prior/complete conditional for $z_j$.
This means we don't need to search over arbitrary distributions—we just need to find the new natural parameters for a known distribution family.
Gaussian distributions are ubiquitous in machine learning. Let's derive the mean-field update for a Gaussian factor.
Setting:
Suppose $z_j \in \mathbb{R}$ and the terms in $\log p(\mathbf{x}, \mathbf{z})$ involving $z_j$ are quadratic:
$$\log p(\mathbf{x}, \mathbf{z}) = -\frac{a}{2} z_j^2 + b z_j + \text{terms not involving } z_j$$
where $a$ and $b$ may depend on other latent variables and data.
Derivation:
Taking the expectation over $q_{-j}$:
$$\mathbb{E}{q{-j}}[\log p(\mathbf{x}, \mathbf{z})] = -\frac{\mathbb{E}[a]}{2} z_j^2 + \mathbb{E}[b] z_j + \text{const}$$
This is a quadratic in $z_j$—the log of a Gaussian! Completing the square:
$$q_j^*(z_j) = \mathcal{N}(z_j | \mu_j, \sigma_j^2)$$
where: $$\sigma_j^2 = \frac{1}{\mathbb{E}{q{-j}}[a]}, \quad \mu_j = \sigma_j^2 \cdot \mathbb{E}{q{-j}}[b]$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
import numpy as npfrom dataclasses import dataclassfrom typing import Tuple @dataclassclass GaussianVariationalFactor: """ Gaussian variational factor q(z) = N(z | μ, σ²) In natural parameter form: - η₁ = μ/σ² (precision-weighted mean) - η₂ = -1/(2σ²) (negative half precision) Updates are computed from expected quadratic coefficients. """ mean: float variance: float @property def precision(self) -> float: return 1.0 / self.variance @property def natural_param1(self) -> float: """η₁ = μ/σ² = μ × precision""" return self.mean * self.precision @property def natural_param2(self) -> float: """η₂ = -1/(2σ²) = -precision/2""" return -0.5 * self.precision def expected_x(self) -> float: """E[z] = μ""" return self.mean def expected_x_squared(self) -> float: """E[z²] = μ² + σ²""" return self.mean**2 + self.variance def entropy(self) -> float: """H[N(μ, σ²)] = 0.5 × (1 + log(2πσ²))""" return 0.5 * (1 + np.log(2 * np.pi * self.variance)) @classmethod def from_quadratic_coefficients( cls, E_a: float, E_b: float, min_variance: float = 1e-10 ) -> 'GaussianVariationalFactor': """ Create Gaussian factor from expected quadratic coefficients. If log p(x,z) = -(a/2)z² + bz + const Then q*(z) = N(μ, σ²) where: σ² = 1/E[a] μ = σ² × E[b] Args: E_a: E_{q_{-j}}[a] - expected coefficient on z² E_b: E_{q_{-j}}[b] - expected coefficient on z min_variance: Minimum variance for numerical stability """ if E_a <= 0: raise ValueError(f"E[a] must be positive for valid Gaussian, got {E_a}") variance = max(1.0 / E_a, min_variance) mean = variance * E_b return cls(mean=mean, variance=variance) def derive_gaussian_update_example(): """ Example: Bayesian Linear Regression weight update Model: w ~ N(0, τ₀⁻¹) (prior) y_i | x_i, w ~ N(w x_i, τ⁻¹) (likelihood) log p(y, w) = -τ₀/2 × w² + Σᵢ[-τ/2 × (yᵢ - w xᵢ)²] + const = -τ₀/2 × w² - τ/2 × Σᵢ(yᵢ² - 2yᵢwxᵢ + w²xᵢ²) + const = -w²/2 × (τ₀ + τΣxᵢ²) + w × (τΣxᵢyᵢ) + const So: a = τ₀ + τΣxᵢ² b = τΣxᵢyᵢ """ print("Gaussian Update: Bayesian Linear Regression") print("=" * 60) # Simulate data np.random.seed(42) n = 50 x = np.random.randn(n) true_w = 2.5 tau = 4.0 # Known noise precision y = true_w * x + np.random.randn(n) / np.sqrt(tau) # Prior tau_0 = 1.0 # Prior precision # Compute quadratic coefficients sum_x_sq = np.sum(x**2) sum_xy = np.sum(x * y) E_a = tau_0 + tau * sum_x_sq E_b = tau * sum_xy print(f"Data: n={n} observations") print(f"True weight: {true_w}") print(f"Prior precision τ₀ = {tau_0}") print(f"Noise precision τ = {tau}") print() print(f"Quadratic coefficients:") print(f" a = τ₀ + τΣxᵢ² = {E_a:.4f}") print(f" b = τΣxᵢyᵢ = {E_b:.4f}") # Create optimal factor q_w = GaussianVariationalFactor.from_quadratic_coefficients(E_a, E_b) print(f"Optimal q*(w) = N(μ, σ²):") print(f" μ = {q_w.mean:.4f} (true w = {true_w})") print(f" σ² = {q_w.variance:.6f}") print(f" σ = {np.sqrt(q_w.variance):.4f}") print() print(f"95% credible interval: [{q_w.mean - 1.96*np.sqrt(q_w.variance):.3f}, " f"{q_w.mean + 1.96*np.sqrt(q_w.variance):.3f}]") return q_w if __name__ == "__main__": derive_gaussian_update_example()For multivariate Gaussian factors q(z_j) = N(μ_j, Σ_j), the same principle applies. If log p contains -z_jᵀ A z_j / 2 + bᵀ z_j, then Σ_j⁻¹ = E[A] and μ_j = Σ_j E[b]. The key is recognizing the quadratic form.
Categorical (discrete) latent variables are common in mixture models, topic models, and hidden Markov models. The mean-field update takes a particularly elegant form.
Setting:
Suppose $z_j \in {1, 2, \ldots, K}$ (a discrete variable with $K$ states). The terms in $\log p(\mathbf{x}, \mathbf{z})$ involving $z_j$ can be written:
$$\log p(\mathbf{x}, \mathbf{z}) = \sum_{k=1}^{K} \mathbb{1}(z_j = k) \cdot f_k(\text{other variables}) + \text{terms not involving } z_j$$
where $f_k$ is some function that may depend on other latent variables and data.
Derivation:
$$\log q_j^*(z_j = k) = \mathbb{E}{q{-j}}[f_k] + \text{const}$$
Normalizing to get valid probabilities:
$$q_j^*(z_j = k) = \frac{\exp\left(\mathbb{E}{q{-j}}[f_k]\right)}{\sum_{k'=1}^{K} \exp\left(\mathbb{E}{q{-j}}[f_{k'}]\right)}$$
This is a softmax over expected log-terms!
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
import numpy as npfrom scipy.special import logsumexp, digammafrom dataclasses import dataclassfrom typing import List @dataclassclass CategoricalVariationalFactor: """ Categorical variational factor q(z) = Cat(z | r) where r = [r_1, ..., r_K] are the probabilities for each state. Often called "responsibilities" in mixture model context. """ probabilities: np.ndarray # Shape (K,), sums to 1 def __post_init__(self): # Ensure normalization self.probabilities = self.probabilities / np.sum(self.probabilities) @property def K(self) -> int: return len(self.probabilities) def expected_indicator(self, k: int) -> float: """E[𝟙(z = k)] = r_k""" return self.probabilities[k] def expected_indicators(self) -> np.ndarray: """E[𝟙(z = k)] for all k""" return self.probabilities.copy() def entropy(self) -> float: """H[Cat(r)] = -Σ r_k log r_k""" # Avoid log(0) p = self.probabilities p_safe = np.where(p > 0, p, 1) return -np.sum(p * np.log(p_safe)) @classmethod def from_log_unnormalized( cls, log_unnorm: np.ndarray, temperature: float = 1.0 ) -> 'CategoricalVariationalFactor': """ Create categorical factor from log-unnormalized probabilities. q*(z = k) ∝ exp(log_unnorm[k]) This is the standard form for mean-field categorical updates: log_unnorm[k] = E_{q_{-j}}[terms involving z_j = k] Args: log_unnorm: Log unnormalized probabilities, shape (K,) temperature: Temperature for softmax (1.0 = standard) """ # Softmax with numerical stability log_unnorm = log_unnorm / temperature log_probs = log_unnorm - logsumexp(log_unnorm) probs = np.exp(log_probs) return cls(probabilities=probs) def derive_gmm_assignment_update(): """ Example: Gaussian Mixture Model cluster assignment update Model: π = mixing proportions μ_k ~ N(0, σ₀²) for each cluster k z_i ~ Cat(π) x_i | z_i ~ N(μ_{z_i}, σ²) For z_i with mean-field q(z_i) = Cat(r_i): log q*(z_i = k) = E[log π_k] + E[log N(x_i | μ_k, σ²)] = log π_k - (1/2σ²) E[(x_i - μ_k)²] """ print("Categorical Update: GMM Cluster Assignments") print("=" * 60) # Simulate scenario K = 3 # clusters x_i = 2.3 # one data point sigma_sq = 1.0 # observation variance # Current variational parameters for cluster means # q(μ_k) = N(m_k, s_k²) m = np.array([0.0, 2.5, 5.0]) # means s_sq = np.array([0.5, 0.3, 0.4]) # variances # Mixing proportions pi = np.array([0.3, 0.4, 0.3]) print(f"Observation x_i = {x_i}") print(f"Cluster means (m_k): {m}") print(f"Cluster variances (s_k²): {s_sq}") print(f"Mixing proportions (π): {pi}") print() # Compute log_unnorm[k] = E[log p(z_i=k, x_i | μ_k)] log_unnorm = np.zeros(K) for k in range(K): # E[log π_k] = log π_k (fixed) log_unnorm[k] += np.log(pi[k]) # E[log N(x_i | μ_k, σ²)] # = -1/(2σ²) × E[(x_i - μ_k)²] # = -1/(2σ²) × (x_i² - 2 x_i E[μ_k] + E[μ_k²]) # = -1/(2σ²) × (x_i² - 2 x_i m_k + m_k² + s_k²) E_sq_diff = x_i**2 - 2*x_i*m[k] + m[k]**2 + s_sq[k] log_unnorm[k] += -0.5/sigma_sq * E_sq_diff print("Log-unnormalized probabilities:") for k in range(K): print(f" log q*(z_i = {k}) = {log_unnorm[k]:.4f}") # Create optimal factor q_z = CategoricalVariationalFactor.from_log_unnormalized(log_unnorm) print(f"Optimal q*(z_i) = Cat(r_i):") for k in range(K): print(f" r_{k} = P(z_i = {k}) = {q_z.probabilities[k]:.4f}") print(f"Most likely cluster: {np.argmax(q_z.probabilities)}") print(f"Entropy H[q*(z_i)] = {q_z.entropy():.4f} nats") return q_z def derive_topic_model_assignment(): """ Example: LDA word-topic assignment update Similar pattern but with Dirichlet priors on mixing proportions. """ print("" + "=" * 60) print("Categorical Update: LDA Word-Topic Assignment") print("=" * 60) K = 4 # number of topics # Word in vocabulary word_idx = 42 # Current variational Dirichlet parameters for: # - θ_d: document's topic distribution q(θ_d) = Dir(γ_d) # - φ_k: topic's word distribution q(φ_k) = Dir(λ_k) gamma_d = np.array([10.0, 5.0, 8.0, 3.0]) # doc-topic params lambda_k_word = np.array([0.5, 2.0, 0.1, 1.5]) # topic-word params for word_idx lambda_k_sum = np.array([100.0, 100.0, 100.0, 100.0]) # sum of λ params # log q*(z_{dn} = k) = E[log θ_{dk}] + E[log φ_{k, word}] # = ψ(γ_{dk}) - ψ(Σγ_d) + ψ(λ_{k,word}) - ψ(Σλ_k) log_unnorm = np.zeros(K) for k in range(K): # E[log θ_k] under Dirichlet: ψ(γ_k) - ψ(sum(γ)) E_log_theta_k = digamma(gamma_d[k]) - digamma(np.sum(gamma_d)) # E[log φ_{k,w}] under Dirichlet: ψ(λ_{kw}) - ψ(sum(λ_k)) E_log_phi_kw = digamma(lambda_k_word[k]) - digamma(lambda_k_sum[k]) log_unnorm[k] = E_log_theta_k + E_log_phi_kw print(f"Word index: {word_idx}") print(f"Dirichlet params γ_d: {gamma_d}") print() print("Expected log probabilities:") for k in range(K): E_log_theta = digamma(gamma_d[k]) - digamma(np.sum(gamma_d)) E_log_phi = digamma(lambda_k_word[k]) - digamma(lambda_k_sum[k]) print(f" Topic {k}: E[log θ_k]={E_log_theta:.3f}, E[log φ_kw]={E_log_phi:.3f}") q_z = CategoricalVariationalFactor.from_log_unnormalized(log_unnorm) print(f"Optimal q*(z_dn) = Cat(φ_dn):") for k in range(K): print(f" Topic {k}: {q_z.probabilities[k]:.4f}") if __name__ == "__main__": derive_gmm_assignment_update() derive_topic_model_assignment()The Dirichlet distribution is the conjugate prior for categorical and multinomial distributions, making it extremely common in mixture models, topic models, and hierarchical Bayesian models.
Setting:
Supppose $\boldsymbol{\theta}_j$ is a probability vector (sums to 1) with Dirichlet prior $\text{Dir}(\boldsymbol{\alpha}_0)$. The terms in $\log p(\mathbf{x}, \mathbf{z})$ involving $\boldsymbol{\theta}_j$ typically look like:
$$\log p(\mathbf{x}, \mathbf{z}) = \sum_{k=1}^{K} (\alpha_{0k} - 1) \log \theta_{jk} + \sum_{k=1}^{K} n_k \log \theta_{jk} + \ldots$$
where $n_k$ counts observations assigned to category $k$.
Derivation:
Taking expectations:
$$\mathbb{E}{q{-j}}[\log p(\mathbf{x}, \mathbf{z})] = \sum_{k=1}^{K} (\alpha_{0k} - 1 + \mathbb{E}[n_k]) \log \theta_{jk} + \text{const}$$
This is the log of a Dirichlet! The optimal factor is:
$$q_j^*(\boldsymbol{\theta}_j) = \text{Dir}(\boldsymbol{\alpha}_j)$$
where: $$\alpha_{jk} = \alpha_{0k} + \mathbb{E}{q{-j}}[n_k]$$
The pattern is beautifully simple: start with the prior concentration parameters α₀, then add expected counts. Each expected count comes from the responsibilities of categorical latent variables: E[n_k] = Σᵢ q(zᵢ = k). This is the Bayesian 'pseudocount' interpretation of Dirichlet parameters.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
import numpy as npfrom scipy.special import digamma, gammalnfrom dataclasses import dataclass @dataclassclass DirichletVariationalFactor: """ Dirichlet variational factor q(θ) = Dir(θ | α) The Dirichlet is conjugate to categorical/multinomial observations. Updates have the simple form: α_new = α_prior + expected_counts """ alpha: np.ndarray # Concentration parameters, shape (K,) @property def K(self) -> int: return len(self.alpha) @property def alpha_sum(self) -> float: return np.sum(self.alpha) def expected_theta(self) -> np.ndarray: """E[θ_k] = α_k / Σα""" return self.alpha / self.alpha_sum def expected_log_theta(self) -> np.ndarray: """E[log θ_k] = ψ(α_k) - ψ(Σα)""" return digamma(self.alpha) - digamma(self.alpha_sum) def variance_theta(self, k: int) -> float: """Var[θ_k] = α_k(α₀ - α_k) / (α₀²(α₀ + 1))""" a0 = self.alpha_sum ak = self.alpha[k] return ak * (a0 - ak) / (a0**2 * (a0 + 1)) def entropy(self) -> float: """ H[Dir(α)] = log B(α) - (K - α₀)ψ(α₀) - Σ(α_k - 1)ψ(α_k) where B(α) = ∏Γ(α_k) / Γ(Σα_k) """ log_B = np.sum(gammaln(self.alpha)) - gammaln(self.alpha_sum) return (log_B - (self.K - self.alpha_sum) * digamma(self.alpha_sum) - np.sum((self.alpha - 1) * digamma(self.alpha))) @classmethod def from_prior_and_counts( cls, alpha_prior: np.ndarray, expected_counts: np.ndarray ) -> 'DirichletVariationalFactor': """ Create Dirichlet factor from prior and expected counts. This is the standard mean-field update: α_k = α₀_k + E[n_k] Args: alpha_prior: Prior concentration parameters expected_counts: Expected counts for each category """ alpha = alpha_prior + expected_counts return cls(alpha=alpha) def derive_lda_document_topic_update(): """ Example: LDA document-topic distribution update Model: θ_d ~ Dir(α) (prior on doc's topic dist) z_{dn} | θ_d ~ Cat(θ_d) (topic for word n in doc d) For mean-field q(θ_d) = Dir(γ_d): γ_{dk} = α_k + Σ_n q(z_{dn} = k) = α_k + (expected count of topic k in document d) """ print("Dirichlet Update: LDA Document-Topic Distribution") print("=" * 60) K = 5 # number of topics N_d = 100 # words in document # Prior alpha = np.ones(K) * 0.1 # Sparse Dirichlet prior # Current responsibilities for words in this document # q(z_{dn} = k) for n = 1, ..., N_d # Simulate: this document is mostly about topics 1 and 3 np.random.seed(42) topic_probs = np.array([0.1, 0.35, 0.1, 0.35, 0.1]) responsibilities = np.random.dirichlet(topic_probs * 10, size=N_d) # Expected counts: E[n_k] = Σ_n q(z_{dn} = k) expected_counts = np.sum(responsibilities, axis=0) print(f"Document with {N_d} words, {K} topics") print(f"Prior α = {alpha}") print(f"Expected topic counts: {expected_counts.round(1)}") # Update q_theta = DirichletVariationalFactor.from_prior_and_counts(alpha, expected_counts) print(f"Posterior Dirichlet parameters γ:") print(f" γ = α + E[counts] = {q_theta.alpha.round(2)}") print(f"Expected topic distribution E[θ_d]:") for k in range(K): print(f" Topic {k}: {q_theta.expected_theta()[k]:.3f}") print(f"Expected log-probabilities E[log θ_d]:") for k in range(K): print(f" Topic {k}: {q_theta.expected_log_theta()[k]:.3f}") print(f"Entropy H[q(θ_d)] = {q_theta.entropy():.4f} nats") return q_theta def derive_multinomial_topic_word_update(): """ Example: Topic-word distribution update in LDA Model: φ_k ~ Dir(β) (prior on topic's word dist) w_{dn} | z_{dn}, φ ~ Cat(φ_{z_{dn}}) For mean-field q(φ_k) = Dir(λ_k): λ_{kv} = β_v + Σ_{d,n} q(z_{dn} = k) × 𝟙(w_{dn} = v) = β_v + (expected count of word v in topic k) """ print("" + "=" * 60) print("Dirichlet Update: LDA Topic-Word Distribution") print("=" * 60) K = 3 # topics V = 10 # vocabulary size (simplified) # Prior (symmetric) beta = np.ones(V) * 0.01 # Very sparse prior # Simulated expected counts: which words go with which topics # Topic 0: words 0, 1, 2 are common # Topic 1: words 3, 4, 5 are common # Topic 2: words 6, 7, 8, 9 are common expected_word_counts = np.array([ [50, 40, 45, 2, 3, 1, 1, 2, 1, 0], # Topic 0 [2, 3, 1, 60, 55, 48, 2, 1, 3, 1], # Topic 1 [1, 2, 1, 3, 2, 1, 45, 50, 40, 38], # Topic 2 ]) print(f"{K} topics, {V} vocabulary words") print(f"Prior β = {beta[0]} (same for all words)") # Update each topic for k in range(K): q_phi_k = DirichletVariationalFactor.from_prior_and_counts( beta, expected_word_counts[k] ) print(f"Topic {k}:") print(f" Top words by E[φ_kv]:", end=" ") top_words = np.argsort(q_phi_k.expected_theta())[::-1][:3] for w in top_words: print(f"word_{w} ({q_phi_k.expected_theta()[w]:.3f})", end=" ") print() if __name__ == "__main__": derive_lda_document_topic_update() derive_multinomial_topic_word_update()Gamma distributions are used for positive-valued parameters like precisions, rates, and scale parameters. They're conjugate to Poisson and exponential likelihoods, and appear in many Bayesian models.
Setting:
Suppose $\tau_j > 0$ (a positive-valued latent variable) with Gamma prior $\text{Gamma}(a_0, b_0)$. Typical terms in $\log p(\mathbf{x}, \mathbf{z})$ involving $\tau_j$:
$$\log p(\mathbf{x}, \mathbf{z}) = (a_0 - 1) \log \tau_j - b_0 \tau_j + c \log \tau_j - d \tau_j + \ldots$$
where $c$ and $d$ depend on other variables.
Derivation:
Taking expectations:
$$\mathbb{E}{q{-j}}[\log p(\mathbf{x}, \mathbf{z})] = (a_0 - 1 + \mathbb{E}[c]) \log \tau_j - (b_0 + \mathbb{E}[d]) \tau_j + \text{const}$$
This is the log of a Gamma distribution:
$$q_j^*(\tau_j) = \text{Gamma}(a_j, b_j)$$
where: $$a_j = a_0 + \mathbb{E}{q{-j}}[c], \quad b_j = b_0 + \mathbb{E}{q{-j}}[d]$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
import numpy as npfrom scipy.special import digamma, gammalnfrom dataclasses import dataclass @dataclassclass GammaVariationalFactor: """ Gamma variational factor q(τ) = Gamma(τ | a, b) Parameterized by shape a and rate b. Mean = a/b, Variance = a/b² """ shape: float # a > 0 rate: float # b > 0 @property def scale(self) -> float: return 1.0 / self.rate def expected_tau(self) -> float: """E[τ] = a/b""" return self.shape / self.rate def expected_log_tau(self) -> float: """E[log τ] = ψ(a) - log(b)""" return digamma(self.shape) - np.log(self.rate) def variance(self) -> float: """Var[τ] = a/b²""" return self.shape / (self.rate ** 2) def entropy(self) -> float: """ H[Gamma(a,b)] = a - log(b) + log(Γ(a)) + (1-a)ψ(a) """ return (self.shape - np.log(self.rate) + gammaln(self.shape) + (1 - self.shape) * digamma(self.shape)) @classmethod def from_expected_sufficient_stats( cls, prior_shape: float, prior_rate: float, E_c: float, # Expected contribution to shape E_d: float # Expected contribution to rate ) -> 'GammaVariationalFactor': """ Create Gamma factor from prior and expected sufficient statistics. If log p = (a₀-1)log τ - b₀τ + c log τ - dτ + const Then: a = a₀ + E[c] b = b₀ + E[d] """ shape = prior_shape + E_c rate = prior_rate + E_d return cls(shape=shape, rate=rate) def derive_precision_update(): """ Example: Bayesian Linear Regression precision update Model: τ ~ Gamma(a₀, b₀) (prior on noise precision) y_i | x_i, w, τ ~ N(xᵢᵀw, τ⁻¹) log p(y, w, τ) includes: (a₀-1) log τ - b₀ τ [prior] + (n/2) log τ - (τ/2) Σᵢ(yᵢ - xᵢᵀw)² [likelihood] So: c = n/2 d = (1/2) Σᵢ(yᵢ - xᵢᵀw)² With mean-field over w: E[d] = (1/2) Σᵢ E[(yᵢ - xᵢᵀw)²] """ print("Gamma Update: Bayesian Linear Regression Precision") print("=" * 60) # Data np.random.seed(42) n, d = 100, 3 X = np.random.randn(n, d) true_w = np.array([1.0, -0.5, 0.3]) true_tau = 4.0 y = X @ true_w + np.random.randn(n) / np.sqrt(true_tau) # Prior a_0, b_0 = 1.0, 1.0 # Current variational distribution for w # q(w) = N(m_w, S_w) m_w = np.array([0.95, -0.48, 0.32]) # Close to true S_w = np.eye(d) * 0.01 # Low variance # Compute expected sufficient statistics # c = n/2 (deterministic) E_c = n / 2 # d = (1/2) Σᵢ(yᵢ - xᵢᵀw)² # E[d] = (1/2) Σᵢ E[(yᵢ - xᵢᵀw)²] # = (1/2) Σᵢ [yᵢ² - 2yᵢ xᵢᵀ E[w] + E[wᵀ xᵢ xᵢᵀ w]] # = (1/2) Σᵢ [yᵢ² - 2yᵢ xᵢᵀ m_w + xᵢᵀ (m_w m_wᵀ + S_w) xᵢ] E_d = 0 for i in range(n): x_i = X[i] y_i = y[i] E_residual_sq = (y_i**2 - 2 * y_i * x_i @ m_w + x_i @ (np.outer(m_w, m_w) + S_w) @ x_i) E_d += 0.5 * E_residual_sq print(f"Data: n={n}, d={d}") print(f"Prior: Gamma({a_0}, {b_0})") print(f"True precision τ = {true_tau}") print() print(f"Expected sufficient statistics:") print(f" E[c] = n/2 = {E_c:.1f}") print(f" E[d] = E[sum of squared residuals]/2 = {E_d:.2f}") # Update q_tau = GammaVariationalFactor.from_expected_sufficient_stats( a_0, b_0, E_c, E_d ) print(f"Optimal q*(τ) = Gamma(a, b):") print(f" a = a₀ + E[c] = {q_tau.shape:.2f}") print(f" b = b₀ + E[d] = {q_tau.rate:.2f}") print() print(f"Expected precision E[τ] = {q_tau.expected_tau():.3f}") print(f"True precision = {true_tau}") print(f"Standard deviation: {np.sqrt(q_tau.variance()):.3f}") return q_tau def derive_poisson_rate_update(): """ Example: Poisson-Gamma conjugacy Model: λ ~ Gamma(a₀, b₀) (prior on rate) x_i ~ Poisson(λ) (observations) log p(x, λ) = (a₀-1) log λ - b₀ λ + Σᵢ[xᵢ log λ - λ] = (a₀ - 1 + Σxᵢ) log λ - (b₀ + n) λ Posterior: λ | x ~ Gamma(a₀ + Σxᵢ, b₀ + n) """ print("" + "=" * 60) print("Gamma Update: Poisson Rate") print("=" * 60) # Data: Poisson observations np.random.seed(42) true_lambda = 5.0 n = 50 x = np.random.poisson(true_lambda, n) # Prior a_0, b_0 = 1.0, 1.0 # For Poisson: c = Σxᵢ, d = n E_c = np.sum(x) E_d = n print(f"Data: {n} Poisson observations") print(f"True rate λ = {true_lambda}") print(f"Sample mean = {np.mean(x):.2f}") print(f"Sum of observations = {E_c}") q_lambda = GammaVariationalFactor.from_expected_sufficient_stats( a_0, b_0, E_c, E_d ) print(f"Posterior Gamma({q_lambda.shape:.1f}, {q_lambda.rate:.1f}):") print(f" E[λ] = {q_lambda.expected_tau():.3f}") print(f" Std[λ] = {np.sqrt(q_lambda.variance()):.3f}") if __name__ == "__main__": derive_precision_update() derive_poisson_rate_update()This page has shown how to derive mean-field update equations for the most common exponential family distributions. Here's a quick reference for future use:
| Factor Type | Prior Form | Update Rule | Key Expectations |
|---|---|---|---|
| Gaussian | N(μ₀, σ₀²) | σ² = 1/E[a], μ = σ²E[b] | E[z], E[z²] = μ² + σ² |
| Categorical | Cat(π) | r_k ∝ exp(E[f_k]) | E[𝟙(z=k)] = r_k |
| Dirichlet | Dir(α₀) | α_k = α₀_k + E[n_k] | E[log θ_k] = ψ(α_k) - ψ(Σα) |
| Gamma | Gam(a₀, b₀) | a = a₀+E[c], b = b₀+E[d] | E[τ] = a/b, E[log τ] = ψ(a)-log(b) |
| Beta | Beta(α₀, β₀) | α = α₀+E[n₁], β = β₀+E[n₀] | E[log θ] = ψ(α)-ψ(α+β) |
| Wishart | W(W₀, ν₀) | W = (W₀⁻¹+E[S])⁻¹, ν = ν₀+E[n] | E[Λ] = νW |
Every exponential family update follows the same pattern: (1) Write log p in terms of z_j, (2) Identify coefficients of sufficient statistics, (3) Take expectations over q_{-j}, (4) Recognize the distribution family, (5) Read off parameters. With practice, this becomes mechanical.
With update equations in hand, the next page examines convergence—how CAVI finds its solution and what guarantees we have. We'll explore convergence rates, initialization strategies, and when mean-field performs well versus poorly.
You now have a comprehensive reference for deriving mean-field update equations. The exponential family structure makes these derivations systematic: identify sufficient statistics, take expectations, recognize the family, read off parameters. This knowledge enables you to implement CAVI for virtually any exponential family model.