Loading content...
Standard gradient descent treats parameter space as a flat Euclidean manifold—moving in the direction of steepest descent measured by the ordinary \(L_2\) norm. But for probability distributions, this assumption is fundamentally flawed.
Consider a simple example: optimizing the mean \(\mu\) of a Gaussian distribution \(\mathcal{N}(\mu, \sigma^2)\). A step \(\Delta\mu = 0.1\) has vastly different effects depending on \(\sigma\):
Euclidean gradient descent is oblivious to this difference, taking the same step regardless of the distribution's spread. This leads to inefficient optimization: steps that are too aggressive in low-variance regions and too conservative in high-variance regions.
Natural gradients resolve this by measuring distances in the space of distributions using the Fisher Information Matrix, which captures the intrinsic geometry of probability distributions.
By the end of this page, you will understand the information-geometric foundations of natural gradients, the Fisher Information Matrix and its role in measuring distributional change, why natural gradients accelerate variational inference, and practical algorithms for computing and applying natural gradient updates.
To understand why natural gradients matter, we must first appreciate the limitations of standard (Euclidean) gradient descent when applied to probability distributions.
The optimization landscape is parameterization-dependent:
Consider optimizing a Bernoulli distribution with success probability \(p\). We could parameterize it in two equivalent ways:
Both parameterizations represent the same family of distributions, yet Euclidean gradient descent behaves dramatically differently:
$$ abla_\theta \mathcal{L} eq abla_\eta \mathcal{L} \cdot \frac{d\eta}{d\theta}$$
The optimization trajectory depends entirely on which parameterization we choose—even though the underlying optimization problem is identical. This is deeply unsatisfying: the "optimal" direction shouldn't depend on arbitrary coordinate choices.
A well-designed optimization algorithm should produce parameterization-invariant behavior. If two parameterizations describe the same model, the optimizer should trace equivalent paths through distribution space, regardless of which coordinates we use. Euclidean gradient descent fails this test.
Visualizing the problem:
Consider optimizing a 2D Gaussian with mean \(\boldsymbol{\mu}\) and covariance \(\boldsymbol{\Sigma}\). The ELBO landscape in parameter space may have the following properties:
Euclidean gradient descent sees these as equivalent directions and takes uniform steps, leading to:
The algorithm oscillates wildly in some dimensions while crawling in others—a symptom of ignoring the problem's geometry.
The core issue:
Euclidean distance in parameter space \(|\theta_1 - \theta_2|_2\) doesn't correspond to meaningful distance between distributions \(p(\cdot; \theta_1)\) and \(p(\cdot; \theta_2)\). We need a metric that measures distributional similarity.
Information geometry treats the space of probability distributions as a Riemannian manifold—a curved space where distances are measured using a metric tensor rather than simple Euclidean norms.
The KL divergence as a local distance:
The natural measure of dissimilarity between distributions is the Kullback-Leibler (KL) divergence:
$$D_{\text{KL}}[p(\cdot; \theta) | p(\cdot; \theta + \delta)] = \int p(x; \theta) \log \frac{p(x; \theta)}{p(x; \theta + \delta)} dx$$
For small perturbations \(\delta\), we can Taylor-expand the KL divergence:
$$D_{\text{KL}}[p(\cdot; \theta) | p(\cdot; \theta + \delta)] \approx \frac{1}{2} \delta^T \mathbf{F}(\theta) \delta + O(|\delta|^3)$$
where \(\mathbf{F}(\theta)\) is the Fisher Information Matrix.
The Fisher Information Matrix F(θ) measures the curvature of the KL divergence at θ. It tells us how sensitive the distribution is to parameter changes in each direction. Large eigenvalues indicate directions where small parameter changes cause large distributional changes.
Definition of the Fisher Information Matrix:
The Fisher Information Matrix is defined as:
$$\mathbf{F}(\theta) = \mathbb{E}{p(x; \theta)}\left[ abla\theta \log p(x; \theta) abla_\theta \log p(x; \theta)^T\right]$$
Equivalently, using the score function \(s(x; \theta) = abla_\theta \log p(x; \theta)\):
$$\mathbf{F}(\theta) = \mathbb{E}_{p}[s \cdot s^T] = \text{Cov}_p[s]$$
(The last equality holds because \(\mathbb{E}_p[s] = 0\) under regularity conditions.)
Alternative formulation:
Under appropriate regularity conditions, the Fisher Information also equals the negative expected Hessian of the log-likelihood:
$$\mathbf{F}(\theta) = -\mathbb{E}{p(x; \theta)}\left[ abla\theta^2 \log p(x; \theta)\right]$$
This connects the Fisher Information to curvature of the log-likelihood surface.
Properties of the Fisher Information Matrix:
With the Fisher Information Matrix defining the local geometry of distribution space, we can now formulate optimization that respects this geometry.
The natural gradient definition:
The natural gradient is defined as:
$$\tilde{ abla}\theta \mathcal{L} = \mathbf{F}(\theta)^{-1} abla\theta \mathcal{L}$$
where \( abla_\theta \mathcal{L}\) is the ordinary (Euclidean) gradient and \(\mathbf{F}(\theta)^{-1}\) is the inverse Fisher Information Matrix.
Intuition:
The natural gradient answers: "What parameter change produces the steepest increase in \(\mathcal{L}\) per unit of distributional change?"
Mathematically, natural gradient descent solves:
$$\tilde{ abla}\theta \mathcal{L} = \arg\max{\delta: |\delta|_F \leq \epsilon} \mathcal{L}(\theta + \delta)$$
where \(|\delta|_F^2 = \delta^T \mathbf{F}(\theta) \delta\) is the Fisher norm. We maximize improvement subject to a constraint on how much the distribution changes, not how much the parameters change.
Parameterization invariance:
A crucial property of natural gradient descent is that it produces the same trajectory through distribution space regardless of parameterization.
If \(\eta = f(\theta)\), then with Jacobian \(J = \partial\theta/\partial\eta\):
$$\tilde{ abla}\eta \mathcal{L} = \mathbf{F}\eta^{-1} abla_\eta \mathcal{L} = (J^T \mathbf{F}\theta J)^{-1} J^T abla\theta \mathcal{L}$$
The update \(\Delta\theta = \alpha \tilde{ abla}_\theta \mathcal{L}\) in one coordinate system corresponds to \(\Delta\eta = J^{-1} \Delta\theta\) in the other—they trace the same path through distributions.
This invariance is why natural gradients are fundamentally "right" for probabilistic optimization.
Natural gradients are particularly elegant for exponential family distributions—a class that includes Gaussians, Bernoullis, Poissons, Dirichlets, and many other common distributions.
Exponential family form:
An exponential family distribution has the form:
$$p(x; \boldsymbol{\eta}) = h(x) \exp\left(\boldsymbol{\eta}^T T(x) - A(\boldsymbol{\eta})\right)$$
where:
The Fisher Information in natural parameters:
For exponential families in natural parameterization, the Fisher Information Matrix takes a beautifully simple form:
$$\mathbf{F}(\boldsymbol{\eta}) = abla^2 A(\boldsymbol{\eta})$$
The Fisher Information is simply the Hessian of the log-partition function!
For exponential families, there's a duality between natural parameters η and mean parameters μ = E[T(x)] = ∇A(η). The Fisher Information in natural coordinates equals the inverse covariance of sufficient statistics: F(η) = Cov[T(x)]. Natural gradient descent in η-space is equivalent to ordinary gradient descent in μ-space!
Example: Gaussian distribution
Consider a univariate Gaussian \(\mathcal{N}(\mu, \sigma^2)\):
Standard parameterization: \(\theta = (\mu, \sigma^2)\)
$$\mathbf{F}(\mu, \sigma^2) = \begin{pmatrix} 1/\sigma^2 & 0 \ 0 & 1/(2\sigma^4) \end{pmatrix}$$
Natural parameterization: \(\eta = (\mu/\sigma^2, -1/(2\sigma^2))\)
The Fisher Information becomes the Hessian of: $$A(\eta_1, \eta_2) = -\frac{\eta_1^2}{4\eta_2} - \frac{1}{2}\log(-2\eta_2)$$
The practical implication:
When optimizing a Gaussian variational distribution:
This automatic scaling is why natural gradients converge faster—they adapt step sizes to the local geometry.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
import numpy as np def natural_gradient_gaussian(mu: float, sigma2: float, grad_mu: float, grad_sigma2: float): """ Compute natural gradient for univariate Gaussian. Standard parameterization: theta = (mu, sigma^2) Fisher Information Matrix: F = [[1/sigma^2, 0], [0, 1/(2*sigma^4)]] Natural gradient = F^{-1} @ euclidean_gradient """ # Fisher Information inverse F_inv = np.array([ [sigma2, 0], [0, 2 * sigma2**2] ]) # Euclidean gradient grad = np.array([grad_mu, grad_sigma2]) # Natural gradient natural_grad = F_inv @ grad return natural_grad[0], natural_grad[1] # (natural_grad_mu, natural_grad_sigma2) def natural_gradient_multivariate_gaussian( mu: np.ndarray, # Shape: (D,) Sigma: np.ndarray, # Shape: (D, D) grad_mu: np.ndarray, # Shape: (D,) grad_Sigma: np.ndarray # Shape: (D, D)) -> tuple: """ Natural gradient for multivariate Gaussian q(z) = N(mu, Sigma). For the mean parameter, natural gradient is: Sigma @ grad_mu For the covariance, it's more complex: 2 * Sigma @ grad_Sigma @ Sigma """ # Natural gradient w.r.t. mean natural_grad_mu = Sigma @ grad_mu # Natural gradient w.r.t. covariance # F_Sigma^{-1} acting on grad_Sigma gives: natural_grad_Sigma = 2 * Sigma @ grad_Sigma @ Sigma return natural_grad_mu, natural_grad_SigmaApplying natural gradients to variational inference yields powerful algorithms with convergence guarantees.
The VI setting:
We're optimizing the ELBO with respect to variational parameters \(\phi\) of the approximating distribution \(q(z; \phi)\):
$$\phi^* = \arg\max_\phi \mathcal{L}(\phi) = \arg\max_\phi \left{ \mathbb{E}_q[\log p(x, z)] + \mathcal{H}[q] \right}$$
The Fisher Information for this problem is:
$$\mathbf{F}q(\phi) = \mathbb{E}{q(z; \phi)}\left[ abla_\phi \log q(z; \phi) abla_\phi \log q(z; \phi)^T\right]$$
Note: This is the Fisher Information of the variational distribution \(q\), not the model \(p\).
Natural gradient update for VI:
$$\phi_{t+1} = \phi_t + \rho_t \mathbf{F}q(\phi_t)^{-1} abla\phi \mathcal{L}(\phi_t)$$
For conjugate exponential family models with mean-field variational families, natural gradient VI with step size ρ = 1 exactly recovers the coordinate ascent variational inference (CAVI) updates! This reveals CAVI as a special case of natural gradient optimization, explaining its often-rapid convergence.
The remarkable simplification:
For exponential family variational distributions, the natural gradient of the ELBO takes an elegant form. Let \(q(z; \boldsymbol{\eta})\) be in exponential family form with natural parameters \(\boldsymbol{\eta}\). Then:
$$\tilde{ abla}_\eta \mathcal{L} = \hat{\boldsymbol{\eta}} - \boldsymbol{\eta}$$
where \(\hat{\boldsymbol{\eta}}\) are the natural parameters that would be optimal if we ignored all other factors (essentially the "expected natural parameters" under the complete conditional).
The natural gradient update becomes:
$$\boldsymbol{\eta}_{t+1} = (1 - \rho_t) \boldsymbol{\eta}_t + \rho_t \hat{\boldsymbol{\eta}}_t$$
This is simply averaging between current and optimal parameters! With \(\rho_t = 1\), we jump directly to the coordinate-wise optimum; with \(\rho_t < 1\), we take a partial step.
Why this matters for SVI:
In stochastic variational inference, we estimate \(\hat{\boldsymbol{\eta}}\) from mini-batches. The natural gradient formulation:
$$\boldsymbol{\eta}_{t+1} = (1 - \rho_t) \boldsymbol{\eta}_t + \rho_t \hat{\boldsymbol{\eta}}_t^{(B)}$$
is a stochastic averaging that automatically provides the right updates without needing to invert Fisher Information matrices explicitly.
We now present the complete Stochastic Variational Inference with Natural Gradients algorithm for conjugate exponential family models.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
import numpy as npfrom typing import List, Tuple, Callable def svi_natural_gradient( data: np.ndarray, compute_local_optimal: Callable, # Returns optimal local params given global compute_global_update: Callable, # Returns target global natural params initial_global_params: np.ndarray, batch_size: int = 100, num_iterations: int = 10000, learning_rate_schedule: Callable = lambda t: (t + 1) ** (-0.5)) -> np.ndarray: """ Stochastic Variational Inference with Natural Gradients. For conjugate exponential family models with local/global latent variables. Args: data: Dataset of N observations compute_local_optimal: Function(global_params, observation) -> local_params Computes closed-form optimal local variational params compute_global_update: Function(global_params, local_params, batch) -> target_params Computes the target global natural parameters from mini-batch initial_global_params: Initial global variational natural parameters batch_size: Mini-batch size M num_iterations: Number of optimization iterations learning_rate_schedule: Function(iteration) -> learning_rate satisfying sum(rho_t) = infinity, sum(rho_t^2) < infinity Returns: Optimized global variational natural parameters """ N = len(data) global_params = initial_global_params.copy() for t in range(num_iterations): # Step 1: Sample mini-batch batch_indices = np.random.choice(N, size=batch_size, replace=False) batch = data[batch_indices] # Step 2: Compute optimal local variational parameters for each batch element local_params_list = [] for observation in batch: local_opt = compute_local_optimal(global_params, observation) local_params_list.append(local_opt) # Step 3: Compute the target global natural parameters # This is the "expected" optimal global params given these local params # Scaled by N/M to account for missing data points target_global = compute_global_update(global_params, local_params_list, batch) # Scale for mini-batch (the target should reflect full dataset) # For many models, this scaling is built into compute_global_update # Step 4: Natural gradient update (weighted average) rho_t = learning_rate_schedule(t) global_params = (1 - rho_t) * global_params + rho_t * target_global # Optional: Log progress if t % 1000 == 0: print(f"Iteration {t}, learning rate: {rho_t:.4f}") return global_params # ============================================# Example: SVI for Latent Dirichlet Allocation# ============================================ class LDA_SVI: """ Stochastic Variational Inference for Latent Dirichlet Allocation. Model: - K topics, V vocabulary size, D documents - Global: topic-word distributions beta_k ~ Dir(eta) for k=1..K - Local: per-doc topic proportions theta_d ~ Dir(alpha) per-word topic assignments z_dn ~ Cat(theta_d) Variational family: - q(beta) = prod_k Dir(beta_k; lambda_k) - q(theta_d, z_d) = Dir(theta_d; gamma_d) prod_n Cat(z_dn; phi_dn) """ def __init__(self, K: int, V: int, alpha: float = 0.1, eta: float = 0.01): self.K = K # Number of topics self.V = V # Vocabulary size self.alpha = alpha # Dirichlet prior for topic proportions self.eta = eta # Dirichlet prior for topic-word distributions # Initialize global variational parameters (natural params for Dirichlet) self.lmbda = np.random.gamma(100, 1/100, (K, V)) # Shape: (K, V) def compute_local_optimal(self, document: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ Coordinate ascent for local variational parameters given current global. Args: document: Bag-of-words representation, shape (V,) with word counts Returns: gamma: Optimal Dirichlet params for this doc, shape (K,) phi: Optimal topic assignments, shape (N_words, K) """ word_indices = np.where(document > 0)[0] word_counts = document[word_indices] N_words = len(word_indices) # Expected log topic-word probs: E[log beta_kv] = psi(lambda_kv) - psi(sum_v lambda_kv) E_log_beta = ( np.digamma(self.lmbda) - np.digamma(self.lmbda.sum(axis=1, keepdims=True)) ) # Initialize local params gamma = np.ones(self.K) * (self.alpha + N_words / self.K) phi = np.ones((N_words, self.K)) / self.K # Coordinate ascent until convergence for _ in range(20): # Usually converges in <10 iterations # Update phi (topic assignments) E_log_theta = np.digamma(gamma) - np.digamma(gamma.sum()) for n, w in enumerate(word_indices): log_phi_n = E_log_theta + E_log_beta[:, w] phi[n] = np.exp(log_phi_n - np.max(log_phi_n)) phi[n] /= phi[n].sum() # Update gamma (topic proportions) gamma = self.alpha + (phi * word_counts[:, np.newaxis]).sum(axis=0) return gamma, phi def compute_global_update( self, documents: List[np.ndarray], local_params: List[Tuple[np.ndarray, np.ndarray]], N_total: int ) -> np.ndarray: """ Compute target global natural parameters from mini-batch. Returns the natural parameters that global vars would have if the entire dataset looked like this mini-batch. """ M = len(documents) # Accumulate sufficient statistics lambda_update = np.full((self.K, self.V), self.eta) # Start with prior for doc, (gamma, phi) in zip(documents, local_params): word_indices = np.where(doc > 0)[0] word_counts = doc[word_indices] for n, (w, count) in enumerate(zip(word_indices, word_counts)): lambda_update[:, w] += count * phi[n] # Scale to full dataset size lambda_update = self.eta + (N_total / M) * (lambda_update - self.eta) return lambda_update def fit( self, documents: np.ndarray, batch_size: int = 100, num_iterations: int = 10000 ): """Run SVI with natural gradients.""" N = len(documents) for t in range(num_iterations): # Learning rate satisfying Robbins-Monro rho_t = (t + 10) ** (-0.7) # Sample mini-batch batch_idx = np.random.choice(N, batch_size, replace=False) batch = [documents[i] for i in batch_idx] # Compute optimal local params for batch local_params = [self.compute_local_optimal(doc) for doc in batch] # Compute target global params lambda_target = self.compute_global_update(batch, local_params, N) # Natural gradient update self.lmbda = (1 - rho_t) * self.lmbda + rho_t * lambda_targetWhile natural gradients offer theoretical and empirical advantages, practical implementation requires attention to several issues.
Approximation strategies:
Several methods approximate the natural gradient without computing full Fisher inverses:
1. Diagonal approximation: Approximate \(\mathbf{F}\) with diagonal \(\text{diag}(F_{11}, \ldots, F_{dd})\). Reduces inversion to \(O(d)\) element-wise division.
2. Block-diagonal approximation: For structured parameters (e.g., layer-wise in neural networks), approximate \(\mathbf{F}\) as block-diagonal. Each block can be inverted separately.
3. Kronecker-factored approximation (K-FAC): For neural networks, approximate layer Fisher matrices as Kronecker products \(\mathbf{F}_l \approx \mathbf{A}_l \otimes \mathbf{B}_l\). Inversion becomes \((\mathbf{A}^{-1}) \otimes (\mathbf{B}^{-1})\).
4. Empirical Fisher: Use \(\hat{\mathbf{F}} = \frac{1}{N} \sum_{i} abla_\theta \log p(x_i; \theta) abla_\theta \log p(x_i; \theta)^T\), computed from data rather than expected under the model.
5. Online natural gradient (TONGA): Maintain a low-rank approximation to \(\mathbf{F}^{-1}\) updated incrementally each iteration.
The Adam optimizer approximates natural gradients! Its adaptive second moment estimate v_t ≈ diag(F), so dividing by √v_t approximates diagonal natural gradient. This connection partly explains Adam's effectiveness for probabilistic models.
| Method | Computation | Memory | Approximation Quality |
|---|---|---|---|
| Exact Fisher | O(d³) | O(d²) | Exact |
| Diagonal | O(d) | O(d) | Poor for correlated params |
| Block-diagonal | O(∑b³) | O(∑b²) | Good within blocks |
| K-FAC | O(d · √d) | O(d) | Good for neural networks |
| Low-rank | O(d·r²) | O(d·r) | Captures r top directions |
Natural gradients represent a fundamental advance in variational inference, aligning optimization with the intrinsic geometry of probability distributions.
What's next:
With mini-batch optimization and natural gradients established, we now turn to the scaling challenge: actually applying SVI to datasets with millions or billions of observations. The next page addresses the systems-level considerations—distributed computation, memory management, and algorithmic modifications—that enable VI at truly massive scale.
You now understand natural gradients—their theoretical foundation in information geometry, their elegant formulation for exponential families, and their practical implementation in stochastic variational inference. This knowledge is essential for understanding why modern VI algorithms converge efficiently and how to design new ones.