Loading learning content...
The Laplace approximation, while elegant, has a fundamental limitation: it captures only local information at the posterior mode. If the true posterior is skewed or has significant probability mass away from the mode, Laplace will mischaracterize it.
Expectation Propagation (EP) takes a fundamentally different approach. Rather than approximating at a single point, EP finds the Gaussian that best matches the moments (mean and variance) of the true posterior. This global matching often yields superior predictive performance, particularly when the posterior marginals are asymmetric.
EP achieves this through a clever message-passing scheme that iteratively refines the approximation, one likelihood term at a time. The result is an algorithm that, while more complex than Laplace, often provides the best trade-off between accuracy and computational cost for GP classification.
By the end of this page, you will understand: (1) The EP philosophy of moment matching, (2) The site approximation and how it factorizes the posterior, (3) The EP update equations for GP classification, (4) The connection to message passing on factor graphs, (5) Convergence properties and practical considerations, and (6) When EP outperforms Laplace and when it struggles.
Expectation Propagation is a deterministic approximate inference algorithm developed by Tom Minka in 2001. Its core idea is deceptively simple:
Instead of approximating the posterior at a single point, approximate it by a distribution that matches the moments of the true posterior.
The Setup:
Recall the posterior:
$$p(\mathbf{f} | \mathbf{y}, X) \propto p(\mathbf{f} | X) \prod_{i=1}^n p(y_i | f_i)$$
The prior $p(\mathbf{f} | X)$ is Gaussian. Each likelihood term $p(y_i | f_i)$ is a sigmoid (non-Gaussian). EP approximates each likelihood term with an (unnormalized) Gaussian 'site':
$$p(y_i | f_i) \approx \tilde{Z}_i \cdot \mathcal{N}(f_i | \tilde{\mu}_i, \tilde{\sigma}_i^2) \equiv \tilde{t}_i(f_i)$$
where $\tilde{Z}_i$, $\tilde{\mu}_i$, and $\tilde{\sigma}_i^2$ are the site parameters to be determined.
The approximate posterior is then:
$$q(\mathbf{f}) \propto p(\mathbf{f} | X) \prod_i \tilde{t}_i(f_i)$$
Since the product of Gaussians is Gaussian, $q(\mathbf{f})$ is Gaussian.
EP doesn't directly approximate the posterior—it approximates each likelihood term separately, then combines these approximations with the prior. This divide-and-conquer strategy is both computationally tractable and often more accurate than single-point approximations.
The Moment Matching Principle:
EP chooses the site approximations to minimize the KL divergence from the true posterior to the approximate posterior:
$$\text{minimize } \text{KL}\left(p(\mathbf{f} | \mathbf{y}, X) ,||, q(\mathbf{f})\right)$$
However, this global optimization is intractable. EP instead performs local moment matching: for each site $i$, it ensures that the marginal $q(f_i)$ matches the first two moments of what you'd get if you used the true likelihood for site $i$.
Why Moment Matching?
When approximating a distribution $p$ with a Gaussian $q$, matching the mean and variance of $q$ to those of $p$ minimizes $\text{KL}(p || q)$ among all Gaussians. This is the inclusive KL or moment projected approximation:
The core data structure in EP is the site approximation—an unnormalized Gaussian that approximates each likelihood term.
Site Parameterization:
For computational reasons, we parameterize sites in the natural (canonical) form:
$$\tilde{t}_i(f_i) \propto \exp\left(-\frac{1}{2}\tilde{\tau}_i f_i^2 + \tilde{\nu}_i f_i\right)$$
where:
The approximate posterior then has precision and natural mean:
$$\mathbf{\Sigma}^{-1} = K^{-1} + \text{diag}(\tilde{\boldsymbol{\tau}})$$ $$\mathbf{\Sigma}^{-1}\boldsymbol{\mu} = K^{-1}\mathbf{m} + \tilde{\boldsymbol{\nu}}$$
where $\mathbf{m}$ is the prior mean (often zero).
Cavity Distribution:
A key concept in EP is the cavity distribution for site $i$. This is what the approximate posterior would look like if we removed site $i$:
$$q_{-i}(f_i) = \frac{q(f_i)}{\tilde{t}_i(f_i)}$$
In natural parameters:
$$\tau_{-i} = \tau_i - \tilde{\tau}i$$ $$\nu{-i} = \nu_i - \tilde{\nu}_i$$
where $\tau_i = 1/\text{Var}_q[f_i]$ and $\nu_i = \mathbb{E}_q[f_i]/\text{Var}_q[f_i]$ are the marginal parameters of the full posterior.
The cavity represents our knowledge about $f_i$ from all factors except the likelihood at point $i$.
Tilted Distribution:
The tilted distribution combines the cavity with the true likelihood:
$$\hat{p}(f_i) \propto p(y_i | f_i) \cdot q_{-i}(f_i)$$
This is the distribution whose moments we want to match.
EP works by cycling through sites. For each site: (1) Remove its contribution to get the cavity, (2) Add back the TRUE likelihood to get the tilted distribution, (3) Compute the tilted moments, (4) Update the site so that the new approximate posterior matches these moments.
Let's derive the EP update equations step by step.
Given: Current approximate posterior marginal $q(f_i) = \mathcal{N}(f_i | \mu_i, \sigma_i^2)$ and site $\tilde{t}_i(f_i)$.
Step 1: Compute the Cavity
$$\tau_{-i} = \sigma_i^{-2} - \tilde{\tau}i$$ $$\nu{-i} = \mu_i \sigma_i^{-2} - \tilde{\nu}_i$$
The cavity mean and variance: $$\mu_{-i} = \nu_{-i} / \tau_{-i}$$ $$\sigma_{-i}^2 = 1 / \tau_{-i}$$
Step 2: Compute Tilted Moments
For the tilted distribution $\hat{p}(f_i) \propto p(y_i | f_i) \mathcal{N}(f_i | \mu_{-i}, \sigma_{-i}^2)$, compute:
$$\hat{Z}i = \int p(y_i | f_i) \mathcal{N}(f_i | \mu{-i}, \sigma_{-i}^2) df_i$$ $$\hat{\mu}i = \frac{1}{\hat{Z}i} \int f_i \cdot p(y_i | f_i) \mathcal{N}(f_i | \mu{-i}, \sigma{-i}^2) df_i$$ $$\hat{\sigma}_i^2 = \frac{1}{\hat{Z}i} \int (f_i - \hat{\mu}i)^2 \cdot p(y_i | f_i) \mathcal{N}(f_i | \mu{-i}, \sigma{-i}^2) df_i$$
Step 3: Update Site Parameters
Set the new site parameters so that the resulting approximate posterior marginal matches $\hat{\mu}_i, \hat{\sigma}_i^2$:
$$\tilde{\tau}_i^{\text{new}} = \hat{\sigma}i^{-2} - \tau{-i}$$ $$\tilde{\nu}_i^{\text{new}} = \hat{\mu}_i \hat{\sigma}i^{-2} - \nu{-i}$$
Damping (Important for Stability):
In practice, we often use damping to prevent oscillations:
$$\tilde{\tau}_i \leftarrow (1-\epsilon)\tilde{\tau}_i^{\text{old}} + \epsilon \tilde{\tau}_i^{\text{new}}$$
where $\epsilon \in (0, 1]$ is the damping factor (typically 0.5-0.9).
Step 4: Update Global Posterior
After updating all sites (or a subset), recompute the global approximate posterior:
$$\mathbf{\Sigma} = (K^{-1} + \tilde{T})^{-1}$$ $$\boldsymbol{\mu} = \mathbf{\Sigma}(K^{-1}\mathbf{m} + \tilde{\boldsymbol{\nu}})$$
where $\tilde{T} = \text{diag}(\tilde{\boldsymbol{\tau}})$.
The cavity precision τ₋ᵢ can become negative (or the tilted variance can exceed the cavity variance), leading to negative site precisions. This indicates instability. Solutions: (1) Skip the update for this site, (2) Increase damping, (3) Reduce the site precision. Robust implementations must handle these edge cases.
The key computational challenge in EP is computing the moments of the tilted distribution. For the probit likelihood, closed-form solutions exist. For the logistic likelihood, we need numerical approximations.
Probit Likelihood (Closed Form):
For $p(y | f) = \Phi(yf)$ where $y \in {-1, +1}$ and $\Phi$ is the standard normal CDF:
$$\hat{Z}_i = \Phi(z_i)$$
where $z_i = \frac{y_i \mu_{-i}}{\sqrt{1 + \sigma_{-i}^2}}$
$$\hat{\mu}i = \mu{-i} + \frac{y_i \sigma_{-i}^2 \cdot \phi(z_i)}{\sqrt{1 + \sigma_{-i}^2} \cdot \Phi(z_i)}$$
$$\hat{\sigma}i^2 = \sigma{-i}^2 - \frac{\sigma_{-i}^4 \cdot \phi(z_i)}{(1 + \sigma_{-i}^2)\Phi(z_i)}\left(z_i + \frac{\phi(z_i)}{\Phi(z_i)}\right)$$
where $\phi$ is the standard normal PDF.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
import numpy as npfrom scipy.stats import normfrom scipy.special import log_ndtr def compute_tilted_moments_probit(y, mu_cavity, sigma_cavity): """ Compute tilted distribution moments for probit likelihood. Parameters: ----------- y : label in {-1, +1} mu_cavity : cavity mean sigma_cavity : cavity standard deviation Returns: -------- Z_hat : normalizing constant mu_hat : tilted mean sigma_hat : tilted standard deviation """ # Standardized argument s = np.sqrt(1 + sigma_cavity**2) z = y * mu_cavity / s # Use log computations for numerical stability log_Z = log_ndtr(z) # log Φ(z) Z_hat = np.exp(log_Z) # Ratio φ(z)/Φ(z) - the inverse Mills ratio # Use numerically stable computation ratio = np.exp(norm.logpdf(z) - log_Z) # Tilted mean mu_hat = mu_cavity + y * sigma_cavity**2 * ratio / s # Tilted variance sigma_hat_sq = sigma_cavity**2 - sigma_cavity**4 * ratio / (1 + sigma_cavity**2) * (z + ratio) sigma_hat = np.sqrt(np.maximum(sigma_hat_sq, 1e-10)) return Z_hat, mu_hat, sigma_hat def compute_tilted_moments_logistic(y, mu_cavity, sigma_cavity, n_gauss_hermite=20): """ Compute tilted distribution moments for logistic likelihood. Uses Gauss-Hermite quadrature. Parameters: ----------- y : label in {0, 1} mu_cavity : cavity mean sigma_cavity : cavity standard deviation n_gauss_hermite : number of quadrature points Returns: -------- Z_hat : normalizing constant mu_hat : tilted mean sigma_hat : tilted standard deviation """ # Gauss-Hermite quadrature points and weights points, weights = np.polynomial.hermite.hermgauss(n_gauss_hermite) # Transform points: x = μ + σ√2 * t (for Hermite quadrature) f_points = mu_cavity + sigma_cavity * np.sqrt(2) * points # Logistic likelihood: σ(f)^y * (1-σ(f))^(1-y) pi = 1 / (1 + np.exp(-f_points)) likelihood = pi**y * (1 - pi)**(1 - y) # Weights for the transformed integral (1/√π factor from Hermite) w = weights / np.sqrt(np.pi) # Compute moments Z_hat = np.sum(w * likelihood) mu_hat = np.sum(w * likelihood * f_points) / Z_hat var_hat = np.sum(w * likelihood * (f_points - mu_hat)**2) / Z_hat sigma_hat = np.sqrt(np.maximum(var_hat, 1e-10)) return Z_hat, mu_hat, sigma_hat # Test with example valuesy_pm = 1 # {-1,+1} labely_01 = 1 # {0,1} labelmu_cav = 0.5sigma_cav = 1.0 Z_p, mu_p, sigma_p = compute_tilted_moments_probit(y_pm, mu_cav, sigma_cav)Z_l, mu_l, sigma_l = compute_tilted_moments_logistic(y_01, mu_cav, sigma_cav) print(f"Probit: Z={Z_p:.4f}, μ={mu_p:.4f}, σ={sigma_p:.4f}")print(f"Logistic: Z={Z_l:.4f}, μ={mu_l:.4f}, σ={sigma_l:.4f}")Logistic Likelihood (Numerical):
For the logistic likelihood $p(y | f) = \sigma(f)^y(1-\sigma(f))^{1-y}$, the tilted moments must be computed numerically:
Gauss-Hermite quadrature: Approximate the integral using weighted sum over quadrature points (typically 20-30 points suffice).
Series expansion: Approximate the logistic with a sum of probit functions (e.g., using the identity $\sigma(f) \approx \Phi(\lambda f)$ for $\lambda \approx 1.7$).
Numerical integration: Use adaptive quadrature (scipy.integrate.quad).
Gauss-Hermite quadrature is most common due to its efficiency and accuracy for this specific integral structure.
Let's consolidate the EP algorithm for GP classification.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
import numpy as npfrom scipy.linalg import cholesky, solve_triangular def ep_gpc(K, y, max_iter=100, tol=1e-4, damping=0.5, m=None): """ Expectation Propagation for GP Classification. Parameters: ----------- K : kernel matrix (n, n) y : labels in {0, 1} max_iter : maximum EP iterations tol : convergence tolerance damping : damping factor (0, 1] m : prior mean (default zero) Returns: -------- mu : posterior mean Sigma_diag : posterior marginal variances site_tau : site precisions site_nu : site natural means log_Z : approximate log marginal likelihood """ n = len(y) if m is None: m = np.zeros(n) # Convert labels to {-1, +1} for probit y_pm = 2 * y - 1 # Initialize sites (uniform = zero precision/mean) site_tau = np.zeros(n) site_nu = np.zeros(n) # Initialize posterior = prior Sigma = K.copy() mu = m.copy() log_Z_terms = np.zeros(n) for iteration in range(max_iter): site_tau_old = site_tau.copy() for i in range(n): # Get current marginal sigma_i_sq = Sigma[i, i] sigma_i = np.sqrt(sigma_i_sq) mu_i = mu[i] # Compute cavity tau_i = 1 / sigma_i_sq nu_i = mu_i / sigma_i_sq tau_cavity = tau_i - site_tau[i] nu_cavity = nu_i - site_nu[i] # Check for negative cavity precision (numerical issue) if tau_cavity <= 0: continue sigma_cavity = 1 / np.sqrt(tau_cavity) mu_cavity = nu_cavity / tau_cavity # Compute tilted moments (probit likelihood) Z_hat, mu_hat, sigma_hat = compute_tilted_moments_probit( y_pm[i], mu_cavity, sigma_cavity ) log_Z_terms[i] = np.log(np.maximum(Z_hat, 1e-300)) # Compute new site parameters tau_hat = 1 / sigma_hat**2 nu_hat = mu_hat / sigma_hat**2 site_tau_new = tau_hat - tau_cavity site_nu_new = nu_hat - nu_cavity # Check for negative site precision if site_tau_new < 0: continue # Apply damping site_tau[i] = (1 - damping) * site_tau[i] + damping * site_tau_new site_nu[i] = (1 - damping) * site_nu[i] + damping * site_nu_new # Update global posterior # Σ = (K⁻¹ + T̃)⁻¹ where T̃ = diag(site_tau) # Use Woodbury identity for efficiency sqrt_tau = np.sqrt(np.maximum(site_tau, 1e-10)) B = np.eye(n) + np.outer(sqrt_tau, sqrt_tau) * K L = cholesky(B + 1e-8 * np.eye(n), lower=True) # Sigma = K - K sqrt(τ) B⁻¹ sqrt(τ) K V = solve_triangular(L, np.diag(sqrt_tau) @ K, lower=True) Sigma = K - V.T @ V # mu = Σ(K⁻¹m + ν̃) # Simplified when m = 0: mu = Σ ν̃ mu = Sigma @ site_nu # Check convergence if np.max(np.abs(site_tau - site_tau_old)) < tol: print(f"EP converged after {iteration + 1} iterations") break else: print(f"EP did not converge after {max_iter} iterations") # Approximate log marginal likelihood log_Z = np.sum(log_Z_terms) # Simplified; full version has more terms Sigma_diag = np.diag(Sigma) return mu, Sigma_diag, site_tau, site_nu, log_Z def compute_tilted_moments_probit(y, mu_cavity, sigma_cavity): """Same as before - compute tilted moments for probit.""" from scipy.special import log_ndtr from scipy.stats import norm s = np.sqrt(1 + sigma_cavity**2) z = y * mu_cavity / s log_Z = log_ndtr(z) Z_hat = np.exp(log_Z) # Inverse Mills ratio ratio = np.exp(norm.logpdf(z) - log_Z) mu_hat = mu_cavity + y * sigma_cavity**2 * ratio / s sigma_hat_sq = sigma_cavity**2 - sigma_cavity**4 * ratio / (1 + sigma_cavity**2) * (z + ratio) sigma_hat = np.sqrt(np.maximum(sigma_hat_sq, 1e-10)) return Z_hat, mu_hat, sigma_hatMaking predictions with EP follows the same structure as Laplace, but uses the EP posterior.
Latent Predictive Distribution:
For a test point $\mathbf{x}_*$:
$$p(f_* | \mathbf{y}, X, \mathbf{x}*) \approx \mathcal{N}(f* | \mu_, \sigma_^2)$$
with:
$$\mu_* = \mathbf{k}*^\top (K + \tilde{T}^{-1})^{-1}(\boldsymbol{\mu} - \mathbf{m}) + m(\mathbf{x}*)$$
Simplifying (using $\boldsymbol{\mu} = \Sigma\tilde{\boldsymbol{\nu}}$ when $\mathbf{m} = 0$):
$$\mu_* = \mathbf{k}_*^\top \tilde{\boldsymbol{\nu}}$$
(This uses the identity $(K + \tilde{T}^{-1})^{-1} K = \Sigma \tilde{T}$ and $\Sigma \tilde{\boldsymbol{\nu}} = \boldsymbol{\mu}$.)
The variance:
$$\sigma_^2 = k(\mathbf{x}_, \mathbf{x}*) - \mathbf{k}^\top (K + \tilde{T}^{-1})^{-1} \mathbf{k}_$$
Class Probability:
For probit likelihood:
$$\pi_* = \Phi\left(\frac{\mu_}{\sqrt{1 + \sigma_^2}}\right)$$
For logistic likelihood, use the probit approximation or numerical integration:
$$\pi_* \approx \sigma(\kappa \mu_) \quad \text{where } \kappa = (1 + \pi\sigma_^2/8)^{-1/2}$$
EP Marginal Likelihood:
The approximate marginal likelihood from EP is:
$$\log q(\mathbf{y} | X) = \sum_i \log \hat{Z}_i - \frac{1}{2}\log|B| - \frac{1}{2}\tilde{\boldsymbol{\nu}}^\top\Sigma\tilde{\boldsymbol{\nu}} + \frac{1}{2}\sum_i \left(\frac{\tilde{\nu}_i^2}{\tilde{\tau}_i} - \log\tilde{\tau}_i\right) + \text{const}$$
This is more complex than Laplace but often more accurate for hyperparameter optimization.
When using EP for hyperparameter optimization, be aware that the marginal likelihood can have discontinuities if the algorithm converges to different solutions for different hyperparameter values. Use multiple random restarts and smooth the optimization landscape if needed.
Convergence Properties:
EP does not have guaranteed convergence for general factor graphs. However, for log-concave likelihoods (including logistic and probit), EP typically converges in practice.
Convergence Tips:
Comparison with Laplace:
| Aspect | Laplace | EP |
|---|---|---|
| Convergence | Guaranteed | Not guaranteed |
| Implementation | Simple | Complex |
| Accuracy (small n) | Good | Better |
| Accuracy (large n) | Good | Slightly better |
| Speed | Fast | Slower (more iterations) |
| Marginal likelihood | Biased | Less biased |
EP can fail to converge when: (1) Data is very noisy with overlapping classes, (2) Kernel hyperparameters are poorly chosen, (3) The posterior is highly multimodal (shouldn't happen for standard likelihoods). If EP fails, fall back to Laplace or use stronger damping.
Expectation Propagation offers a more sophisticated approach to GP classification inference, often yielding superior results to Laplace at the cost of increased complexity.
What's Next:
The next page introduces Variational Inference for GP classification—an approach that optimizes a lower bound on the marginal likelihood. Variational methods are particularly important for scalable GP classification, where inducing points make large-scale problems tractable.
You now understand Expectation Propagation for GP classification—from the site approximation framework, through the moment matching updates, to practical considerations for convergence. EP often provides the best accuracy for moderate-sized problems.