Loading learning content...
The reparameterization trick is one of the most important technical innovations in modern machine learning. Introduced in the VAE paper (Kingma & Welling, 2014), it transformed variational inference from a computationally challenging technique into a practical, scalable method.
This page provides a complete treatment of reparameterization: its mechanics, why it dramatically reduces gradient variance, how to apply it to different distributions, and its limitations.
By the end of this page, you will: (1) Implement reparameterization for Gaussians and other distributions, (2) Understand geometrically why reparameterization reduces variance, (3) Apply the technique in PyTorch/TensorFlow, (4) Know when reparameterization fails and what alternatives exist.
The key insight is simple but profound: express randomness as external to the parameters.
Without Reparameterization: $$\mathbf{z} \sim q_\phi(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\mu_\phi(\mathbf{x}), \sigma_\phi(\mathbf{x})^2)$$
Sampling from this distribution involves calling a random number generator that produces $\mathbf{z}$ directly. The gradient $\nabla_\phi \mathbf{z}$ is undefined—sampling is not differentiable.
With Reparameterization: $$\epsilon \sim \mathcal{N}(0, I)$$ $$\mathbf{z} = \mu_\phi(\mathbf{x}) + \sigma_\phi(\mathbf{x}) \odot \epsilon$$
Now $\mathbf{z}$ is a deterministic, differentiable function of $\phi$ (through $\mu$ and $\sigma$), plus external noise $\epsilon$. We can compute:
$$\nabla_\phi \mathbf{z} = \nabla_\phi \mu_\phi + \epsilon \odot \nabla_\phi \sigma_\phi$$
Think of reparameterization as creating a 'gradient highway' through the sampling operation. Gradients flow from the loss, through the decoder, through z, and into the encoder parameters—all via standard backpropagation.
123456789101112131415161718192021222324252627282930313233343536373839
import torchimport torch.nn as nn class ReparameterizedGaussian(nn.Module): """Gaussian encoder with reparameterization trick.""" def __init__(self, input_dim, latent_dim): super().__init__() self.fc_mu = nn.Linear(input_dim, latent_dim) self.fc_logvar = nn.Linear(input_dim, latent_dim) def forward(self, x): mu = self.fc_mu(x) logvar = self.fc_logvar(x) return mu, logvar def sample(self, mu, logvar): """Sample using reparameterization trick.""" std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) # External noise z = mu + std * eps # Deterministic transformation return z def forward_sample(self, x): """Encode and sample in one call.""" mu, logvar = self.forward(x) z = self.sample(mu, logvar) return z, mu, logvar # Comparison: gradients flow correctlyencoder = ReparameterizedGaussian(784, 32)x = torch.randn(16, 784) # This works - gradients flow through z to encoderz, mu, logvar = encoder.forward_sample(x)loss = z.sum() # Dummy lossloss.backward()print(f"Gradient exists: {encoder.fc_mu.weight.grad is not None}")Let's derive the pathwise gradient formally. We want:
$$\nabla_\phi \mathbb{E}{q\phi(\mathbf{z})}[f(\mathbf{z})]$$
Step 1: Reparameterize
If $\mathbf{z} = g_\phi(\epsilon)$ where $\epsilon \sim p(\epsilon)$ is parameter-independent:
$$\mathbb{E}{q\phi(\mathbf{z})}[f(\mathbf{z})] = \mathbb{E}{p(\epsilon)}[f(g\phi(\epsilon))]$$
Step 2: Exchange Gradient and Expectation
Since $p(\epsilon)$ doesn't depend on $\phi$:
$$\nabla_\phi \mathbb{E}{p(\epsilon)}[f(g\phi(\epsilon))] = \mathbb{E}{p(\epsilon)}[\nabla\phi f(g_\phi(\epsilon))]$$
Step 3: Apply Chain Rule
$$= \mathbb{E}{p(\epsilon)}[\nabla\mathbf{z} f(\mathbf{z})|{\mathbf{z}=g\phi(\epsilon)} \cdot \nabla_\phi g_\phi(\epsilon)]$$
For Gaussian:
$g_\phi(\epsilon) = \mu_\phi + \sigma_\phi \epsilon$, so:
$$\nabla_\phi g_\phi(\epsilon) = \nabla_\phi \mu_\phi + \epsilon \nabla_\phi \sigma_\phi$$
Monte Carlo Estimation:
In practice, we estimate this expectation with samples:
$$\nabla_\phi \mathbb{E}{q\phi}[f(\mathbf{z})] \approx \frac{1}{N}\sum_{i=1}^{N} \nabla_\phi f(g_\phi(\epsilon_i))$$
With automatic differentiation, we simply:
.backward()—gradients flow automatically!For VAEs with a single sample per data point (the common case):
$$\nabla_\phi \mathcal{L} \approx \nabla_\phi [\log p_\theta(\mathbf{x}|g_\phi(\epsilon)) - D_{KL}(q_\phi | p)]$$
For Gaussian q and p, the KL divergence has a closed-form expression: D_KL = 0.5 * Σ(μ² + σ² - log(σ²) - 1). No sampling needed—it's analytically differentiable.
The variance reduction from reparameterization is often dramatic—orders of magnitude in practice. Let's understand why.
Score Function Variance:
$$\text{Var}[f(\mathbf{z}) \nabla_\phi \log q_\phi(\mathbf{z})] = \mathbb{E}[(f(\mathbf{z}) \nabla \log q)^2] - (\mathbb{E}[f(\mathbf{z}) \nabla \log q])^2$$
This depends on $\text{Var}[f(\mathbf{z})]$ and $\text{Var}[\nabla \log q]$—both can be large.
Pathwise Variance:
$$\text{Var}[\nabla_\mathbf{z} f \cdot \nabla_\phi g_\phi] = \mathbb{E}[(\nabla f \cdot \nabla g)^2] - (\mathbb{E}[\nabla f \cdot \nabla g])^2$$
This depends on $\text{Var}[\nabla_\mathbf{z} f]$—often much smaller than $\text{Var}[f]$.
Geometric Intuition:
Imagine a smooth function $f(\mathbf{z})$ over the latent space:
Score function looks at function values at random points, weighted by how parameter changes affect point likelihood. High-value and low-value samples produce opposite gradient signals.
Pathwise gradient looks at function slopes at random points, projected onto the direction parameters move points. The slope is typically more consistent than the value.
Analogy: Estimating the derivative of a smooth curve by:
| Estimator | Relative Variance | Samples Needed for Same Accuracy |
|---|---|---|
| Score Function (raw) | 1000x - 10000x | 1000 - 10000 |
| Score Function + baseline | 10x - 100x | 10 - 100 |
| Reparameterization | 1x (baseline) | 1 |
Not all distributions can be reparameterized. The key requirement: expressible as a differentiable transformation of fixed noise.
| Distribution | Reparameterization | Noise Distribution |
|---|---|---|
| Normal(μ, σ²) | z = μ + σε | ε ~ N(0, 1) |
| LogNormal(μ, σ²) | z = exp(μ + σε) | ε ~ N(0, 1) |
| Exponential(λ) | z = -log(u) / λ | u ~ Uniform(0, 1) |
| Gamma(α, β) [α≥1] | z = (α - 1/3)(1 + ε/√(9α-3))³ / β | ε ~ N(0, 1) (approx) |
| Beta(α, β) | Via Gamma ratio | ε ~ N(0, 1) |
| Dirichlet(α) | Via Gamma normalization | ε ~ N(0, 1)^k |
Non-Reparameterizable Distributions:
Relaxations for Discrete:
For discrete distributions, several relaxations provide approximate reparameterization:
1234567891011121314151617181920212223242526272829303132333435
import torchfrom torch.distributions import Normal, LogNormal, Gamma def reparam_normal(mu, sigma, n_samples=1): """Standard Gaussian reparameterization.""" eps = torch.randn(n_samples, *mu.shape) return mu + sigma * eps def reparam_lognormal(mu, sigma, n_samples=1): """LogNormal via exponentiated Gaussian.""" z_normal = reparam_normal(mu, sigma, n_samples) return torch.exp(z_normal) def reparam_exponential(rate, n_samples=1): """Exponential via inverse CDF.""" u = torch.rand(n_samples, *rate.shape) return -torch.log(u) / rate # PyTorch distributions have rsample() for reparameterized samplingnormal = Normal(torch.tensor([0.0]), torch.tensor([1.0]))z = normal.rsample((10,)) # Reparameterized - gradients flowz_no_grad = normal.sample((10,)) # Not reparameterized # Gumbel-Softmax for categorical (relaxation)def gumbel_softmax(logits, tau=1.0, hard=False): """Reparameterized categorical via Gumbel-Softmax.""" gumbels = -torch.log(-torch.log(torch.rand_like(logits))) y_soft = torch.softmax((logits + gumbels) / tau, dim=-1) if hard: # Straight-through: discrete forward, soft backward index = y_soft.argmax(dim=-1, keepdim=True) y_hard = torch.zeros_like(logits).scatter_(-1, index, 1.0) return y_hard - y_soft.detach() + y_soft return y_softModern frameworks make reparameterization nearly automatic.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
import torchimport torch.nn as nnimport torch.nn.functional as F class VAE(nn.Module): """Complete VAE with reparameterization.""" def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20): super().__init__() # Encoder self.enc_fc1 = nn.Linear(input_dim, hidden_dim) self.enc_fc_mu = nn.Linear(hidden_dim, latent_dim) self.enc_fc_logvar = nn.Linear(hidden_dim, latent_dim) # Decoder self.dec_fc1 = nn.Linear(latent_dim, hidden_dim) self.dec_fc2 = nn.Linear(hidden_dim, input_dim) def encode(self, x): h = F.relu(self.enc_fc1(x)) mu = self.enc_fc_mu(h) logvar = self.enc_fc_logvar(h) return mu, logvar def reparameterize(self, mu, logvar): """The reparameterization trick.""" std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + std * eps def decode(self, z): h = F.relu(self.dec_fc1(z)) return torch.sigmoid(self.dec_fc2(h)) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) x_recon = self.decode(z) return x_recon, mu, logvar def loss_function(self, x, x_recon, mu, logvar): # Reconstruction: binary cross-entropy recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum') # KL divergence: closed form for Gaussian kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return recon_loss + kl_loss, recon_loss, kl_loss # Training loopmodel = VAE()optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for batch in dataloader: x = batch[0].view(-1, 784) x_recon, mu, logvar = model(x) loss, recon, kl = model.loss_function(x, x_recon, mu, logvar) optimizer.zero_grad() loss.backward() # Gradients flow through reparameterization! optimizer.step()PyTorch: Use distribution.rsample() for reparameterized samples. TensorFlow Probability: Use tfd.Distribution.sample() with reparameterization_type=FULLY_REPARAMETERIZED. JAX/Numpyro: Automatic via jax.random with functional transforms.
Despite its power, reparameterization has limitations:
Alternatives and Extensions:
You now have a complete understanding of the reparameterization trick—the technical innovation that made VAEs practical. Next, we'll explore the score function estimator in depth, which remains essential for discrete latents and non-differentiable settings.