Loading learning content...
Standard gradient descent treats all directions in parameter space as equal—a step of size $\epsilon$ in any direction is considered "the same amount of change." But for probabilistic models like neural networks, this assumption is fundamentally flawed.
Consider a neural network's softmax output layer, parameterized by logits $z_1, z_2, z_3$. The probabilities $p_1, p_2, p_3$ are computed as $p_i = e^{z_i}/\sum_j e^{z_j}$. A small change $\Delta z_1 = 0.01$ has very different effects depending on the current values:
The same parameter change causes vastly different changes to the model's predictions. Standard gradient descent ignores this, treating both cases identically.
Natural gradient descent addresses this fundamental issue by measuring distances in the space of probability distributions, not the space of parameters. It asks: "What parameter change causes a fixed-size change in the model's output distribution?"
By the end of this page, you will: (1) Understand why Euclidean parameter space is the wrong geometry for optimization, (2) Derive the Fisher information matrix as the natural metric for probability distributions, (3) Develop the natural gradient update rule and understand its invariance properties, (4) See the deep connections between natural gradient, second-order methods, and information geometry, and (5) Appreciate both the theoretical elegance and practical challenges of natural gradient descent.
To motivate natural gradient descent, we must first understand why standard gradient descent's implicit assumptions are problematic.
Standard gradient descent updates parameters as: $$\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}t - \eta \nabla\theta \mathcal{L}(\boldsymbol{\theta}_t)$$
This update minimizes a first-order Taylor approximation of the loss subject to a constraint on the Euclidean step size:
$$\boldsymbol{\theta}{t+1} = \arg\min\theta \left{ \mathcal{L}(\boldsymbol{\theta}t) + \nabla\theta \mathcal{L}(\boldsymbol{\theta}_t)^T (\boldsymbol{\theta} - \boldsymbol{\theta}_t) \right} \quad \text{s.t.} \quad |\boldsymbol{\theta} - \boldsymbol{\theta}_t|_2^2 \leq \epsilon$$
The constraint uses Euclidean distance in parameter space: $|\boldsymbol{\theta} - \boldsymbol{\theta}'|_2^2 = \sum_i (\theta_i - \theta'_i)^2$.
For many optimization problems, this is perfectly sensible. But neural networks aren't just any function—they define probability distributions over outputs.
A critical flaw of Euclidean distance: it depends on the parameterization. If you replace parameter $\theta$ with $\phi = \theta^2$, the "same" model has different gradient descent dynamics. This is arbitrary—the model's behavior depends only on the probability distribution it represents, not how we parameterize it.
Consider a simple binary classifier with a single parameter $\theta$ and output probability $p = \sigma(\theta) = 1/(1 + e^{-\theta})$.
The gradient of cross-entropy loss with respect to $\theta$ is $(p - y)$ where $y \in {0, 1}$ is the label.
Now consider two scenarios:
If the true label is $y = 0$:
But wait—Case B is making a very confident wrong prediction. A gradient update of 0.01 barely changes $p$ (since we're on the flat part of the sigmoid). Case A is uncertain, and the same 0.01 update causes a much larger change in $p$.
The gradient magnitudes don't reflect the impact on the output distribution. Case B needs larger steps in $\theta$ to achieve meaningful distribution changes, but standard gradient descent doesn't account for this.
| Starting point | Euclidean gradient | Effect on $p$ | Actual behavior change |
|---|---|---|---|
| $\theta = 0$ ($p = 0.5$) | 0.5 | Large | Model becomes more confident |
| $\theta = 10$ ($p \approx 1$) | 1.0 | Tiny | Almost no behavior change |
| $\theta = -10$ ($p \appro 0$) | 0.0 | None | No behavior change |
The natural remedy is to measure distances not in parameter space, but in the space of probability distributions. This leads us to information geometry—the study of statistical manifolds.
The most natural measure of "distance" between probability distributions is the Kullback-Leibler (KL) divergence:
$$D_{KL}(p | q) = \mathbb{E}_{x \sim p} \left[ \log \frac{p(x)}{q(x)} \right] = \sum_x p(x) \log \frac{p(x)}{q(x)}$$
KL divergence measures how much information is lost when we approximate distribution $p$ with distribution $q$. It has several important properties:
Despite not being a true metric, KL divergence is the right notion of "distance" for our purposes because it measures information-theoretic difference between distributions.
Consider a family of probability distributions ${p_\theta : \theta \in \Theta}$. This family forms a manifold—a curved surface in the infinite-dimensional space of all distributions. Just as the Earth's surface can be parameterized by latitude and longitude, this manifold can be parameterized by $\theta$. Information geometry studies the intrinsic geometry of this manifold.
For optimization, we care about small changes in parameters. Consider two nearby distributions $p_\theta$ and $p_{\theta + d\theta}$. The KL divergence between them is:
$$D_{KL}(p_\theta | p_{\theta + d\theta}) \approx \frac{1}{2} d\theta^T \mathbf{F}(\theta) , d\theta + \mathcal{O}(|d\theta|^3)$$
where $\mathbf{F}(\theta)$ is the Fisher information matrix:
$$F_{ij}(\theta) = \mathbb{E}{x \sim p\theta} \left[ \frac{\partial \log p_\theta(x)}{\partial \theta_i} \frac{\partial \log p_\theta(x)}{\partial \theta_j} \right]$$
This is a remarkable result: the Fisher information is the Hessian of KL divergence at zero. It defines a local inner product on the tangent space of the statistical manifold.
Alternative equivalent expressions for the Fisher information: $$\mathbf{F}(\theta) = -\mathbb{E}{x \sim p\theta} \left[ \nabla^2_\theta \log p_\theta(x) \right] = \text{Cov}{x \sim p\theta} \left[ \nabla_\theta \log p_\theta(x) \right]$$
The Fisher information $\mathbf{F}(\theta)$ defines a Riemannian metric on the parameter manifold. This means:
This geometry captures the fundamental insight: directions in which the distribution changes rapidly should have "shorter" distances than directions where it changes slowly.
The Fisher metric is the unique (up to scaling) Riemannian metric that is:
Now we can derive natural gradient descent by replacing Euclidean distance with KL divergence.
The steepest descent direction minimizes the linear approximation of the loss while constraining the (squared) distance to be at most $\epsilon$:
$$\tilde{\nabla}\theta \mathcal{L} = \arg\min{d\theta} \left{ \nabla_\theta \mathcal{L}^T d\theta \right} \quad \text{s.t.} \quad d\theta^T \mathbf{F} , d\theta \leq \epsilon$$
Using Lagrange multipliers: $$\mathcal{L}{\text{Lagrange}} = \nabla\theta \mathcal{L}^T d\theta + \lambda (d\theta^T \mathbf{F} , d\theta - \epsilon)$$
Taking derivatives and setting to zero: $$\nabla_\theta \mathcal{L} + 2\lambda \mathbf{F} , d\theta = 0$$ $$d\theta = -\frac{1}{2\lambda} \mathbf{F}^{-1} \nabla_\theta \mathcal{L}$$
The magnitude depends on $\lambda$, but the direction is $\mathbf{F}^{-1} \nabla_\theta \mathcal{L}$.
$$\tilde{\nabla}\theta \mathcal{L} = \mathbf{F}^{-1} \nabla\theta \mathcal{L}$$
The natural gradient is the ordinary gradient premultiplied by the inverse Fisher information. The natural gradient update is: $$\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}t - \eta \mathbf{F}^{-1} \nabla\theta \mathcal{L}$$
The most profound property of natural gradient descent is its invariance to reparameterization. If we change variables from $\theta$ to $\phi = g(\theta)$ for any differentiable $g$, the natural gradient update produces the same change in the probability distribution.
Proof sketch: Under reparameterization, the gradient transforms as: $$\nabla_\phi \mathcal{L} = \mathbf{J}^{-T} \nabla_\theta \mathcal{L}$$ where $\mathbf{J} = \partial \theta / \partial \phi$ is the Jacobian.
The Fisher information transforms as: $$\mathbf{F}\phi = \mathbf{J}^{-T} \mathbf{F}\theta \mathbf{J}^{-1}$$
The natural gradient in the new parameterization: $$\mathbf{F}\phi^{-1} \nabla\phi \mathcal{L} = (\mathbf{J}^{-T} \mathbf{F}\theta \mathbf{J}^{-1})^{-1} \mathbf{J}^{-T} \nabla\theta \mathcal{L}$$ $$= \mathbf{J} \mathbf{F}\theta^{-1} \mathbf{J}^T \mathbf{J}^{-T} \nabla\theta \mathcal{L} = \mathbf{J} \mathbf{F}\theta^{-1} \nabla\theta \mathcal{L}$$
This is exactly the transformation of the natural gradient under the change of variables, confirming invariance.
Standard gradient descent lacks this property—its behavior depends on arbitrary aspects of the parameterization.
| Property | Gradient Descent | Natural Gradient |
|---|---|---|
| Update direction | $-\nabla_\theta \mathcal{L}$ | $-\mathbf{F}^{-1} \nabla_\theta \mathcal{L}$ |
| Distance constraint | Euclidean: $|d\theta|_2$ | KL: $\sqrt{d\theta^T \mathbf{F} d\theta}$ |
| Parameterization invariant | No | Yes |
| Per-iteration cost | $\mathcal{O}(n)$ | $\mathcal{O}(n^3)$ naively |
| Convergence on quadratics | Linear rate $\mathcal{O}(\kappa)$ | One step (if exact Fisher) |
For neural networks, we need to express the Fisher information in terms of the network's outputs and loss function.
Consider a neural network for K-class classification. Given input $x$, the network outputs logits $z = f_\theta(x)$, and the softmax produces probabilities: $$p_c = \text{softmax}(z)_c = \frac{e^{z_c}}{\sum_j e^{z_j}}$$
The log-likelihood for true class $c$ is: $$\log p_c = z_c - \log \sum_j e^{z_j}$$
The gradient with respect to logits is the familiar expression: $$\frac{\partial \log p_c}{\partial z} = \mathbf{e}_c - \mathbf{p}$$
where $\mathbf{e}_c$ is the one-hot vector for class $c$ and $\mathbf{p}$ is the probability vector.
The Fisher information with respect to logits is: $$\mathbf{F}z = \mathbb{E}{c \sim p} \left[ (\mathbf{e}_c - \mathbf{p})(\mathbf{e}_c - \mathbf{p})^T \right] = \text{diag}(\mathbf{p}) - \mathbf{p}\mathbf{p}^T$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
import torchimport torch.nn as nnimport torch.nn.functional as F def compute_fisher_logit(probs: torch.Tensor) -> torch.Tensor: """ Compute Fisher information matrix for softmax logits. For softmax probabilities p, the Fisher w.r.t. logits is: F = diag(p) - p @ p^T Args: probs: Softmax probabilities [batch_size, num_classes] Returns: Fisher matrix [batch_size, num_classes, num_classes] """ # diag(p): [batch_size, num_classes, num_classes] diag_p = torch.diag_embed(probs) # p @ p^T: [batch_size, num_classes, num_classes] outer_p = probs.unsqueeze(-1) @ probs.unsqueeze(-2) return diag_p - outer_p def compute_empirical_fisher( model: nn.Module, data_loader, num_samples: int = 1000) -> torch.Tensor: """ Compute empirical Fisher information matrix. The empirical Fisher uses the true labels rather than samples from the model's distribution. This is standard in deep learning and equals the Gauss-Newton Hessian for cross-entropy loss. Args: model: Neural network data_loader: Data loader num_samples: Number of samples to use Returns: Empirical Fisher matrix (flattened parameters) """ model.eval() params = [p for p in model.parameters() if p.requires_grad] n_params = sum(p.numel() for p in params) fisher = torch.zeros(n_params, n_params) n_processed = 0 for x, y in data_loader: if n_processed >= num_samples: break # Forward pass output = model(x) log_probs = F.log_softmax(output, dim=-1) # For each sample in batch for i in range(min(len(x), num_samples - n_processed)): # Gradient of log probability for true class model.zero_grad() log_prob_i = log_probs[i, y[i]] grad = torch.autograd.grad( log_prob_i, params, retain_graph=True ) grad_flat = torch.cat([g.flatten() for g in grad]) # Accumulate outer product fisher += torch.outer(grad_flat, grad_flat) n_processed += 1 fisher /= n_processed return fisher def natural_gradient_step( model: nn.Module, loss_fn, x: torch.Tensor, y: torch.Tensor, lr: float = 0.01, damping: float = 1e-3): """ Perform a single natural gradient descent step. Warning: This computes the full Fisher matrix - only feasible for small models! Real implementations use approximations. """ # Forward pass and compute loss output = model(x) loss = loss_fn(output, y) # Compute gradient params = [p for p in model.parameters() if p.requires_grad] grad = torch.autograd.grad(loss, params) grad_flat = torch.cat([g.flatten() for g in grad]) # Compute empirical Fisher (batch approximation) n_params = grad_flat.numel() fisher = torch.zeros(n_params, n_params) log_probs = F.log_softmax(output, dim=-1) for i in range(len(x)): model.zero_grad() log_prob_i = log_probs[i, y[i]] g = torch.autograd.grad(log_prob_i, params, retain_graph=True) g_flat = torch.cat([gi.flatten() for gi in g]) fisher += torch.outer(g_flat, g_flat) fisher /= len(x) # Add damping for numerical stability fisher += damping * torch.eye(n_params) # Compute natural gradient: F^{-1} @ grad nat_grad = torch.linalg.solve(fisher, grad_flat) # Apply update offset = 0 with torch.no_grad(): for p in params: numel = p.numel() p -= lr * nat_grad[offset:offset + numel].view(p.shape) offset += numel return loss.item()In practice, we compute the empirical Fisher using the true data labels: $$\hat{\mathbf{F}} = \frac{1}{N} \sum_{i=1}^N \nabla_\theta \log p_\theta(y_i | x_i) \nabla_\theta \log p_\theta(y_i | x_i)^T$$
This differs from the true Fisher, which uses samples from the model's distribution: $$\mathbf{F} = \mathbb{E}{x \sim p{data}} \mathbb{E}{y \sim p\theta(\cdot | x)} \left[ \nabla_\theta \log p_\theta(y | x) \nabla_\theta \log p_\theta(y | x)^T \right]$$
The empirical Fisher is:
The connection to Gauss-Newton is important: natural gradient with empirical Fisher is equivalent to Gauss-Newton optimization, bridging information geometry and classical optimization theory.
Natural gradient descent inherits the favorable convergence properties of Newton's method while being tailored to probabilistic models.
The Fisher information matrix arises in statistics as the Cramér-Rao bound: no unbiased estimator can have variance smaller than $\mathbf{F}^{-1}$. Estimators achieving this bound are called efficient.
Natural gradient descent is "efficient" in an analogous sense: it extracts the maximum information from each gradient evaluation about the optimal parameter direction. Standard gradient descent wastes effort moving in directions that barely change the distribution.
Near a local minimum $\theta^*$, natural gradient behaves like Newton's method:
For a well-specified model (where the data comes from some $p_{\theta^}$), the loss Hessian at $\theta^$ equals the Fisher information: $$\nabla^2_\theta \mathcal{L}(\theta^) = \mathbf{F}(\theta^)$$
This means natural gradient descent is asymptotically equivalent to Newton's method, enjoying quadratic local convergence.
Amari (1998) showed that natural gradient achieves Fisher-efficient online learning. When processing data sequentially, natural gradient updates achieve the optimal statistical efficiency predicted by the Cramér-Rao bound. Standard gradient descent is statistically suboptimal—it doesn't use data as efficiently as possible.
Deep networks often exhibit "plateaus"—regions where the loss decreases very slowly. These occur when:
Natural gradient helps escape plateaus by rescaling directions according to curvature. Directions with small curvature (where progress is easy) receive larger steps; directions with large curvature (where the landscape is steep) receive smaller steps.
Consider a toy loss landscape: $$\mathcal{L}(\theta_1, \theta_2) = \frac{1}{2}\theta_1^2 + \frac{1}{2} \times 100 \times \theta_2^2$$
The eigenvalues of the Hessian are 1 and 100 (condition number = 100).
This is precisely why second-order methods excel on ill-conditioned problems.
| Method | Updates to reach $\epsilon$ error | Condition sensitivity |
|---|---|---|
| Gradient Descent | $\mathcal{O}(\kappa \log(1/\epsilon))$ | High—scales with $\kappa$ |
| Gradient Descent + Momentum | $\mathcal{O}(\sqrt{\kappa} \log(1/\epsilon))$ | Medium—scales with $\sqrt{\kappa}$ |
| Natural Gradient | $\mathcal{O}(\log(1/\epsilon))$ | None—independent of $\kappa$ |
| Newton's Method | $\mathcal{O}(\log \log(1/\epsilon))$ | None—quadratic rate |
Despite its theoretical elegance, natural gradient faces the same computational barrier as Newton's method: the Fisher matrix is $n \times n$ where $n$ is the number of parameters.
For a network with $n$ parameters:
For a ResNet-50 with 25M parameters, the Fisher matrix would have $6.25 \times 10^{14}$ entries—requiring exabytes of storage.
The simplest approximation uses only the diagonal of $\mathbf{F}$: $$\tilde{\nabla}\theta \mathcal{L} \approx \text{diag}(\mathbf{F})^{-1} \odot \nabla\theta \mathcal{L}$$
where $\odot$ denotes element-wise multiplication.
This is essentially AdaGrad or Adam (with appropriate averaging)! These popular optimizers can be viewed as natural gradient with diagonal Fisher approximation. The connection explains why adaptive optimizers work well—they're capturing per-parameter curvature.
Adam maintains running estimates of first (momentum) and second (squared gradient) moments. The second moment $v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2$ approximates the diagonal of the empirical Fisher. The update $\theta_{t+1} = \theta_t - \eta g_t / \sqrt{v_t}$ is precisely natural gradient with a diagonal Fisher approximation!
A better approximation uses a block-diagonal Fisher, where each block corresponds to one layer:
$$\mathbf{F} \approx \begin{pmatrix} \mathbf{F}_1 & 0 & \cdots & 0 \ 0 & \mathbf{F}_2 & \cdots & 0 \ \vdots & \vdots & \ddots & \vdots \ 0 & 0 & \cdots & \mathbf{F}_L \end{pmatrix}$$
This ignores correlations between layers but captures within-layer curvature. For a layer with $n_l$ parameters, each block has size $n_l \times n_l$, and inversion costs $\mathcal{O}(n_l^3)$.
Still expensive for fully-connected layers with many parameters, but tractable for careful architectures.
The most successful approximation is K-FAC (which we'll cover in detail on the next page). It approximates each layer's Fisher block as a Kronecker product: $$\mathbf{F}l \approx \mathbf{A}{l-1} \otimes \mathbf{G}_l$$
where:
The Kronecker structure enables efficient inversion: $(\mathbf{A} \otimes \mathbf{G})^{-1} = \mathbf{A}^{-1} \otimes \mathbf{G}^{-1}$, reducing cost from $\mathcal{O}(n_l^3)$ to $\mathcal{O}(d_{in}^3 + d_{out}^3)$.
Natural gradient descent represents a fundamental shift in how we think about optimization for probabilistic models. By measuring progress in distribution space rather than parameter space, it achieves invariance properties and convergence rates that standard methods cannot match.
The theoretical power of natural gradient is clear, but its practical application requires clever approximations. The next page explores K-FAC (Kronecker-Factored Approximate Curvature)—arguably the most successful method for practical natural gradient optimization in deep learning. K-FAC exploits the structure of neural network layers to make Fisher approximation and inversion tractable.
You now have a rigorous understanding of natural gradient descent—from its information-geometric foundations to its convergence properties and computational challenges. This theoretical foundation is essential for understanding practical approximations like K-FAC and for appreciating why adaptive optimizers work.