Loading learning content...
We've seen that standard variational inference minimizes the reverse KL divergence, KL(q || p), which is mode-seeking—it finds a q that concentrates on a single mode of the posterior and tends to underestimate variance. But what if we need better calibrated uncertainties? What if the posterior has multiple modes that all matter?
Expectation Propagation (EP) takes a fundamentally different approach. Instead of optimizing a global approximation, EP breaks the posterior into factors and matches moments (means and variances) locally. The result often provides better uncertainty estimates, especially for models where the posterior is non-Gaussian.
EP was developed by Tom Minka in 2001 and has become a workhorse for probabilistic models in machine learning, particularly for Gaussian process classification and Bayesian neural ranking systems.
By the end of this page, you will understand how EP approximates intractable factors with tractable exponential family distributions, derive the moment-matching updates that drive the algorithm, recognize when EP outperforms variational inference, and appreciate the connections between EP, Assumed Density Filtering, and belief propagation.
Consider a posterior that factorizes as a product of terms:
$$p(\theta|D) \propto p(\theta) \prod_{i=1}^n f_i(\theta)$$
where p(θ) is the prior and each f_i(θ) represents the i-th data point's contribution to the likelihood. The full product is intractable, but each factor might be individually manageable.
EP's strategy:
The key insight is that we don't need to approximate the full posterior directly—we approximate each factor and let the product of approximations form the global approximation.
EP works with exponential family distributions because they're characterized by sufficient statistics that can be easily matched through moment conditions. For Gaussians, matching first and second moments is sufficient; for other families, we match the expected sufficient statistics.
The exponential family form:
An exponential family distribution has the form:
$$q(\theta|\eta) = h(\theta) \exp(\eta^T \phi(\theta) - A(\eta))$$
where:
For a Gaussian: φ(θ) = (θ, θ²), η = (μ/σ², -1/(2σ²)), and expectations of sufficient statistics give the moments.
Let's develop the EP update equations step by step. We want to approximate:
$$p(\theta|D) \propto \underbrace{p(\theta)}{\text{prior}} \prod{i=1}^n \underbrace{f_i(\theta)}_{\text{likelihood factors}}$$
with:
$$q(\theta) \propto p(\theta) \prod_{i=1}^n \tilde{f}_i(\theta)$$
where each f̃_i is a tractable approximation to f_i.
Step 1: Define the cavity distribution
To update f̃_i, we first compute the cavity distribution—the approximation with factor i removed:
$$q_{\backslash i}(\theta) \propto \frac{q(\theta)}{\tilde{f}i(\theta)} = p(\theta) \prod{j eq i} \tilde{f}_j(\theta)$$
For Gaussian q and f̃_i (both in natural parameterization), this is simple subtraction of natural parameters.
Step 2: Form the tilted distribution
Combine the cavity with the true factor f_i:
$$\hat{p}i(\theta) \propto q{\backslash i}(\theta) f_i(\theta)$$
This "tilted" distribution represents our best guess at the posterior if we had the true factor f_i combined with the current approximation of everything else.
Step 3: Moment matching (projection)
Project the tilted distribution back into the approximating family by matching moments:
$$q^{\text{new}}(\theta) = \text{proj}[\hat{p}_i(\theta)]$$
For Gaussians, this means:
$$\mathbb{E}{q^{\text{new}}}[\theta] = \mathbb{E}{\hat{p}i}[\theta]$$ $$\mathbb{E}{q^{\text{new}}}[\theta\theta^T] = \mathbb{E}_{\hat{p}_i}[\theta\theta^T]$$
Step 4: Update the site approximation
Recover the new f̃_i from the updated global approximation:
$$\tilde{f}i^{\text{new}}(\theta) \propto \frac{q^{\text{new}}(\theta)}{q{\backslash i}(\theta)}$$
This completes one EP update. Iterate over all factors until convergence.
EP updates can be unstable, especially when the tilted distribution has much larger variance than the cavity. Damping—interpolating between old and new natural parameters—is often necessary: η_new = (1-ε)η_old + ε·η_update for some ε ∈ (0, 1].
Let's work through EP for the most common case: Gaussian approximations. Suppose the prior and all site approximations are Gaussian:
$$p(\theta) = \mathcal{N}(\theta | \mu_0, \Sigma_0)$$ $$\tilde{f}_i(\theta) = \mathcal{N}(\theta | \mu_i, \Sigma_i) \cdot Z_i$$
where Z_i is a normalization constant (site normalizers contribute to marginal likelihood).
Natural parameterization for Gaussians:
Gaussians are more naturally manipulated in terms of:
Products of Gaussians correspond to addition of natural parameters:
$$\mathcal{N}(\theta|\mu_1, \Sigma_1) \cdot \mathcal{N}(\theta|\mu_2, \Sigma_2) \propto \mathcal{N}(\theta|\mu_{1+2}, \Sigma_{1+2})$$
where Λ₁₊₂ = Λ₁ + Λ₂ and ν₁₊₂ = ν₁ + ν₂.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
import numpy as np class GaussianSite: """A site approximation in natural parameterization.""" def __init__(self, dim): self.precision = np.zeros((dim, dim)) # Λ self.precision_mean = np.zeros(dim) # ν = Λμ self.log_z = 0.0 # Log normalization constant def to_moment_params(self): """Convert to mean and covariance.""" if np.allclose(self.precision, 0): return None, None cov = np.linalg.inv(self.precision) mean = cov @ self.precision_mean return mean, cov class GaussianEP: """ Expectation Propagation with Gaussian approximations. """ def __init__(self, prior_mean, prior_cov, log_factor_fn, n_factors): """ Parameters: ----------- prior_mean : ndarray Prior mean μ₀ prior_cov : ndarray Prior covariance Σ₀ log_factor_fn : callable Function (i, theta) -> log f_i(θ) n_factors : int Number of likelihood factors """ self.dim = len(prior_mean) self.n_factors = n_factors self.log_factor_fn = log_factor_fn # Prior in natural params self.prior_precision = np.linalg.inv(prior_cov) self.prior_prec_mean = self.prior_precision @ prior_mean # Initialize site approximations self.sites = [GaussianSite(self.dim) for _ in range(n_factors)] # Global approximation (sum of prior + all sites) self._update_global() def _update_global(self): """Compute global q from prior + all sites.""" self.global_precision = self.prior_precision.copy() self.global_prec_mean = self.prior_prec_mean.copy() self.log_z_global = 0.0 for site in self.sites: self.global_precision += site.precision self.global_prec_mean += site.precision_mean self.log_z_global += site.log_z def get_cavity(self, i): """Compute cavity distribution q_{\i}.""" cavity_precision = self.global_precision - self.sites[i].precision cavity_prec_mean = self.global_prec_mean - self.sites[i].precision_mean return cavity_precision, cavity_prec_mean def update_site(self, i, n_samples=1000, damping=0.5): """ Update site i via moment matching. """ # Get cavity distribution cavity_prec, cavity_prec_mean = self.get_cavity(i) # Check cavity is valid (positive definite) try: cavity_cov = np.linalg.inv(cavity_prec) cavity_mean = cavity_cov @ cavity_prec_mean except np.linalg.LinAlgError: return False # Skip update if cavity is invalid # Sample from tilted distribution via importance sampling # Proposal: cavity distribution samples = np.random.multivariate_normal(cavity_mean, cavity_cov, n_samples) # Importance weights: f_i(θ) / proposal(θ) ∝ f_i(θ) log_weights = np.array([self.log_factor_fn(i, s) for s in samples]) log_weights -= np.max(log_weights) # Numerical stability weights = np.exp(log_weights) weights /= np.sum(weights) # Estimate moments of tilted distribution tilted_mean = np.sum(weights[:, None] * samples, axis=0) centered = samples - tilted_mean tilted_cov = sum(w * np.outer(c, c) for w, c in zip(weights, centered)) # New global approximation matches these moments try: new_global_prec = np.linalg.inv(tilted_cov) new_global_prec_mean = new_global_prec @ tilted_mean except np.linalg.LinAlgError: return False # Skip if tilted cov is singular # Extract new site parameters new_site_prec = new_global_prec - cavity_prec new_site_prec_mean = new_global_prec_mean - cavity_prec_mean # Apply damping self.sites[i].precision = (1 - damping) * self.sites[i].precision + damping * new_site_prec self.sites[i].precision_mean = (1 - damping) * self.sites[i].precision_mean + damping * new_site_prec_mean # Update global self._update_global() return True def fit(self, n_iters=10, damping=0.5): """Run EP until convergence.""" for iteration in range(n_iters): for i in range(self.n_factors): self.update_site(i, damping=damping) mean, cov = self.posterior_moments() print(f"Iter {iteration}: mean = {mean}, diag(cov) = {np.diag(cov)}") def posterior_moments(self): """Return mean and covariance of global approximation.""" cov = np.linalg.inv(self.global_precision) mean = cov @ self.global_prec_mean return mean, covThe moment-matching projection in EP is intimately connected to forward KL divergence minimization. This reveals why EP has different properties than VI.
Theorem: For exponential families, moment matching is equivalent to minimizing KL(p || q):
$$\text{proj}[p] = \arg\min_{q \in \mathcal{Q}} \text{KL}(p | q)$$
Proof sketch: The forward KL divergence is:
$$\text{KL}(p | q) = \mathbb{E}_p[\log p] - \mathbb{E}_p[\log q]$$
For exponential family q(θ) = h(θ)exp(η^T φ(θ) - A(η)):
$$\text{KL}(p | q) = \mathbb{E}_p[\log p] - \eta^T \mathbb{E}_p[\phi(\theta)] + A(\eta) + \text{const}$$
Setting derivative w.r.t. η to zero:
$$\frac{\partial A}{\partial \eta} = \mathbb{E}_p[\phi(\theta)]$$
But ∂A/∂η = 𝔼_q[φ(θ)] for exponential families. So the optimal q has:
$$\mathbb{E}_q[\phi(\theta)] = \mathbb{E}_p[\phi(\theta)]$$
This is exactly moment matching! □
Unlike VI, EP doesn't minimize a single global objective. It performs local moment matching at each step, which can lead to different final approximations depending on update order. This lack of a global objective also means EP can fail to converge or oscillate.
EP's most important application is Gaussian Process Classification (GPC), where it provides accurate uncertainty estimates that Laplace approximation often gets wrong.
The GPC model:
where σ is the sigmoid function and f = (f(x₁), ..., f(x_n)).
The posterior over f:
$$p(f|y) \propto \mathcal{N}(f|0, K) \prod_{i=1}^n \underbrace{\sigma(y_i f_i)}_{t_i(f_i)}$$
The prior is Gaussian, but the likelihood factors t_i are sigmoid—not Gaussian. The product doesn't yield a closed-form posterior.
EP for GPC:
Approximate each sigmoid factor with a Gaussian site:
$$t_i(f_i) \approx \tilde{t}_i(f_i) = Z_i \mathcal{N}(f_i | \tilde{\mu}_i, \tilde{\sigma}_i^2)$$
The global approximation becomes:
$$q(f) = \mathcal{N}(f | \mu, \Sigma)$$
with analytically computable mean and covariance.
EP update for site i:
Cavity: Remove site i to get q_{\i}(f_i) = N(μ_{\i}, σ²_{\i})
Tilted distribution: p̂_i(f_i) ∝ N(f_i | μ_{\i}, σ²_{\i}) · σ(y_i f_i)
Moment matching: Compute 𝔼[f], 𝔼[f²] under tilted distribution (requires 1D numerical integration or approximation)
Update site: New site parameters from matched moments
| Aspect | EP Approximation | Laplace Approximation |
|---|---|---|
| Update strategy | Sequential moment matching | Find mode, compute Hessian |
| Local approximation | Moment match to sigmoid | Second-order Taylor at mode |
| Computational cost | O(n³) per full iteration | O(n³) for Hessian inverse |
| Variance calibration | Better calibrated | Tends to underestimate |
| Predictive uncertainty | More accurate for extremes | Overconfident on boundaries |
| Convergence | Not guaranteed | Guaranteed (convex optimization) |
For GP classification, EP typically provides better predictive probabilities, especially near the decision boundary where uncertainty matters most. Laplace often produces overconfident predictions. The computational cost is similar, so EP is often the better default choice.
Assumed Density Filtering (ADF) is the online, single-pass version of EP. Instead of iterating until convergence, ADF processes observations once in sequence.
The ADF algorithm:
Start with prior q₀(θ) = p(θ). For each observation i = 1, ..., n:
Compute tilted distribution: $$\hat{p}i(\theta) \propto q{i-1}(\theta) f_i(\theta)$$
Project back to approximating family: $$q_i(\theta) = \text{proj}[\hat{p}_i(\theta)]$$
After processing all observations, q_n(θ) is the approximate posterior.
ADF vs EP:
Power EP (α-divergence generalization):
Both EP and ADF can be generalized using α-divergences:
$$D_\alpha(p | q) = \frac{1}{\alpha(1-\alpha)}\left(1 - \int p(\theta)^\alpha q(\theta)^{1-\alpha} d\theta\right)$$
Power EP uses fractional updates that interpolate between VI and EP:
$$\tilde{f}i^{\alpha}(\theta) \propto \frac{q^{\text{new}}(\theta)^{1/\alpha}}{q{\backslash i}(\theta)^{1/\alpha - 1}}$$
This provides a continuous trade-off between mode-seeking (small α) and mean-seeking (large α) behavior.
Unlike VI (which optimizes a global objective), EP's convergence properties are more subtle. EP can oscillate, diverge, or produce negative variances if not carefully managed.
Practical remedies:
Damping: Interpolate between old and new site parameters: $$\eta_i^{\text{new}} = (1-\epsilon)\eta_i^{\text{old}} + \epsilon \cdot \eta_i^{\text{update}}$$ Smaller ε (0.1-0.5) improves stability at cost of slower convergence.
Parallel EP: Update all sites simultaneously, then project. More stable than sequential updates for some problems.
Double-loop EP: Inner loop: converge with fixed normalizers Z_i. Outer loop: update normalizers. Guarantees a fixed point exists.
Expectation Consistent EP: Add constraints ensuring marginal consistency, providing a well-defined objective function.
Natural gradient EP: Use natural gradients for updates, interpreting EP as gradient descent in natural parameter space.
If EP fails to converge after damping and other remedies, consider: (1) using a more flexible approximating family, (2) switching to VI which has guaranteed convergence, or (3) using MCMC for that portion of the model while using EP elsewhere.
EP has found success in several important machine learning applications beyond GP classification.
| Application | Why EP Works Well | Notable Systems |
|---|---|---|
| GP Classification | Better calibrated than Laplace | GPy, GPflow, scikit-learn |
| TrueSkill™ | Handles complex factor graphs for ranking | Xbox Live matchmaking |
| Sparse GP Regression | Moment matching for inducing points | FITC/VFE approximations |
| Probit Regression | 1D integrals are tractable | Bayesian GLMs |
| Gaussian Mixture Models | Approximate E-step in EM | Hybrid EM-EP algorithms |
| Neural Network Pruning | Uncertainty for parameter importance | Sparse Bayesian NNs |
Case Study: TrueSkill™
Microsoft's TrueSkill system for Xbox Live uses EP to infer player skills from game outcomes. The model includes:
The resulting factor graph has complex structure with shared variables (players in multiple games). EP efficiently handles this by:
EP's ability to handle non-Gaussian factors (win/loss) while maintaining Gaussian beliefs about skills makes it ideal for this application.
Choose EP when: (1) you need well-calibrated uncertainties, (2) factors are non-Gaussian but moment matching is tractable, (3) you're working with GP classification or ranking. Choose VI when: (1) you need guaranteed convergence, (2) you want a well-defined objective for model comparison, (3) you're working with neural networks (reparameterization is easier than moment matching).
Expectation Propagation provides a complementary approach to variational inference, using moment matching (forward KL) instead of mode seeking (reverse KL) to obtain approximate posteriors. While EP lacks VI's convergence guarantees, it often produces better calibrated uncertainties.
What's next:
We've now covered three deterministic approximation methods: Laplace (Gaussian at mode), VI (optimal in a family via ELBO), and EP (moment matching via forward KL). The next page compares deterministic vs. stochastic methods systematically, helping you choose the right approach for your inference problem.
You now understand Expectation Propagation—its moment-matching algorithm, its connection to forward KL, and its applications in GP classification and ranking systems. EP completes the trio of major deterministic approximation methods.