Loading learning content...
Consider a neural network with 100 million parameters and a scalar loss function. The full Jacobian matrix relating inputs to loss would have 100 million entries—far too large to compute or store. Yet backpropagation computes the gradient (equivalent to a row of this Jacobian) in roughly the same time as a single forward pass. How?
The answer lies in a beautiful mathematical insight: we never need the full Jacobian. What we need is the product of the Jacobian with a specific vector. This operation—the vector-Jacobian product (VJP) for reverse mode and Jacobian-vector product (JVP) for forward mode—can be computed far more efficiently than the full Jacobian.
In this page, we develop a deep understanding of JVPs and VJPs, the computational primitives at the heart of modern automatic differentiation. This knowledge reveals how frameworks like PyTorch, TensorFlow, and JAX achieve their remarkable efficiency.
By the end of this page, you will understand: (1) The difference between JVPs and VJPs, (2) Why VJPs are used for backpropagation, (3) How to compute JVPs/VJPs without forming full Jacobians, (4) The connection to forward and reverse mode autodiff, and (5) How modern frameworks implement these operations.
Before understanding Jacobian-vector products, we must fully appreciate the Jacobian matrix itself—what it represents and why directly computing it is often impractical.
Definition:
For a function $\mathbf{f}: \mathbb{R}^n \to \mathbb{R}^m$, the Jacobian matrix $\mathbf{J} \in \mathbb{R}^{m \times n}$ contains all first-order partial derivatives:
$$\mathbf{J} = \begin{bmatrix} \frac{\partial f_1}{\partial x_1} & \frac{\partial f_1}{\partial x_2} & \cdots & \frac{\partial f_1}{\partial x_n} \ \frac{\partial f_2}{\partial x_1} & \frac{\partial f_2}{\partial x_2} & \cdots & \frac{\partial f_2}{\partial x_n} \ \vdots & \vdots & \ddots & \vdots \ \frac{\partial f_m}{\partial x_1} & \frac{\partial f_m}{\partial x_2} & \cdots & \frac{\partial f_m}{\partial x_n} \end{bmatrix}$$
Row Interpretation: Row $i$ contains the gradient of output $f_i$ with respect to all inputs—it tells us how the $i$-th output changes with each input.
Column Interpretation: Column $j$ contains the partial derivatives of all outputs with respect to input $x_j$—it tells us how all outputs change when we perturb $x_j$.
The Scale Problem:
In neural networks, the Jacobian can be astronomically large:
| Component | Typical Size | Jacobian Entries |
|---|---|---|
| Single dense layer (1024 → 512) | 524,288 weights | 524,288 × 512 = 268M |
| ResNet-50 | 25M parameters | 25M × batch_size × features |
| GPT-3 | 175B parameters | 175B × sequence_length |
Explicitly computing and storing these Jacobians would require petabytes of memory and astronomical compute. The key insight is that we don't need the full Jacobian—we only need its product with specific vectors.
Rather than viewing the Jacobian as a matrix to be stored, view it as a linear operator: a 'black box' that takes a vector input and produces a vector output. JVPs and VJPs are methods for applying this operator without ever materializing the full matrix.
A Jacobian-vector product (JVP) computes $\mathbf{J} \cdot \mathbf{v}$ where $\mathbf{J}$ is the $m \times n$ Jacobian and $\mathbf{v} \in \mathbb{R}^n$ is an input tangent vector.
What JVP Computes:
$$\text{JVP}(\mathbf{f}, \mathbf{x}, \mathbf{v}) = \mathbf{J}f(\mathbf{x}) \cdot \mathbf{v} = \sum{j=1}^{n} v_j \frac{\partial \mathbf{f}}{\partial x_j}$$
This is a directional derivative: it tells us how $\mathbf{f}$ changes when we move in direction $\mathbf{v}$ from point $\mathbf{x}$.
The Output Size is $m$ (same as function output), regardless of input dimension $n$.
Forward Mode Autodiff = JVP Propagation:
In forward mode autodiff, we propagate tangent vectors alongside values:
Key Property: Computing one JVP costs roughly the same as one forward pass. To get the full Jacobian, we'd need $n$ JVPs (one per input dimension)—each with a different basis vector.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
import numpy as np def jvp_example(): """ Demonstrate JVP computation for a simple function: f(x) = [x1^2 + x2, x1 * x2, sin(x1)] The JVP tells us: if we perturb x by v, how does f(x) change? """ # Define the function def f(x): x1, x2 = x return np.array([ x1**2 + x2, x1 * x2, np.sin(x1) ]) # Jacobian of f (analytically computed for verification) def jacobian_f(x): x1, x2 = x return np.array([ [2*x1, 1], # df1/dx1, df1/dx2 [x2, x1], # df2/dx1, df2/dx2 [np.cos(x1), 0] # df3/dx1, df3/dx2 ]) # JVP computed via Jacobian multiplication def jvp_explicit(x, v): """Compute J @ v explicitly (expensive for large Jacobians)""" J = jacobian_f(x) return J @ v # JVP computed via dual numbers (forward mode) def jvp_forward_mode(x, v): """ Compute JVP without forming full Jacobian. Use dual numbers: propagate (value, tangent) pairs. For operation y = op(x), the tangent evolves as: tangent_y = d(op)/dx * tangent_x """ x1, x2 = x v1, v2 = v # Operation 1: y1 = x1^2 + x2 # tangent_y1 = 2*x1 * v1 + 1 * v2 tangent_y1 = 2*x1 * v1 + v2 # Operation 2: y2 = x1 * x2 # tangent_y2 = x2 * v1 + x1 * v2 tangent_y2 = x2 * v1 + x1 * v2 # Operation 3: y3 = sin(x1) # tangent_y3 = cos(x1) * v1 tangent_y3 = np.cos(x1) * v1 return np.array([tangent_y1, tangent_y2, tangent_y3]) # Test point and direction x = np.array([1.5, 2.0]) v = np.array([0.3, -0.2]) # Arbitrary direction print("JVP Example: f(x) = [x1^2 + x2, x1*x2, sin(x1)]") print(f"Point x = {x}") print(f"Direction v = {v}") print() # Compute JVP both ways jvp_explicit_result = jvp_explicit(x, v) jvp_forward_result = jvp_forward_mode(x, v) print(f"JVP (explicit J @ v): {jvp_explicit_result}") print(f"JVP (forward mode): {jvp_forward_result}") print(f"Match: {np.allclose(jvp_explicit_result, jvp_forward_result)}") print() # Show how to get full Jacobian columns via JVPs print("Getting Jacobian columns via JVPs:") e1 = np.array([1, 0]) # First standard basis vector e2 = np.array([0, 1]) # Second standard basis vector col1 = jvp_forward_mode(x, e1) # J @ e1 = first column of J col2 = jvp_forward_mode(x, e2) # J @ e2 = second column of J print(f"J[:, 0] via JVP(e1): {col1}") print(f"J[:, 1] via JVP(e2): {col2}") print() print("Full Jacobian (explicit):") print(jacobian_f(x)) jvp_example()A vector-Jacobian product (VJP) computes $\mathbf{v}^T \cdot \mathbf{J}$ (equivalently, $\mathbf{J}^T \cdot \mathbf{v}$) where $\mathbf{v} \in \mathbb{R}^m$ is a "cotangent" or "adjoint" vector.
What VJP Computes:
$$\text{VJP}(\mathbf{f}, \mathbf{x}, \mathbf{v}) = \mathbf{J}f(\mathbf{x})^T \cdot \mathbf{v} = \sum{i=1}^{m} v_i \nabla f_i(\mathbf{x})$$
This is a weighted sum of gradients: it tells us the gradient of a linear combination of outputs (with weights $\mathbf{v}$) with respect to inputs.
The Output Size is $n$ (same as function input), regardless of output dimension $m$.
Critical for Deep Learning: For a loss function $L: \mathbb{R}^n \to \mathbb{R}$, the gradient $\nabla L$ is exactly $\text{VJP}(L, x, 1)$—the VJP with scalar cotangent 1.
Reverse Mode Autodiff = VJP Propagation (Backpropagation!):
In reverse mode, we first complete the forward pass, then propagate cotangent vectors backward:
Key Property: Computing one VJP costs roughly the same as one forward pass. One VJP gives the gradient w.r.t. all inputs at once—exactly what we need for gradient descent!
Neural networks have millions of inputs (parameters) but typically one output (loss). VJP gives us dL/d(all parameters) in one pass. JVP would require millions of passes (one per parameter). This asymmetry is why reverse mode (VJP-based) autodiff dominates deep learning.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
import numpy as np def vjp_example(): """ Demonstrate VJP computation for a simple function: f(x) = [x1^2 + x2, x1 * x2, sin(x1)] The VJP tells us: given a gradient from downstream (v), what is the gradient w.r.t. inputs? """ # Same function as JVP example def f(x): x1, x2 = x return np.array([x1**2 + x2, x1 * x2, np.sin(x1)]) def jacobian_f(x): x1, x2 = x return np.array([ [2*x1, 1], [x2, x1], [np.cos(x1), 0] ]) # VJP computed via Jacobian multiplication (explicit) def vjp_explicit(x, v): """Compute J.T @ v explicitly""" J = jacobian_f(x) return J.T @ v # Note the transpose! # VJP computed via reverse mode (without full Jacobian) def vjp_reverse_mode(x, v): """ Compute VJP by propagating gradient backward through operations. This is exactly what backpropagation does! v = [v1, v2, v3] are the upstream gradients for each output. """ x1, x2 = x v1, v2, v3 = v # Gradient accumulates for each input grad_x1 = 0 grad_x2 = 0 # Operation 1: y1 = x1^2 + x2 # dy1/dx1 = 2*x1, dy1/dx2 = 1 grad_x1 += v1 * 2 * x1 grad_x2 += v1 * 1 # Operation 2: y2 = x1 * x2 # dy2/dx1 = x2, dy2/dx2 = x1 grad_x1 += v2 * x2 grad_x2 += v2 * x1 # Operation 3: y3 = sin(x1) # dy3/dx1 = cos(x1), dy3/dx2 = 0 grad_x1 += v3 * np.cos(x1) # grad_x2 += 0 return np.array([grad_x1, grad_x2]) # Test x = np.array([1.5, 2.0]) v = np.array([1.0, 0.5, 0.3]) # Upstream gradients print("VJP Example: f(x) = [x1^2 + x2, x1*x2, sin(x1)]") print(f"Point x = {x}") print(f"Upstream gradient v = {v}") print() vjp_explicit_result = vjp_explicit(x, v) vjp_reverse_result = vjp_reverse_mode(x, v) print(f"VJP (explicit J.T @ v): {vjp_explicit_result}") print(f"VJP (reverse mode): {vjp_reverse_result}") print(f"Match: {np.allclose(vjp_explicit_result, vjp_reverse_result)}") print() # Demonstrate getting Jacobian rows via VJPs print("Getting Jacobian rows via VJPs:") e1 = np.array([1, 0, 0]) e2 = np.array([0, 1, 0]) e3 = np.array([0, 0, 1]) row1 = vjp_reverse_mode(x, e1) # J.T @ e1 = first row of J (transposed) row2 = vjp_reverse_mode(x, e2) row3 = vjp_reverse_mode(x, e3) print(f"J[0, :] via VJP(e1): {row1}") print(f"J[1, :] via VJP(e2): {row2}") print(f"J[2, :] via VJP(e3): {row3}") print() print("Full Jacobian (explicit):") print(jacobian_f(x)) vjp_example()The choice between JVP (forward mode) and VJP (reverse mode) depends on the shape of the function—specifically, the ratio of inputs to outputs.
Decision Rule:
For neural networks with millions of parameters and a scalar loss, we have $n \gg m = 1$, so VJP wins decisively.
| Aspect | JVP (Forward Mode) | VJP (Reverse Mode) |
|---|---|---|
| Computes | J · v (Jacobian times vector) | Jᵀ · v (Jacobian transpose times vector) |
| Output dimension | m (# of function outputs) | n (# of function inputs) |
| One pass gives | Derivative w.r.t. all outputs for one input direction | Derivative w.r.t. all inputs for one output direction |
| To get full Jacobian | n passes (one per input) | m passes (one per output) |
| Memory during computation | O(1) extra (no caching needed) | O(depth) (must cache forward pass) |
| Best for | f: R^n → R^m with m ≫ n | f: R^n → R^m with n ≫ m |
| Deep learning use case | Jacobian-vector products in second-order methods | Gradient computation (backprop) |
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
import numpy as npimport time def compare_jvp_vjp_efficiency(): """ Compare computational cost of JVP vs VJP for different input/output dimension ratios. """ # Simulate a neural network layer def forward(x, W): """Layer: y = W @ x""" return W @ x def jvp_layer(W, x, v_in): """JVP: tangent_out = W @ tangent_in""" return W @ v_in def vjp_layer(W, x, v_out): """VJP: gradient_in = W.T @ gradient_out""" return W.T @ v_out # Test case 1: Many inputs, few outputs (typical NN) n_in, n_out = 10000, 10 W = np.random.randn(n_out, n_in) x = np.random.randn(n_in) print("Case 1: n_inputs=10000, n_outputs=10 (like neural network)") print("-" * 60) # To get full gradient w.r.t. inputs: # JVP approach: need n_in passes start = time.time() full_jacobian_jvp = np.zeros((n_out, n_in)) for i in range(n_in): e_i = np.zeros(n_in) e_i[i] = 1 full_jacobian_jvp[:, i] = jvp_layer(W, x, e_i) jvp_time = time.time() - start print(f"JVP approach (10000 passes): {jvp_time*1000:.1f} ms") # VJP approach: need n_out passes (or just 1 for scalar loss) start = time.time() full_jacobian_vjp = np.zeros((n_out, n_in)) for i in range(n_out): e_i = np.zeros(n_out) e_i[i] = 1 full_jacobian_vjp[i, :] = vjp_layer(W, x, e_i) vjp_time = time.time() - start print(f"VJP approach (10 passes): {vjp_time*1000:.1f} ms") print(f"VJP speedup: {jvp_time/vjp_time:.1f}x") print() # For scalar loss (n_out = 1), it's even more dramatic print("For scalar loss (n_outputs=1):") start = time.time() grad_vjp = vjp_layer(W, x, np.array([1.0])) # Single scalar upstream vjp_single_time = time.time() - start print(f"VJP (1 pass): {vjp_single_time*1e6:.1f} μs") # JVP would still need n_in passes print(f"JVP would need {n_in} passes") print() print("=> VJP (reverse mode) is optimal for neural network gradients!") compare_jvp_vjp_efficiency()While VJP dominates gradient computation, JVP is valuable for: (1) Computing Hessian-vector products (used in second-order optimization), (2) Efficiently propagating uncertainty, (3) Computing directional derivatives for sensitivity analysis. JAX provides both jvp and vjp as first-class operations.
The power of JVPs and VJPs lies in their ability to be computed locally for each operation, then composed via the chain rule. Each primitive operation needs only its own local JVP/VJP rule.
Local JVP Rule for Operation $y = f(x)$: $$\text{tangent}_y = \mathbf{J}_f \cdot \text{tangent}_x$$
The local Jacobian $\mathbf{J}_f$ is often simple and sparse. We never form it explicitly—we just apply its effect.
Local VJP Rule for Operation $y = f(x)$: $$\text{cotangent}_x = \mathbf{J}_f^T \cdot \text{cotangent}_y$$
Again, we compute the effect of $\mathbf{J}_f^T$ without forming the matrix.
For element-wise operations $y_i = f(x_i)$, the Jacobian is diagonal: $$\mathbf{J} = \text{diag}(f'(x_1), f'(x_2), \ldots, f'(x_n))$$
JVP and VJP reduce to element-wise multiplication:
Examples:
1234567891011121314151617
# VJP for element-wise operationsdef relu_vjp(x, cotangent_y): """VJP for ReLU: cotangent_x = (x > 0) * cotangent_y""" return (x > 0).astype(float) * cotangent_y def sigmoid_vjp(x, cotangent_y): """VJP for sigmoid""" s = 1 / (1 + np.exp(-x)) grad = s * (1 - s) # sigmoid'(x) return grad * cotangent_y def square_vjp(x, cotangent_y): """VJP for y = x^2""" return 2 * x * cotangent_y # Key: These are all just element-wise multiplications!# No matrix formation or storage needed.The beauty of JVP/VJP is how they compose through the chain rule. For a sequence of operations $y = f_n \circ f_{n-1} \circ \cdots \circ f_1(x)$:
JVP Composition (Forward Mode): $$\text{tangent}y = \mathbf{J}{f_n} \cdot \mathbf{J}{f{n-1}} \cdots \mathbf{J}_{f_1} \cdot \text{tangent}_x$$
Computed left-to-right, one JVP at a time during the forward pass.
VJP Composition (Reverse Mode): $$\text{cotangent}x = \mathbf{J}{f_1}^T \cdot \mathbf{J}{f_2}^T \cdots \mathbf{J}{f_n}^T \cdot \text{cotangent}_y$$
Computed right-to-left, one VJP at a time during the backward pass.
This is exactly backpropagation! Each layer's backward pass computes one VJP.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
import numpy as np class AutoDiffLayer: """Base class for a layer that supports forward and VJP""" def forward(self, x): raise NotImplementedError def vjp(self, cotangent_y): """Return (cotangent_x, param_gradients)""" raise NotImplementedError class Linear(AutoDiffLayer): """y = x @ W + b""" def __init__(self, in_features, out_features): self.W = np.random.randn(in_features, out_features) * 0.01 self.b = np.zeros(out_features) self.cache = None def forward(self, x): self.cache = x # Store for backward return x @ self.W + self.b def vjp(self, cotangent_y): x = self.cache cotangent_x = cotangent_y @ self.W.T grad_W = x.T @ cotangent_y grad_b = cotangent_y.sum(axis=0) return cotangent_x, {'W': grad_W, 'b': grad_b} class ReLU(AutoDiffLayer): """y = max(0, x)""" def __init__(self): self.cache = None def forward(self, x): self.cache = x return np.maximum(0, x) def vjp(self, cotangent_y): x = self.cache cotangent_x = cotangent_y * (x > 0) return cotangent_x, {} class MSELoss: """L = mean((y - target)^2)""" def __init__(self): self.cache = None def forward(self, y, target): self.cache = (y, target) return np.mean((y - target) ** 2) def vjp(self): """VJP for loss: returns cotangent for y""" y, target = self.cache # d/dy of mean((y - target)^2) = 2 * (y - target) / n cotangent_y = 2 * (y - target) / y.size return cotangent_y class SimpleNetwork: """Demonstrates VJP composition through a network""" def __init__(self): self.layers = [ Linear(10, 32), ReLU(), Linear(32, 16), ReLU(), Linear(16, 1), ] self.loss_fn = MSELoss() def forward(self, x, target): """Forward pass through all layers""" for layer in self.layers: x = layer.forward(x) loss = self.loss_fn.forward(x, target) return loss def backward(self): """ Backward pass: compose VJPs through all layers (reverse order). This is exactly backpropagation! """ # Start with cotangent from loss cotangent = self.loss_fn.vjp() all_gradients = [] # Propagate through layers in reverse for layer in reversed(self.layers): cotangent, grads = layer.vjp(cotangent) all_gradients.append(grads) return all_gradients[::-1] # Reverse to match layer order # Demonp.random.seed(42)net = SimpleNetwork() # Sample input and targetx = np.random.randn(4, 10) # Batch of 4target = np.random.randn(4, 1) # Forward passloss = net.forward(x, target)print(f"Loss: {loss:.6f}") # Backward pass (composed VJPs)gradients = net.backward() print("\nGradients computed via VJP composition:")for i, (layer, grads) in enumerate(zip(net.layers, gradients)): if grads: print(f" Layer {i} ({layer.__class__.__name__}):") for name, grad in grads.items(): print(f" {name}: shape={grad.shape}, norm={np.linalg.norm(grad):.4f}")Modern deep learning frameworks implement JVP/VJP as their core differentiation primitives. Understanding these APIs deepens your ability to work with and extend these systems.
PyTorch: Uses VJP-based reverse mode. The backward() method computes VJPs through the computation graph. Custom operations define their VJP via the backward method of torch.autograd.Function.
TensorFlow: Similar to PyTorch, uses reverse mode. Gradient tape records operations, then tape.gradient() computes VJPs.
JAX: Provides both JVP and VJP as explicit, composable functions. jax.jvp(f, primals, tangents) computes forward mode; jax.vjp(f, primals) returns a VJP function.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
# PyTorch: VJP via backward()import torch def pytorch_vjp_example(): x = torch.randn(3, requires_grad=True) y = x ** 2 + 2 * x # VJP with cotangent v v = torch.tensor([1.0, 2.0, 3.0]) y.backward(v) # Computes J.T @ v print(f"PyTorch VJP: x.grad = {x.grad}") # Expected: (2*x + 2) * v # JAX: Explicit JVP and VJP APIs# import jax# import jax.numpy as jnp def jax_example(): """ JAX provides both JVP and VJP as first-class operations. (Commented out to avoid dependency - showing API pattern) """ # def f(x): # return jnp.array([x[0]**2 + x[1], x[0] * x[1], jnp.sin(x[0])]) # x = jnp.array([1.5, 2.0]) # # JVP: Forward mode # tangent = jnp.array([0.3, -0.2]) # primals_out, tangents_out = jax.jvp(f, (x,), (tangent,)) # # VJP: Reverse mode # primals_out, vjp_fn = jax.vjp(f, x) # cotangent = jnp.array([1.0, 0.5, 0.3]) # (grad_x,) = vjp_fn(cotangent) print("JAX API pattern:") print(" jax.jvp(f, primals, tangents) -> (outputs, output_tangents)") print(" jax.vjp(f, primals) -> (outputs, vjp_function)") print("") print(" JAX makes JVP/VJP composable and differentiable!") # Custom PyTorch autograd function with VJPclass MySquare(torch.autograd.Function): """ Custom operation with explicit forward and VJP (backward). """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) # Cache for VJP return x ** 2 @staticmethod def backward(ctx, grad_output): # This IS the VJP! x, = ctx.saved_tensors # VJP: J.T @ grad_output = diag(2x) @ grad_output = 2x * grad_output return 2 * x * grad_output # Usagex = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)y = MySquare.apply(x)y.sum().backward()print(f"Custom VJP gradient: {x.grad}") # Should be [2, 4, 6]JAX allows computing VJP of a VJP, which gives Hessian-vector products. This enables efficient second-order optimization without forming the full Hessian. The composability of JVP/VJP is one of JAX's key innovations.
We have developed a comprehensive understanding of Jacobian-vector products and their role in automatic differentiation. These concepts are the computational foundation of all modern deep learning.
Looking Ahead:
With JVP/VJP understood, we're ready to address a practical concern: memory. Computing VJPs requires caching intermediate activations from the forward pass—potentially gigabytes of data for large models. In the next section, we'll explore memory considerations in backpropagation, including strategies for trading computation for memory.
You now understand the computational primitives that power all neural network training. JVPs and VJPs aren't just implementation details—they're the conceptual lens through which to understand automatic differentiation. Next, we'll see how to manage the memory costs of storing intermediate values for the backward pass.