Loading content...
Natural gradient descent promises parameterization-invariant optimization with fast convergence, but its $\mathcal{O}(n^2)$ storage and $\mathcal{O}(n^3)$ inversion costs render it impractical for deep neural networks. K-FAC (Kronecker-Factored Approximate Curvature) elegantly resolves this dilemma by exploiting the structure of neural network layers.
The key insight of K-FAC is that the Fisher information matrix for a neural network layer can be approximated as a Kronecker product of two much smaller matrices. This factorization reduces storage from $\mathcal{O}(n^2)$ to $\mathcal{O}(n)$ and inversion from $\mathcal{O}(n^3)$ to manageable per-layer costs.
Since its introduction by Martens and Grosse in 2015, K-FAC has become one of the most successful second-order optimization methods for deep learning, enabling training of deep networks with fewer iterations than SGD and competitive wall-clock times on modern hardware.
By the end of this page, you will: (1) Understand the Kronecker product and its properties for matrix inversion, (2) Derive the K-FAC approximation for fully-connected and convolutional layers, (3) Learn the statistics that K-FAC maintains and updates, (4) Understand the damping and trust region strategies essential for stability, (5) Implement K-FAC for practical neural network training, and (6) Appreciate the trade-offs and when K-FAC is most effective.
Before diving into K-FAC, we need to understand the Kronecker product—the mathematical operation that makes K-FAC computationally tractable.
The Kronecker product of matrices $\mathbf{A} \in \mathbb{R}^{m \times n}$ and $\mathbf{B} \in \mathbb{R}^{p \times q}$ is the $mp \times nq$ matrix:
$$\mathbf{A} \otimes \mathbf{B} = \begin{pmatrix} a_{11}\mathbf{B} & a_{12}\mathbf{B} & \cdots & a_{1n}\mathbf{B} \ a_{21}\mathbf{B} & a_{22}\mathbf{B} & \cdots & a_{2n}\mathbf{B} \ \vdots & \vdots & \ddots & \vdots \ a_{m1}\mathbf{B} & a_{m2}\mathbf{B} & \cdots & a_{mn}\mathbf{B} \end{pmatrix}$$
Each element $a_{ij}$ of $\mathbf{A}$ is replaced by the $p \times q$ block $a_{ij}\mathbf{B}$.
The Kronecker product has remarkable properties that K-FAC exploits:
| Property | Formula | Implication for K-FAC |
|---|---|---|
| Inversion | $(\mathbf{A} \otimes \mathbf{B})^{-1} = \mathbf{A}^{-1} \otimes \mathbf{B}^{-1}$ | Invert small matrices instead of large one |
| Eigendecomposition | $\text{eig}(\mathbf{A} \otimes \mathbf{B}) = \text{eig}(\mathbf{A}) \otimes \text{eig}(\mathbf{B})$ | Eigenvalues are products, eigenvectors are Kronecker products |
| Vec-permutation | $(\mathbf{A} \otimes \mathbf{B})\text{vec}(\mathbf{X}) = \text{vec}(\mathbf{B}\mathbf{X}\mathbf{A}^T)$ | Multiply efficiently without forming Kronecker product |
| Trace | $\text{tr}(\mathbf{A} \otimes \mathbf{B}) = \text{tr}(\mathbf{A})\text{tr}(\mathbf{B})$ | Trace computations factor |
| Determinant | $\det(\mathbf{A} \otimes \mathbf{B}) = \det(\mathbf{A})^p \det(\mathbf{B})^m$ | Log-det computations factor |
The inversion property is the heart of K-FAC's efficiency. If the Fisher is $\mathbf{F} = \mathbf{A} \otimes \mathbf{B}$ where $\mathbf{A}$ is $m \times m$ and $\mathbf{B}$ is $n \times n$, then:
For a layer with 1000 inputs and 1000 outputs, this is $10^{18}$ vs $2 \times 10^9$—a factor of 500 million speedup!
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
import torchimport numpy as np def kronecker_product(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: """ Compute Kronecker product A ⊗ B. For tensors of shape (m, n) and (p, q), returns shape (m*p, n*q). """ m, n = A.shape p, q = B.shape return torch.einsum('ij,kl->ikjl', A, B).reshape(m*p, n*q) def kronecker_mvp(A: torch.Tensor, B: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ Compute (A ⊗ B) @ x efficiently without forming the Kronecker product. Uses the identity: (A ⊗ B) vec(X) = vec(B X A^T) where x = vec(X) with X of shape (n, m) matching B @ X @ A^T Args: A: Matrix of shape (m, m) B: Matrix of shape (n, n) x: Vector of shape (m * n,) Returns: (A ⊗ B) @ x of shape (m * n,) """ m = A.shape[0] n = B.shape[0] # Reshape x to matrix X (column-major/Fortran order for vec) X = x.reshape(n, m) # Note: PyTorch uses row-major, need care with ordering # Compute B @ X @ A^T result = B @ X @ A.T return result.flatten() def kronecker_inverse_mvp( A: torch.Tensor, B: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ Compute (A ⊗ B)^{-1} @ x efficiently. Uses: (A ⊗ B)^{-1} = A^{-1} ⊗ B^{-1} Then applies the vec-permutation identity. """ A_inv = torch.linalg.inv(A) B_inv = torch.linalg.inv(B) return kronecker_mvp(A_inv, B_inv, x) # Demonstration: verify correctnessdef verify_kronecker_operations(): torch.manual_seed(42) m, n = 50, 30 A = torch.randn(m, m) B = torch.randn(n, n) x = torch.randn(m * n) # Full Kronecker product (expensive!) K = kronecker_product(A, B) # Method 1: Direct multiplication y1 = K @ x # Method 2: Efficient Kronecker MVP y2 = kronecker_mvp(A, B, x) print(f"MVP relative error: {torch.norm(y1 - y2) / torch.norm(y1):.2e}") # Test inverse K_inv = torch.linalg.inv(K) z1 = K_inv @ x z2 = kronecker_inverse_mvp(A, B, x) print(f"Inverse MVP relative error: {torch.norm(z1 - z2) / torch.norm(z1):.2e}") print("\nComplexity comparison:") print(f" Full Kronecker: {m*n} x {m*n} = {(m*n)**2:,} elements") print(f" K-FAC factors: {m}x{m} + {n}x{n} = {m**2 + n**2:,} elements") print(f" Storage ratio: {(m*n)**2 / (m**2 + n**2):.0f}x savings") verify_kronecker_operations()K-FAC's approximation is grounded in the structure of neural network layers. Let's derive the exact Fisher information for a fully-connected layer and see why the Kronecker approximation is natural.
Consider layer $l$ with:
The gradient with respect to weights uses the outer product: $$\nabla_{\mathbf{W}_l} \mathcal{L} = \mathbf{g}l \mathbf{a}{l-1}^T$$
Vectorized (column-major stacking): $$\nabla_{\text{vec}(\mathbf{W}l)} \mathcal{L} = \mathbf{a}{l-1} \otimes \mathbf{g}_l$$
The Fisher information for layer $l$'s weights is: $$\mathbf{F}l = \mathbb{E}\left[ (\mathbf{a}{l-1} \otimes \mathbf{g}l)(\mathbf{a}{l-1} \otimes \mathbf{g}_l)^T \right]$$
Using the Kronecker product rule $(\mathbf{u} \otimes \mathbf{v})(\mathbf{u} \otimes \mathbf{v})^T = (\mathbf{u}\mathbf{u}^T) \otimes (\mathbf{v}\mathbf{v}^T)$: $$\mathbf{F}l = \mathbb{E}\left[ (\mathbf{a}{l-1}\mathbf{a}_{l-1}^T) \otimes (\mathbf{g}_l\mathbf{g}_l^T) \right]$$
$$\mathbf{F}l = \mathbb{E}\left[ (\mathbf{a}{l-1}\mathbf{a}_{l-1}^T) \otimes (\mathbf{g}l\mathbf{g}l^T) \right] \approx \mathbb{E}\left[ \mathbf{a}{l-1}\mathbf{a}{l-1}^T \right] \otimes \mathbb{E}\left[ \mathbf{g}_l\mathbf{g}l^T \right] = \mathbf{A}{l-1} \otimes \mathbf{G}_l$$
K-FAC assumes that input activations and output gradients are statistically independent. This is exact when the data and model are jointly Gaussian, and a good approximation in many practical settings.
K-FAC's independence assumption $\mathbb{E}[\mathbf{a} \mathbf{a}^T \otimes \mathbf{g} \mathbf{g}^T] \approx \mathbb{E}[\mathbf{a} \mathbf{a}^T] \otimes \mathbb{E}[\mathbf{g} \mathbf{g}^T]$ is not always accurate:
When it's good:
When it's poor:
Despite this approximation, K-FAC works remarkably well in practice, often outperforming exact methods that make no such assumption but have worse scalability.
K-FAC maintains two statistics per layer:
$$\mathbf{A}{l-1} = \mathbb{E}[\mathbf{a}{l-1}\mathbf{a}{l-1}^T] \in \mathbb{R}^{d{in} \times d_{in}}$$ $$\mathbf{G}_l = \mathbb{E}[\mathbf{g}l\mathbf{g}l^T] \in \mathbb{R}^{d{out} \times d{out}}$$
These are estimated via exponential moving averages during training: $$\mathbf{A}{l-1} \leftarrow (1-\beta)\mathbf{A}{l-1} + \beta \frac{1}{B}\sum_{i=1}^B \mathbf{a}{l-1}^{(i)} (\mathbf{a}{l-1}^{(i)})^T$$ $$\mathbf{G}_l \leftarrow (1-\beta)\mathbf{G}l + \beta \frac{1}{B}\sum{i=1}^B \mathbf{g}_l^{(i)} (\mathbf{g}_l^{(i)})^T$$
Typical values: $\beta = 0.95$ (update every step with heavy momentum).
With the Kronecker approximation, computing the natural gradient becomes tractable.
The natural gradient for layer $l$ is: $$\tilde{\nabla}_{\mathbf{W}_l} = \mathbf{F}l^{-1} \nabla{\mathbf{W}_l} \mathcal{L}$$
Using the K-FAC approximation: $$\tilde{\nabla}_{\mathbf{W}l} = (\mathbf{A}{l-1} \otimes \mathbf{G}l)^{-1} \text{vec}(\nabla{\mathbf{W}l})$$ $$= (\mathbf{A}{l-1}^{-1} \otimes \mathbf{G}l^{-1}) \text{vec}(\nabla{\mathbf{W}_l})$$
Using the vec-permutation identity $(\mathbf{A}^{-1} \otimes \mathbf{G}^{-1})\text{vec}(\mathbf{X}) = \text{vec}(\mathbf{G}^{-1}\mathbf{X}\mathbf{A}^{-1})$:
$$\tilde{\nabla}_{\mathbf{W}_l} = \text{vec}(\mathbf{G}l^{-1} \nabla{\mathbf{W}l} \mathbf{A}{l-1}^{-1})$$
Or in matrix form: $$\tilde{\nabla}_{\mathbf{W}_l} = \mathbf{G}l^{-1} \nabla{\mathbf{W}l} \mathbf{A}{l-1}^{-1}$$
The K-FAC natural gradient is a matrix sandwich: multiply the gradient matrix on the left by $\mathbf{G}^{-1}$ and on the right by $\mathbf{A}^{-1}$. This is extremely efficient—just two matrix multiplications instead of inverting a massive Kronecker product.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
import torchimport torch.nn as nnfrom typing import Dict, List, Tuple class KFACOptimizer: """ K-FAC optimizer for neural networks. Maintains Kronecker-factored Fisher approximations for each layer and computes natural gradient updates efficiently. """ def __init__( self, model: nn.Module, lr: float = 0.01, damping: float = 1e-3, cov_ema_decay: float = 0.95, inv_update_freq: int = 10, weight_decay: float = 0.0 ): self.model = model self.lr = lr self.damping = damping self.cov_ema_decay = cov_ema_decay self.inv_update_freq = inv_update_freq self.weight_decay = weight_decay # Find all linear layers self.layers: Dict[str, nn.Linear] = {} for name, module in model.named_modules(): if isinstance(module, nn.Linear): self.layers[name] = module # Initialize K-FAC statistics self.A: Dict[str, torch.Tensor] = {} # Input activation covariance self.G: Dict[str, torch.Tensor] = {} # Output gradient covariance self.A_inv: Dict[str, torch.Tensor] = {} # Cached inverses self.G_inv: Dict[str, torch.Tensor] = {} # For capturing activations during forward pass self.activations: Dict[str, torch.Tensor] = {} self.gradients: Dict[str, torch.Tensor] = {} # Register hooks to capture statistics self._register_hooks() self.step_count = 0 def _register_hooks(self): """Register forward and backward hooks to capture statistics.""" def make_forward_hook(name): def hook(module, input, output): # Capture input activation (with bias term) a = input[0].detach() if module.bias is not None: # Append 1s for bias ones = torch.ones(a.shape[0], 1, device=a.device) a = torch.cat([a, ones], dim=1) self.activations[name] = a return hook def make_backward_hook(name): def hook(module, grad_input, grad_output): # Capture output gradient g = grad_output[0].detach() self.gradients[name] = g return hook for name, layer in self.layers.items(): layer.register_forward_hook(make_forward_hook(name)) layer.register_full_backward_hook(make_backward_hook(name)) def _update_cov_estimates(self): """Update running estimates of A and G.""" for name in self.layers: a = self.activations[name] # [batch_size, d_in] g = self.gradients[name] # [batch_size, d_out] batch_size = a.shape[0] # Compute batch covariances batch_A = (a.T @ a) / batch_size batch_G = (g.T @ g) / batch_size # EMA update if name not in self.A: self.A[name] = batch_A self.G[name] = batch_G else: decay = self.cov_ema_decay self.A[name] = decay * self.A[name] + (1 - decay) * batch_A self.G[name] = decay * self.G[name] + (1 - decay) * batch_G def _update_inverses(self): """Compute inverses of A and G with damping.""" for name in self.layers: A = self.A[name] G = self.G[name] # Compute pi factor for balanced damping # (described in the damping section) pi = torch.sqrt(torch.trace(A) / torch.trace(G) + 1e-8) # Add damping: multiply by sqrt since we'll use both A^{-1} and G^{-1} damped_A = A + (self.damping * pi) * torch.eye(A.shape[0], device=A.device) damped_G = G + (self.damping / pi) * torch.eye(G.shape[0], device=G.device) # Compute inverses self.A_inv[name] = torch.linalg.inv(damped_A) self.G_inv[name] = torch.linalg.inv(damped_G) def _compute_natural_gradient(self) -> Dict[str, torch.Tensor]: """Compute natural gradient for each layer.""" natural_grads = {} for name, layer in self.layers.items(): # Get gradient (including bias in augmented form) grad_w = layer.weight.grad if layer.bias is not None: grad = torch.cat([grad_w, layer.bias.grad.unsqueeze(1)], dim=1) else: grad = grad_w # Natural gradient: G^{-1} @ grad @ A^{-1} nat_grad = self.G_inv[name] @ grad @ self.A_inv[name] natural_grads[name] = nat_grad return natural_grads def step(self): """Perform one K-FAC optimization step.""" self.step_count += 1 # Update covariance estimates every step self._update_cov_estimates() # Update inverses periodically (expensive) if self.step_count % self.inv_update_freq == 0: self._update_inverses() # Skip natural gradient if inverses not computed yet if not self.A_inv: # Fall back to regular gradient for layer in self.layers.values(): if layer.weight.grad is not None: layer.weight.data -= self.lr * layer.weight.grad if layer.bias is not None and layer.bias.grad is not None: layer.bias.data -= self.lr * layer.bias.grad return # Compute natural gradients natural_grads = self._compute_natural_gradient() # Apply updates with torch.no_grad(): for name, layer in self.layers.items(): nat_grad = natural_grads[name] # Separate weight and bias gradients if layer.bias is not None: grad_w = nat_grad[:, :-1] grad_b = nat_grad[:, -1] else: grad_w = nat_grad grad_b = None # Weight decay if self.weight_decay > 0: grad_w = grad_w + self.weight_decay * layer.weight # Apply updates layer.weight.data -= self.lr * grad_w if grad_b is not None: layer.bias.data -= self.lr * grad_b # Clear hooks data self.activations.clear() self.gradients.clear() # Example usagedef train_with_kfac(): """Example training loop with K-FAC.""" # Simple MLP model = nn.Sequential( nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 10) ) optimizer = KFACOptimizer( model, lr=0.001, damping=0.01, inv_update_freq=20 ) loss_fn = nn.CrossEntropyLoss() # Dummy data for illustration for epoch in range(5): x = torch.randn(64, 784) y = torch.randint(0, 10, (64,)) # Forward pass output = model(x) loss = loss_fn(output, y) # Backward pass (hooks capture activations/gradients) model.zero_grad() loss.backward() # K-FAC update optimizer.step() print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}") train_with_kfac()Damping is crucial for K-FAC's stability. Without it, the Fisher approximation can be singular or nearly so, leading to enormous and destabilizing updates.
The simplest damping adds a multiple of identity to the Fisher: $$\tilde{\mathbf{F}} = \mathbf{F} + \lambda \mathbf{I}$$
For K-FAC, we need to add damping to the Kronecker factors. Since $(\mathbf{A} + \epsilon_A \mathbf{I}) \otimes (\mathbf{G} + \epsilon_G \mathbf{I}) \neq (\mathbf{A} \otimes \mathbf{G}) + \lambda \mathbf{I}$, we face a choice.
Factored damping: Add $\epsilon_A \mathbf{I}$ to $\mathbf{A}$ and $\epsilon_G \mathbf{I}$ to $\mathbf{G}$: $$(\mathbf{A} + \epsilon_A \mathbf{I}) \otimes (\mathbf{G} + \epsilon_G \mathbf{I}) \approx \mathbf{A} \otimes \mathbf{G} + \epsilon_A \epsilon_G \mathbf{I}$$
This approximation holds when $\mathbf{A}$ and $\mathbf{G}$ dominate the identity terms. The effective damping is $\lambda \approx \epsilon_A \epsilon_G$.
The scale of $\mathbf{A}$ and $\mathbf{G}$ can differ significantly. To balance their contributions, K-FAC uses pi-adjusted damping:
$$\pi = \sqrt{\frac{\text{tr}(\mathbf{A}) / d_{in}}{\text{tr}(\mathbf{G}) / d_{out}}}$$
Then apply: $$\tilde{\mathbf{A}} = \mathbf{A} + \lambda \pi \mathbf{I}, \quad \tilde{\mathbf{G}} = \mathbf{G} + \frac{\lambda}{\pi} \mathbf{I}$$
This ensures the damping contribution is balanced across factors, preventing one factor's damping from dominating.
Too little damping: Updates can be enormous, causing divergence. The curvature estimate is trusted too much.
Too much damping: Reverts to gradient descent, losing second-order benefits. The curvature is ignored.
Adaptive damping strategies (like Levenberg-Marquardt) adjust $\lambda$ based on whether the update actually reduces the loss.
An alternative to damping is the trust region approach. Instead of regularizing the curvature, we constrain the step size:
$$\min_\delta \quad \mathbf{g}^T \delta + \frac{1}{2} \delta^T \mathbf{F} \delta \quad \text{s.t.} \quad \delta^T \mathbf{F} \delta \leq \Delta^2$$
Solving via Lagrange multipliers yields the same form as damping, but with $\lambda$ chosen to satisfy the constraint.
Adaptive trust region: Adjust $\Delta$ based on the actual vs. predicted improvement:
If $\rho > 0.75$: Increase $\Delta$ (model is accurate) If $\rho < 0.25$: Decrease $\Delta$ (model is inaccurate) If $\rho < 0$: Reject the step (made things worse)
This creates a self-correcting system that automatically finds the right level of trust in the curvature estimate.
Convolutional layers have additional structure that K-FAC can exploit. The key insight is treating each spatial location as a separate "sample" when computing statistics.
A convolutional layer has:
At each spatial location $(h, w)$, the convolution extracts a patch: $$\mathbf{a}{h,w} \in \mathbb{R}^{C{in} \times k_h \times k_w}$$
And applies the kernel: $$\mathbf{s}{h,w} = \text{vec}(\mathbf{K})^T \text{vec}(\mathbf{a}{h,w})$$
We can view each patch as an independent sample and average over all patches: $$\mathbf{A} = \frac{1}{|\Omega|} \sum_{(h,w) \in \Omega} \text{vec}(\mathbf{a}{h,w}) \text{vec}(\mathbf{a}{h,w})^T$$
where $\Omega$ is the set of all spatial locations.
This treats spatial locations as i.i.d., which is an approximation (nearby patches are correlated), but works well in practice.
K-FAC for convolutions assumes patches at different spatial locations are independent. This is similar to the batch normalization assumption and is reasonable for diverse natural images. For highly structured inputs (like repeated patterns), this approximation degrades.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
import torchimport torch.nn as nnimport torch.nn.functional as F def compute_kfac_cov_conv( activation: torch.Tensor, # [batch, C_in, H, W] kernel_size: Tuple[int, int], stride: int = 1, padding: int = 0) -> torch.Tensor: """ Compute K-FAC activation covariance for a convolutional layer. Extracts patches and computes covariance treating each patch as a sample. """ batch_size, C_in, H, W = activation.shape k_h, k_w = kernel_size # Extract patches using unfold # Result: [batch, C_in * k_h * k_w, num_patches] patches = F.unfold(activation, kernel_size, padding=padding, stride=stride) # Reshape to [num_samples, patch_size] where num_samples = batch * num_patches patches = patches.permute(0, 2, 1) # [batch, num_patches, patch_size] patches = patches.reshape(-1, patches.shape[-1]) # [total_patches, patch_size] # Add bias term (column of ones) ones = torch.ones(patches.shape[0], 1, device=patches.device) patches = torch.cat([patches, ones], dim=1) # Compute covariance num_samples = patches.shape[0] A = (patches.T @ patches) / num_samples return A def compute_kfac_grad_cov_conv( grad_output: torch.Tensor # [batch, C_out, H', W']) -> torch.Tensor: """ Compute K-FAC gradient covariance for a convolutional layer. Treats each spatial location as a sample. """ batch_size, C_out, H, W = grad_output.shape # Reshape: treat each spatial location as a sample # [batch, C_out, H, W] -> [batch * H * W, C_out] grad_flat = grad_output.permute(0, 2, 3, 1).reshape(-1, C_out) # Compute covariance num_samples = grad_flat.shape[0] G = (grad_flat.T @ grad_flat) / num_samples return G class KFACConv2d: """K-FAC handler for a single Conv2d layer.""" def __init__(self, layer: nn.Conv2d, damping: float = 1e-3): self.layer = layer self.damping = damping # Covariance matrices self.A = None # Input covariance self.G = None # Gradient covariance # Cached inverses self.A_inv = None self.G_inv = None def update_covs(self, activation: torch.Tensor, grad_output: torch.Tensor): """Update running covariance estimates.""" A_new = compute_kfac_cov_conv( activation, self.layer.kernel_size, self.layer.stride[0], self.layer.padding[0] ) G_new = compute_kfac_grad_cov_conv(grad_output) # EMA update decay = 0.95 if self.A is None: self.A = A_new self.G = G_new else: self.A = decay * self.A + (1 - decay) * A_new self.G = decay * self.G + (1 - decay) * G_new def update_inverses(self): """Compute inverses with damping.""" # Pi-adjusted damping pi = torch.sqrt(torch.trace(self.A) / torch.trace(self.G) + 1e-8) A_damped = self.A + self.damping * pi * torch.eye( self.A.shape[0], device=self.A.device ) G_damped = self.G + self.damping / pi * torch.eye( self.G.shape[0], device=self.G.device ) self.A_inv = torch.linalg.inv(A_damped) self.G_inv = torch.linalg.inv(G_damped) def compute_natural_gradient(self) -> torch.Tensor: """Compute natural gradient for the conv layer weights.""" # Reshape weight gradient to 2D # [C_out, C_in, k_h, k_w] -> [C_out, C_in * k_h * k_w] grad = self.layer.weight.grad grad_2d = grad.reshape(grad.shape[0], -1) # Include bias if present if self.layer.bias is not None: grad_2d = torch.cat([grad_2d, self.layer.bias.grad.unsqueeze(1)], dim=1) # Natural gradient: G^{-1} @ grad @ A^{-1} nat_grad = self.G_inv @ grad_2d @ self.A_inv return nat_gradK-FAC's per-layer structure makes it naturally amenable to distributed training. Each layer's statistics can be computed and inverted independently, enabling efficient parallelization.
In data-parallel training, each worker processes different data. K-FAC statistics must be synchronized:
Gradient Aggregation: Same as SGD—all-reduce the gradients across workers.
Covariance Aggregation: Covariances from different workers must be averaged: $$\mathbf{A}{global} = \frac{1}{K} \sum{k=1}^K \mathbf{A}^{(k)}$$
This requires all-reducing the covariance matrices, which are $d \times d$ per factor.
Inverse Computation: Can be distributed across layers. Each worker computes inverses for a subset of layers and broadcasts results.
K-FAC's block-diagonal structure enables layer-wise parallelism:
This achieves near-linear speedup with number of workers (up to the number of layers).
| Operation | Volume | Frequency |
|---|---|---|
| Gradient all-reduce | $\mathcal{O}(n)$ | Every step |
| Covariance all-reduce | $\mathcal{O}(\sum_l (d_{in}^{(l)})^2 + (d_{out}^{(l)})^2)$ | Every step |
| Inverse broadcast | $\mathcal{O}(\sum_l (d_{in}^{(l)})^2 + (d_{out}^{(l)})^2)$ | Every $k$ steps |
K-FAC particularly shines in large-batch training, where:
Studies have shown K-FAC can train ResNets on ImageNet in fewer iterations (and competitive wall-clock time) compared to SGD with momentum.
Practical distributed K-FAC implementations optimize for:
For very large models (billions of parameters), even K-FAC's reduced complexity can be challenging. Recent work explores further approximations: (1) K-FAC with low-rank updates for the inverses, (2) Shampoo which stores matrix roots instead of inverses, (3) Online K-FAC which uses streaming updates. These trade approximation quality for scalability.
K-FAC represents the most successful practical realization of natural gradient descent for deep learning. By exploiting the Kronecker structure inherent in neural network layers, it achieves the benefits of second-order optimization at a computational cost that scales linearly with parameter count.
We have now covered the major second-order optimization methods: Newton's method, Hessian-free optimization, natural gradient descent, and K-FAC. The final page of this module examines the practical limitations of these methods—when they work, when they don't, and how to choose between first-order and second-order approaches in practice.
You now have a comprehensive understanding of K-FAC—from the mathematical foundations of Kronecker products to practical implementation details for both fully-connected and convolutional layers. K-FAC exemplifies how exploiting problem structure can make theoretically elegant algorithms practically feasible.