Loading content...
Newton's method promises quadratic convergence but demands an impossible price: $\mathcal{O}(n^2)$ storage and $\mathcal{O}(n^3)$ computation for the Hessian. For neural networks with millions of parameters, this is utterly infeasible. Yet, the promise of faster convergence through curvature information remains tantalizing.
Hessian-free optimization resolves this impasse with an elegant insight: we never actually need the Hessian matrix itself—we only need to compute its product with vectors. By combining this observation with iterative linear solvers, we can approximate Newton's method at a cost comparable to first-order methods.
This approach, pioneered for deep learning by James Martens in his seminal 2010 paper, demonstrated that second-order methods could train deep networks that were previously considered intractable. It represented a breakthrough in our understanding of how curvature information could be exploited for deep learning.
By the end of this page, you will: (1) Understand how Hessian-vector products can be computed efficiently using automatic differentiation, (2) Master the conjugate gradient algorithm for solving linear systems iteratively, (3) Learn how to combine these techniques into Hessian-free optimization, (4) Appreciate the damping and regularization strategies essential for stability, and (5) Understand the practical considerations for applying HF optimization to neural networks.
The key enabler of Hessian-free optimization is the ability to compute $\mathbf{H}\mathbf{v}$—the Hessian-vector product—without forming the Hessian. This technique, known as Pearlmutter's method (1994), reduces the cost from $\mathcal{O}(n^2)$ to $\mathcal{O}(n)$.
Recall that the Hessian is the Jacobian of the gradient: $$\mathbf{H} = \nabla^2 f = \frac{\partial}{\partial \boldsymbol{\theta}} (\nabla f)$$
The product $\mathbf{H}\mathbf{v}$ is the directional derivative of the gradient in direction $\mathbf{v}$: $$\mathbf{H}\mathbf{v} = \lim_{\epsilon \to 0} \frac{\nabla f(\boldsymbol{\theta} + \epsilon \mathbf{v}) - \nabla f(\boldsymbol{\theta})}{\epsilon}$$
This directional derivative can be computed exactly using automatic differentiation, without taking limits or finite differences.
In the literature, this operation is often called the R-operator or R{·}. For any expression $E(\boldsymbol{\theta})$, we define $R{E} = \lim_{\epsilon \to 0} \frac{E(\boldsymbol{\theta} + \epsilon \mathbf{v}) - E(\boldsymbol{\theta})}{\epsilon}$. This operator obeys the same differentiation rules as regular derivatives: linearity, product rule, chain rule. The Hessian-vector product is then $\mathbf{H}\mathbf{v} = R{\nabla f}$.
Modern deep learning frameworks support this directly through their automatic differentiation capabilities. The procedure is:
The key insight is that step 4 involves differentiating through the gradient computation itself. The gradient $\mathbf{g}$ depends on $\boldsymbol{\theta}$, so $\nabla_\theta (\mathbf{g}^T \mathbf{v}) = \nabla_\theta \sum_i g_i v_i = \sum_i v_i \nabla_\theta g_i = \mathbf{H}\mathbf{v}$.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
import torchimport torch.nn as nn def hessian_vector_product(loss_fn, model, x, y, v, damping=0.0): """ Compute Hessian-vector product H @ v efficiently using double backprop. Args: loss_fn: Loss function model: Neural network x, y: Input data and targets v: Vector to multiply with Hessian (list of tensors matching params) damping: Tikhonov damping factor (adds damping * v to result) Returns: Hv: List of tensors, the Hessian-vector product """ # Forward pass output = model(x) loss = loss_fn(output, y) # First backward pass: compute gradient, retain graph for second backward params = list(model.parameters()) grads = torch.autograd.grad(loss, params, create_graph=True) # Compute g^T @ v (dot product of gradient and vector) grad_v = sum( (g * v_i).sum() for g, v_i in zip(grads, v) ) # Second backward pass: differentiate the dot product # This gives us H @ v Hv = torch.autograd.grad(grad_v, params, retain_graph=True) # Apply damping: (H + damping * I) @ v if damping > 0: Hv = tuple(hv + damping * v_i for hv, v_i in zip(Hv, v)) return Hv def hessian_vector_product_flat(loss_fn, model, x, y, v_flat, damping=0.0): """ Same as above, but with flattened parameter/vector representation. More convenient for conjugate gradient implementation. """ # Convert flat vector to list of parameter-shaped tensors v = [] offset = 0 for p in model.parameters(): numel = p.numel() v.append(v_flat[offset:offset + numel].view(p.shape)) offset += numel # Compute Hv Hv = hessian_vector_product(loss_fn, model, x, y, v, damping) # Flatten result return torch.cat([hv.flatten() for hv in Hv]) # Demonstration: verify against finite differencesdef verify_hessian_vector_product(): """Verify Hv computation against finite difference approximation.""" # Small model for verification model = nn.Sequential( nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 2) ) loss_fn = nn.CrossEntropyLoss() # Random data x = torch.randn(8, 5) y = torch.randint(0, 2, (8,)) # Random direction v = [torch.randn_like(p) for p in model.parameters()] # Compute Hv via double backprop Hv = hessian_vector_product(loss_fn, model, x, y, v) Hv_flat = torch.cat([hv.flatten() for hv in Hv]) # Compute Hv via finite differences def get_gradient(model): output = model(x) loss = loss_fn(output, y) grads = torch.autograd.grad(loss, model.parameters()) return torch.cat([g.flatten() for g in grads]) epsilon = 1e-5 # Perturb parameters in direction v with torch.no_grad(): for p, v_i in zip(model.parameters(), v): p.add_(epsilon * v_i) grad_plus = get_gradient(model) with torch.no_grad(): for p, v_i in zip(model.parameters(), v): p.sub_(2 * epsilon * v_i) grad_minus = get_gradient(model) # Restore parameters with torch.no_grad(): for p, v_i in zip(model.parameters(), v): p.add_(epsilon * v_i) Hv_fd = (grad_plus - grad_minus) / (2 * epsilon) # Compare error = torch.norm(Hv_flat - Hv_fd) / torch.norm(Hv_fd) print(f"Relative error: {error:.2e}") assert error < 1e-4, "Hessian-vector product verification failed!" print("Verification passed!") verify_hessian_vector_product()The Hessian-vector product requires:
Total: $\mathcal{O}(n)$ operations—the same order as computing the gradient itself. Memory overhead is also minimal, requiring storage for the computation graph (which is needed for the second backward pass).
Compare this with forming the full Hessian:
The savings compound when we use iterative methods that require multiple Hv products—each costs $\mathcal{O}(n)$, and we typically need $\ll n$ products for convergence.
Newton's method requires solving the linear system $\mathbf{H}\mathbf{d} = -\mathbf{g}$ for the Newton direction $\mathbf{d}$. Direct solution via matrix inversion costs $\mathcal{O}(n^3)$, but we can solve it iteratively using only matrix-vector products.
The Conjugate Gradient (CG) algorithm is ideal for this purpose. It's an iterative method for solving $\mathbf{A}\mathbf{x} = \mathbf{b}$ where $\mathbf{A}$ is symmetric positive definite. Crucially, CG only requires computing $\mathbf{A}\mathbf{v}$ for various vectors $\mathbf{v}$—exactly what we can do efficiently with the Hessian.
CG builds a sequence of conjugate directions ${\mathbf{p}_0, \mathbf{p}_1, \ldots}$ that satisfy: $$\mathbf{p}_i^T \mathbf{A} \mathbf{p}_j = 0 \quad \text{for } i \neq j$$
These directions are "orthogonal" with respect to the inner product defined by $\mathbf{A}$. The magic of CG is that if we minimize the objective $\frac{1}{2}\mathbf{x}^T\mathbf{A}\mathbf{x} - \mathbf{b}^T\mathbf{x}$ (whose solution is $\mathbf{A}^{-1}\mathbf{b}$) along each conjugate direction, we reach the exact solution in at most $n$ steps.
More importantly, CG typically converges to a good approximate solution in far fewer than $n$ steps, especially when the eigenvalues of $\mathbf{A}$ are clustered.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
import torch def conjugate_gradient( Av_fn, # Function computing A @ v b, # Right-hand side vector x0=None, # Initial guess (default: zero) max_iter=100, # Maximum number of iterations tol=1e-6, # Convergence tolerance (relative residual) preconditioner=None # Optional preconditioner function M^{-1} @ v): """ Conjugate Gradient method for solving Ax = b. This implementation is suitable for Hessian-free optimization where Av_fn computes the Hessian-vector product. Args: Av_fn: Function that takes vector v and returns A @ v b: Right-hand side vector (the negative gradient) x0: Initial guess for solution max_iter: Maximum CG iterations tol: Convergence tolerance on ||r|| / ||b|| preconditioner: Optional function computing M^{-1} @ v Returns: x: Approximate solution to Ax = b residuals: List of residual norms at each iteration """ # Initialize n = b.numel() x = torch.zeros_like(b) if x0 is None else x0.clone() # Initial residual: r = b - Ax r = b - Av_fn(x) # Apply preconditioner if provided if preconditioner is not None: z = preconditioner(r) else: z = r.clone() p = z.clone() # Initial search direction rz = torch.dot(r, z) # r^T z (or r^T r if no preconditioner) b_norm = torch.norm(b) if b_norm < 1e-12: return x, [0.0] residuals = [torch.norm(r).item()] for i in range(max_iter): # Compute A @ p Ap = Av_fn(p) # Step size pAp = torch.dot(p, Ap) if pAp < 1e-12: # Detect non-positive-definiteness print(f"Warning: CG detected non-positive curvature at iteration {i}") break alpha = rz / pAp # Update solution and residual x = x + alpha * p r = r - alpha * Ap # Check convergence r_norm = torch.norm(r) residuals.append(r_norm.item()) if r_norm / b_norm < tol: break # Apply preconditioner to new residual if preconditioner is not None: z = preconditioner(r) else: z = r # Update search direction (Polak-Ribière variant for stability) rz_new = torch.dot(r, z) beta = rz_new / rz p = z + beta * p rz = rz_new return x, residuals # Example: visualize CG convergence on a quadraticdef demo_cg_convergence(): """Demonstrate CG convergence on a simple quadratic problem.""" import matplotlib.pyplot as plt # Create a symmetric positive definite matrix with various condition numbers n = 100 # Well-conditioned case eigvals_good = torch.linspace(1.0, 10.0, n) Q = torch.randn(n, n) Q, _ = torch.linalg.qr(Q) # Orthogonal matrix A_good = Q @ torch.diag(eigvals_good) @ Q.T # Ill-conditioned case eigvals_bad = torch.logspace(-3, 1, n) # Condition number ~10000 A_bad = Q @ torch.diag(eigvals_bad) @ Q.T # Right-hand side x_true = torch.randn(n) b_good = A_good @ x_true b_bad = A_bad @ x_true # Run CG _, residuals_good = conjugate_gradient( lambda v: A_good @ v, b_good, max_iter=n ) _, residuals_bad = conjugate_gradient( lambda v: A_bad @ v, b_bad, max_iter=n ) print(f"Well-conditioned (κ=10): converged in {len(residuals_good)} iterations") print(f"Ill-conditioned (κ=10000): converged in {len(residuals_bad)} iterations") demo_cg_convergence()CG's convergence depends on the condition number $\kappa = \lambda_{max}/\lambda_{min}$ of the matrix $\mathbf{A}$:
$$|\mathbf{x}k - \mathbf{x}^*|\mathbf{A} \leq 2 \left( \frac{\sqrt{\kappa} - 1}{\sqrt{\kappa} + 1} \right)^k |\mathbf{x}0 - \mathbf{x}^*|\mathbf{A}$$
where $|\mathbf{z}|_\mathbf{A} = \sqrt{\mathbf{z}^T\mathbf{A}\mathbf{z}}$ is the $\mathbf{A}$-norm.
Key observations:
For the neural network Hessian, condition numbers can be enormous (often $> 10^6$), making raw CG slow. This is where preconditioning and damping become essential.
Other iterative solvers exist (Jacobi, Gauss-Seidel, GMRES), but CG is particularly suited for Hessian-free optimization because: (1) It minimizes the quadratic objective at each step, (2) It has optimal convergence for symmetric matrices, (3) It only requires matrix-vector products, and (4) Each iteration generates a descent direction, so early termination still produces improvement.
Now we can assemble the pieces into the complete Hessian-Free (HF) optimization algorithm. The core idea is simple: at each iteration, use CG with Hessian-vector products to approximately solve for the Newton direction.
At each iteration $k$:
The damping term $\lambda \mathbf{I}$ serves multiple purposes:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
import torchimport torch.nn as nnfrom typing import Callable, Optional, List, Tuple class HessianFreeOptimizer: """ Hessian-Free (HF) optimizer for neural networks. Uses Conjugate Gradient with Hessian-vector products to compute approximate Newton directions efficiently. """ def __init__( self, model: nn.Module, damping: float = 1.0, cg_max_iter: int = 100, cg_tol: float = 1e-4, initial_cg_solution: bool = True, # Warm start CG ): self.model = model self.damping = damping self.cg_max_iter = cg_max_iter self.cg_tol = cg_tol self.initial_cg_solution = initial_cg_solution # Store previous CG solution for warm starting self._prev_direction = None def _flatten_params(self) -> torch.Tensor: """Flatten all model parameters into a single vector.""" return torch.cat([p.flatten() for p in self.model.parameters()]) def _unflatten_params(self, flat: torch.Tensor) -> List[torch.Tensor]: """Convert flat vector back to list of parameter-shaped tensors.""" result = [] offset = 0 for p in self.model.parameters(): numel = p.numel() result.append(flat[offset:offset + numel].view(p.shape)) offset += numel return result def _compute_gradient( self, loss_fn: Callable, x: torch.Tensor, y: torch.Tensor ) -> Tuple[float, torch.Tensor]: """Compute loss and gradient.""" output = self.model(x) loss = loss_fn(output, y) grads = torch.autograd.grad(loss, self.model.parameters()) grad_flat = torch.cat([g.flatten() for g in grads]) return loss.item(), grad_flat def _hessian_vector_product( self, loss_fn: Callable, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: """Compute (H + damping * I) @ v.""" output = self.model(x) loss = loss_fn(output, y) params = list(self.model.parameters()) grads = torch.autograd.grad(loss, params, create_graph=True) grad_flat = torch.cat([g.flatten() for g in grads]) # Compute g^T @ v grad_v = torch.dot(grad_flat, v) # Compute H @ v via second backward pass Hv = torch.autograd.grad(grad_v, params, retain_graph=True) Hv_flat = torch.cat([hv.flatten() for hv in Hv]) # Add damping: (H + λI) @ v return Hv_flat + self.damping * v def _conjugate_gradient( self, Av_fn: Callable, b: torch.Tensor, x0: Optional[torch.Tensor] = None ) -> torch.Tensor: """Solve Ax = b using Conjugate Gradient.""" x = torch.zeros_like(b) if x0 is None else x0.clone() r = b - Av_fn(x) p = r.clone() rr = torch.dot(r, r) b_norm = torch.norm(b) for i in range(self.cg_max_iter): Ap = Av_fn(p) pAp = torch.dot(p, Ap) if pAp <= 0: # Negative curvature if i == 0: return b / self.damping # Fall back to scaled gradient break alpha = rr / pAp x = x + alpha * p r = r - alpha * Ap rr_new = torch.dot(r, r) if torch.sqrt(rr_new) / b_norm < self.cg_tol: break beta = rr_new / rr p = r + beta * p rr = rr_new return x def _line_search( self, loss_fn: Callable, x: torch.Tensor, y: torch.Tensor, direction: List[torch.Tensor], current_loss: float, c1: float = 1e-4 ) -> float: """Simple backtracking line search with Armijo condition.""" alpha = 1.0 # Save original parameters original_params = [p.clone() for p in self.model.parameters()] for _ in range(20): # Max line search iterations # Try step with torch.no_grad(): for p, orig, d in zip( self.model.parameters(), original_params, direction ): p.copy_(orig + alpha * d) output = self.model(x) new_loss = loss_fn(output, y).item() # Armijo condition: sufficient decrease # (simplified: just check that loss decreases) if new_loss < current_loss: return alpha alpha *= 0.5 # Restore original if line search fails with torch.no_grad(): for p, orig in zip(self.model.parameters(), original_params): p.copy_(orig) return 0.0 def step( self, loss_fn: Callable, x: torch.Tensor, y: torch.Tensor ) -> float: """ Perform one HF optimization step. Returns: loss: Loss value after the step """ # Compute gradient current_loss, grad = self._compute_gradient(loss_fn, x, y) # Define Hessian-vector product function def Hv_fn(v): return self._hessian_vector_product(loss_fn, x, y, v) # Solve for Newton direction using CG # Warm start from previous direction if available x0 = self._prev_direction if self.initial_cg_solution else None direction_flat = self._conjugate_gradient(Hv_fn, -grad, x0) # Store for next iteration self._prev_direction = direction_flat.detach() # Convert to list of tensors direction = self._unflatten_params(direction_flat) # Line search alpha = self._line_search( loss_fn, x, y, direction, current_loss ) # Compute new loss if alpha > 0: output = self.model(x) return loss_fn(output, y).item() return current_loss # Training loop exampledef train_with_hessian_free(model, train_loader, epochs=10): """Example training loop using Hessian-Free optimizer.""" loss_fn = nn.CrossEntropyLoss() optimizer = HessianFreeOptimizer( model, damping=1.0, cg_max_iter=50 ) for epoch in range(epochs): total_loss = 0.0 n_batches = 0 for x, y in train_loader: loss = optimizer.step(loss_fn, x, y) total_loss += loss n_batches += 1 avg_loss = total_loss / n_batches print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")The damping parameter $\lambda$ is critical for Hessian-free optimization's success. Choosing it poorly leads to divergence (too small) or slow convergence (too large). We need adaptive strategies.
One approach adjusts damping based on the quality of the quadratic approximation:
Compute the reduction ratio: $$\rho = \frac{f(\boldsymbol{\theta}_k) - f(\boldsymbol{\theta}_k + \mathbf{d})}{m_k(\mathbf{0}) - m_k(\mathbf{d})}$$ where $m_k$ is the quadratic model
Adjust damping:
This creates a self-correcting system: when the quadratic approximation is good, we trust it more (smaller damping, more Newton-like); when it's poor, we trust it less (larger damping, more gradient-descent-like).
For recurrent networks, Martens (2011) introduced structural damping: adding a penalty on hidden state changes rather than just parameter changes. This is particularly effective for sequence models where the effect of a parameter change propagates through many timesteps. The damping term becomes $\lambda(|\Delta \boldsymbol{\theta}|^2 + \nu \sum_t |\Delta \mathbf{h}_t|^2)$, penalizing both parameter and hidden state perturbations.
Instead of the true Hessian, we often use the Gauss-Newton (GN) approximation:
For a loss of the form $f(\boldsymbol{\theta}) = \frac{1}{2}|\mathbf{r}(\boldsymbol{\theta})|^2$ (sum of squared residuals), the Hessian is:
$$\mathbf{H} = \mathbf{J}^T\mathbf{J} + \sum_i r_i \nabla^2 r_i$$
where $\mathbf{J}$ is the Jacobian of the residuals. The GN approximation drops the second term:
$$\mathbf{H}_{GN} = \mathbf{J}^T\mathbf{J}$$
This approximation:
For classification with cross-entropy, the analogous approximation uses the generalized Gauss-Newton matrix, which is equivalent to the Fisher information matrix (as we'll see in the natural gradient discussion).
| Matrix | Always PSD? | Cost to compute Hv | When appropriate |
|---|---|---|---|
| True Hessian $\mathbf{H}$ | No | $\mathcal{O}(n)$ | Convex problems, fine-tuning |
| Gauss-Newton $\mathbf{J}^T\mathbf{J}$ | Yes | $\mathcal{O}(n)$ | Least squares, near optimum |
| Fisher $\mathbf{F}$ | Yes | $\mathcal{O}(n)$ | Probabilistic models, classification |
| Damped Hessian $\mathbf{H} + \lambda\mathbf{I}$ | If $\lambda$ large enough | $\mathcal{O}(n)$ | General, far from optimum |
Implementing Hessian-free optimization for modern deep networks requires careful attention to several practical issues.
In stochastic settings, both gradient and Hessian-vector products are estimated from mini-batches:
Option 1: Same batch for gradient and Hv
Option 2: Different batches
Option 3: Larger batch for Hv
Empirically, using the same batch works well, especially when batches are reasonably large (hundreds to thousands of examples).
A single HF iteration is more expensive than an SGD iteration:
If CG takes 50 iterations, HF costs ~150x more per iteration than SGD. But if HF converges in 10x fewer iterations, it's only ~15x more expensive overall.
The key question is whether fewer, higher-quality updates (HF) beat many cheap updates (SGD). The answer depends on:
HF requires retaining the computation graph for the second backward pass in Hessian-vector products. This roughly doubles memory usage compared to standard backpropagation. For very large models, memory may become the binding constraint before computation. Gradient checkpointing can help here.
Hessian-free optimization had a profound impact on deep learning history. Its success in training deep networks challenged the prevailing wisdom of the late 2000s.
Before 2010, training very deep networks was considered nearly impossible. Gradient descent struggled with vanishing gradients and ill-conditioning. Networks with more than 3-4 layers often failed to train at all.
Martens's 2010 paper "Deep learning via Hessian-free optimization" demonstrated that:
This work, alongside concurrent advances in ReLU activations and dropout, helped spark the modern deep learning revolution.
Despite its historical importance, HF is rarely used in production today. Several factors explain this:
1. Adam and variants: Adaptive optimizers like Adam capture some curvature information at lower cost. They're "good enough" for most problems.
2. Architecture innovations: ResNets, normalization layers, and transformers mitigate ill-conditioning, reducing the need for sophisticated optimization.
3. Massive scale: Modern LLMs have hundreds of billions of parameters. Even Hv products become expensive at this scale, and the CG iterations don't parallelize as well as data-parallel SGD.
4. Hyperparameter simplicity: SGD with momentum and a learning rate schedule is simple and well-understood. HF introduces more hyperparameters (damping, CG iterations, preconditioning).
Hessian-free methods remain valuable in specific scenarios:
Hessian-free optimization represents a practical realization of second-order methods for deep learning, overcoming the computational barriers of Newton's method through clever algorithmic design.
Hessian-free optimization uses the Hessian as its curvature metric, but it's not the only choice. The next page explores natural gradient descent, which uses the Fisher information matrix instead. This leads to updates that are invariant to reparameterization—a property with deep connections to information geometry and significant practical implications for neural network training.
You now understand how Hessian-free optimization bridges the gap between Newton's theoretical power and practical feasibility. The key insight—using efficient Hessian-vector products with iterative solvers—opens the door to exploiting curvature information at scale.