Loading learning content...
The central insight enabling modern variational inference is viewing it as an optimization problem. Rather than solving integrals analytically, we apply gradient-based optimization to maximize the ELBO. This perspective transforms VI from a specialized statistical technique into something that fits naturally into the machine learning toolkit.
However, there's a fundamental challenge: the ELBO involves expectations over the variational distribution $q$. How do we compute gradients through expectations? This page addresses this central question, developing the techniques that make large-scale variational inference possible.
By the end of this page, you will understand how VI is optimized in practice, from the foundational score function estimator to the revolutionary reparameterization trick that enabled VAEs and modern deep generative models.
By completing this page, you will: (1) Understand why gradient estimation through expectations is non-trivial, (2) Master the reparameterization trick for continuous latent variables, (3) Understand the score function (REINFORCE) estimator for discrete variables, (4) Learn variance reduction techniques, and (5) Know practical considerations for VI optimization.
To optimize the ELBO with gradient descent, we need:
$$\nabla_\phi \mathcal{L}(\phi) = \nabla_\phi \mathbb{E}{q\phi(\mathbf{z})}[f(\mathbf{z})]$$
where $f(\mathbf{z}) = \log p(\mathbf{x}, \mathbf{z}) - \log q_\phi(\mathbf{z})$ and $\phi$ are the variational parameters.
The difficulty is that $\phi$ appears in two places:
We cannot simply push the gradient inside the expectation:
$$\nabla_\phi \mathbb{E}{q\phi}[f(\mathbf{z})] \neq \mathbb{E}{q\phi}[\nabla_\phi f(\mathbf{z})]$$
The right-hand side ignores how changing $\phi$ changes what distribution we're sampling from. We must account for both effects.
There are two main strategies for estimating $\nabla_\phi \mathbb{E}{q\phi}[f(\mathbf{z})]$:
1. Score Function Estimator (REINFORCE)
Use the log-derivative trick to move the gradient inside the expectation: $$\nabla_\phi \mathbb{E}{q\phi}[f(\mathbf{z})] = \mathbb{E}{q\phi}[f(\mathbf{z}) \nabla_\phi \log q_\phi(\mathbf{z})]$$
2. Reparameterization Trick
Rewrite sampling as a deterministic function of parameters and external noise: $$\mathbf{z} = g_\phi(\boldsymbol{\epsilon}), \quad \boldsymbol{\epsilon} \sim p(\boldsymbol{\epsilon})$$
Then: $$\nabla_\phi \mathbb{E}{q\phi}[f(\mathbf{z})] = \mathbb{E}{p(\boldsymbol{\epsilon})}[\nabla\phi f(g_\phi(\boldsymbol{\epsilon}))]$$
| Property | Score Function | Reparameterization |
|---|---|---|
| Applicability | Any distribution | Reparameterizable only |
| Discrete latents | ✅ Yes | ❌ No (needs relaxations) |
| Continuous latents | ✅ Yes | ✅ Yes |
| Variance | High | Low |
| Bias | Unbiased | Unbiased |
| Requires | $\nabla_\phi \log q_\phi$ | $\nabla_\phi g_\phi$ |
| Common use | RL, discrete VAEs | VAEs, normalizing flows |
The reparameterization trick is perhaps the single most important technical contribution enabling modern variational autoencoders and deep generative models. It provides low-variance gradient estimates by expressing the random sampling process as a deterministic transformation of parameter-free noise.
Instead of sampling directly: $$\mathbf{z} \sim q_\phi(\mathbf{z})$$
We sample noise and transform: $$\boldsymbol{\epsilon} \sim p(\boldsymbol{\epsilon}), \quad \mathbf{z} = g_\phi(\boldsymbol{\epsilon})$$
The distribution $p(\boldsymbol{\epsilon})$ doesn't depend on $\phi$. Now the gradient can flow through $g_\phi$:
$$\nabla_\phi f(\mathbf{z}) = \nabla_\phi f(g_\phi(\boldsymbol{\epsilon})) = \nabla_{\mathbf{z}} f(\mathbf{z}) \cdot \nabla_\phi g_\phi(\boldsymbol{\epsilon})$$
For $q_\phi(\mathbf{z}) = \mathcal{N}(\boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2))$ with $\phi = (\boldsymbol{\mu}, \boldsymbol{\sigma})$:
$$\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}), \quad \mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}$$
Gradients: $$\frac{\partial \mathbf{z}}{\partial \boldsymbol{\mu}} = \mathbf{I}, \quad \frac{\partial \mathbf{z}}{\partial \boldsymbol{\sigma}} = \text{diag}(\boldsymbol{\epsilon})$$
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
import torchimport torch.nn as nnimport numpy as np class ReparameterizedGaussian(nn.Module): """ Demonstrates the reparameterization trick for Gaussian distributions. Key insight: Instead of z ~ N(mu, sigma²), we use eps ~ N(0, 1), z = mu + sigma * eps This allows gradients to flow through mu and sigma. """ def __init__(self, dim): super().__init__() self.dim = dim self.mu = nn.Parameter(torch.zeros(dim)) self.log_sigma = nn.Parameter(torch.zeros(dim)) def sample_without_reparam(self, n_samples): """ Direct sampling - BREAKS gradient flow! torch.randn() with mean/std specified creates a new distribution each time, and gradients don't flow back through the sampling. """ sigma = torch.exp(self.log_sigma) # This doesn't allow gradients to flow to mu and sigma! return torch.randn(n_samples, self.dim) * sigma + self.mu def sample_with_reparam(self, n_samples): """ Reparameterized sampling - gradients flow through mu and sigma! z = mu + sigma * eps, where eps ~ N(0, I) The randomness comes from eps, which doesn't depend on parameters. mu and sigma enter deterministically, allowing backprop. """ sigma = torch.exp(self.log_sigma) eps = torch.randn(n_samples, self.dim) # Parameter-free noise z = self.mu + sigma * eps # Deterministic function of (mu, sigma, eps) return z def elbo_objective(self, z): """ Simple ELBO: E_q[log p(z)] + H[q] For standard normal prior p(z) = N(0, I) """ # log p(z) for standard normal prior log_p = -0.5 * torch.sum(z**2, dim=-1) - 0.5 * self.dim * np.log(2 * np.pi) # Entropy of q (analytical for Gaussian) sigma = torch.exp(self.log_sigma) entropy = 0.5 * self.dim * (1 + np.log(2 * np.pi)) + torch.sum(self.log_sigma) return log_p.mean() + entropy # Demonstrate gradient flowprint("=== Reparameterization Trick Demo ===\n") model = ReparameterizedGaussian(dim=10)optimizer = torch.optim.Adam(model.parameters(), lr=0.1) print("Before optimization:")print(f" mu: {model.mu.data[:3].numpy()}")print(f" sigma: {torch.exp(model.log_sigma.data)[:3].numpy()}") # Optimize to match standard normal priorfor step in range(100): optimizer.zero_grad() # Sample using reparameterization z = model.sample_with_reparam(n_samples=100) # Compute ELBO (we want to maximize, so negate for loss) elbo = model.elbo_objective(z) loss = -elbo # Gradients flow through z back to mu and sigma! loss.backward() optimizer.step() print("\nAfter optimization (should match N(0,1)):")print(f" mu: {model.mu.data[:3].numpy()}")print(f" sigma: {torch.exp(model.log_sigma.data)[:3].numpy()}") # Verify gradients existz = model.sample_with_reparam(n_samples=10)loss = -model.elbo_objective(z)loss.backward()print(f"\nGradient on mu: {model.mu.grad[:3].numpy()}")print(f"Gradient on log_sigma: {model.log_sigma.grad[:3].numpy()}")print("\n✓ Gradients exist! Reparameterization enables backprop.")Not all distributions can be reparameterized. Common reparameterizable families include:
• Gaussian: $z = \mu + \sigma \cdot \epsilon$, $\epsilon \sim \mathcal{N}(0, 1)$ • Uniform: $z = a + (b-a) \cdot \epsilon$, $\epsilon \sim \mathrm{Uniform}(0, 1)$ • Exponential: $z = -\beta \log \epsilon$, $\epsilon \sim \mathrm{Uniform}(0, 1)$ • Location-scale families: Generally reparameterizable
Not directly reparameterizable: Categorical, Bernoulli, Poisson (discrete distributions)
The reparameterization trick reduces variance because it exploits the structure of the gradient:
$$\nabla_\phi f(g_\phi(\boldsymbol{\epsilon})) = \nabla_{\mathbf{z}} f(\mathbf{z}) \cdot \nabla_\phi g_\phi(\boldsymbol{\epsilon})$$
In contrast, the score function estimator multiplies $f$ by $\nabla_\phi \log q_\phi$—two potentially high-variance quantities whose product has even higher variance.
Empirical comparison: In VAE training, reparameterization typically gives 10-100× lower gradient variance than REINFORCE, enabling faster and more stable training.
When reparameterization isn't possible (e.g., discrete latent variables), we fall back to the score function estimator, also known as REINFORCE in the reinforcement learning literature.
We want $\nabla_\phi \mathbb{E}{q\phi}[f(\mathbf{z})]$. Using the identity:
$$\nabla_\phi q_\phi(\mathbf{z}) = q_\phi(\mathbf{z}) \nabla_\phi \log q_\phi(\mathbf{z})$$
We derive:
$$\begin{aligned} \nabla_\phi \mathbb{E}{q\phi}[f(\mathbf{z})] &= \nabla_\phi \int q_\phi(\mathbf{z}) f(\mathbf{z}) , d\mathbf{z} \ &= \int f(\mathbf{z}) \nabla_\phi q_\phi(\mathbf{z}) , d\mathbf{z} \ &= \int f(\mathbf{z}) q_\phi(\mathbf{z}) \nabla_\phi \log q_\phi(\mathbf{z}) , d\mathbf{z} \ &= \mathbb{E}{q\phi}[f(\mathbf{z}) \nabla_\phi \log q_\phi(\mathbf{z})] \end{aligned}$$
The term $\nabla_\phi \log q_\phi(\mathbf{z})$ is called the score function, giving the estimator its name.
The score function estimator is:
$$\nabla_\phi \mathcal{L}(\phi) \approx \frac{1}{K} \sum_{k=1}^K f(\mathbf{z}^{(k)}) \nabla_\phi \log q_\phi(\mathbf{z}^{(k)})$$
where $\mathbf{z}^{(k)} \sim q_\phi(\mathbf{z})$.
The Problem: High Variance
The score function estimator is unbiased but has notoriously high variance:
Why the High Variance?
The score $\nabla_\phi \log q_\phi$ points in different directions for different samples. Multiplying by $f$—which varies substantially across the distribution—creates wildly varying gradient estimates. The signal (mean gradient) gets lost in noise.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
import torchimport torch.nn as nnimport numpy as np def score_function_gradient_estimate(log_q, f, samples, n_samples=100): """ Score function (REINFORCE) gradient estimator. ∇_φ E_q[f(z)] ≈ (1/K) Σ f(z_k) · ∇_φ log q(z_k) Args: log_q: Function computing log q_φ(z); must be differentiable in φ f: Function computing f(z); the objective we're taking expectation of samples: Samples z ~ q_φ(z) Returns: Gradient estimate (via backward pass on the surrogate loss) """ # Compute surrogate loss for REINFORCE # gradient of E[f] ≈ E[f · ∇log q] # This is implemented as: loss = E[f.detach() * log q] # Taking gradient of this loss gives the REINFORCE estimator f_values = f(samples).detach() # Stop gradient through f log_q_values = log_q(samples) # Keep gradient through log q # Surrogate loss: when differentiated, gives REINFORCE gradient surrogate_loss = -(f_values * log_q_values).mean() return surrogate_loss def compare_estimator_variances(): """ Compare variance of score function vs reparameterization estimators. """ torch.manual_seed(42) # Simple 1D Gaussian variational distribution mu = nn.Parameter(torch.tensor(1.0)) log_sigma = nn.Parameter(torch.tensor(0.0)) def get_sigma(): return torch.exp(log_sigma) # Target: minimize E_q[(z - 2)^2] (should push mu toward 2) def f(z): return -(z - 2)**2 # Negative because we maximize ELBO def log_q(z): sigma = get_sigma() return -0.5 * ((z - mu) / sigma)**2 - log_sigma - 0.5 * np.log(2 * np.pi) n_estimates = 100 n_samples = 10 # Collect gradient estimates from both methods reparam_grads = [] score_grads = [] for _ in range(n_estimates): # === Reparameterization Method === eps = torch.randn(n_samples) z_reparam = mu + get_sigma() * eps loss_reparam = -f(z_reparam).mean() loss_reparam.backward() reparam_grads.append(mu.grad.item()) mu.grad.zero_() log_sigma.grad.zero_() # === Score Function Method === # Sample (no reparameterization) with torch.no_grad(): z_score = mu + get_sigma() * torch.randn(n_samples) f_vals = f(z_score).detach() log_q_vals = log_q(z_score) surrogate = -(f_vals * log_q_vals).mean() surrogate.backward() score_grads.append(mu.grad.item()) mu.grad.zero_() log_sigma.grad.zero_() print("Gradient Estimator Variance Comparison") print("=" * 50) print(f"Samples per estimate: {n_samples}") print(f"Number of estimates: {n_estimates}") print() print(f"Reparameterization:") print(f" Mean gradient: {np.mean(reparam_grads):.4f}") print(f" Std deviation: {np.std(reparam_grads):.4f}") print() print(f"Score Function (REINFORCE):") print(f" Mean gradient: {np.mean(score_grads):.4f}") print(f" Std deviation: {np.std(score_grads):.4f}") print() print(f"Variance ratio: {np.var(score_grads) / np.var(reparam_grads):.1f}x higher for REINFORCE") compare_estimator_variances()When we must use the score function estimator (e.g., for discrete latents), variance reduction is essential. Several techniques can dramatically improve gradient quality.
The idea: subtract a quantity with known expectation to reduce variance without introducing bias.
Observe that for any $b$ not depending on $\mathbf{z}$:
$$\mathbb{E}q[\nabla\phi \log q_\phi(\mathbf{z})] = 0$$
(The score function has zero expectation.)
Therefore, we can modify the REINFORCE estimator:
$$\nabla_\phi \mathcal{L} \approx \frac{1}{K} \sum_k (f(\mathbf{z}^{(k)}) - b) \nabla_\phi \log q_\phi(\mathbf{z}^{(k)})$$
This is still unbiased! The optimal $b$ minimizes variance:
$$b^* = \frac{\mathbb{E}[f \cdot |\nabla \log q|^2]}{\mathbb{E}[|\nabla \log q|^2]}$$
In practice, a running average of $f$ works well.
For categorical latent variables, the Gumbel-Softmax (or Concrete) distribution provides a differentiable approximation:
$$y_i = \frac{\exp((\log \pi_i + g_i) / \tau)}{\sum_j \exp((\log \pi_j + g_j) / \tau)}$$
where $g_i \sim \text{Gumbel}(0, 1)$ and $\tau$ is a temperature parameter.
This enables reparameterized gradients for categorical variables at the cost of a biased approximation.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
import torchimport torch.nn.functional as Fimport numpy as np def reinforce_with_baseline(): """ Demonstrate variance reduction with control variates (baselines). """ torch.manual_seed(42) # Categorical distribution over 5 classes logits = torch.tensor([1.0, 2.0, 0.5, 1.5, 0.8], requires_grad=True) # Reward function: r(k) = k (higher classes are better) def reward(k): return float(k) def estimate_gradient(use_baseline=False, n_samples=100): probs = F.softmax(logits, dim=0) # Sample from categorical samples = torch.multinomial(probs, n_samples, replacement=True) # Compute rewards rewards = torch.tensor([reward(k.item()) for k in samples]) # Baseline: running mean of rewards baseline = rewards.mean() if use_baseline else 0.0 # Log probabilities log_probs = torch.log(probs[samples]) # REINFORCE gradient estimate # ∇E[r] ≈ E[(r - b) ∇log π] advantages = rewards - baseline surrogate_loss = -(advantages * log_probs).mean() surrogate_loss.backward() grad = logits.grad.clone() logits.grad.zero_() return grad # Compare variance with and without baseline n_estimates = 200 grads_no_baseline = [estimate_gradient(False, n_samples=50) for _ in range(n_estimates)] grads_with_baseline = [estimate_gradient(True, n_samples=50) for _ in range(n_estimates)] var_no_baseline = torch.stack(grads_no_baseline).var(dim=0).mean() var_with_baseline = torch.stack(grads_with_baseline).var(dim=0).mean() print("Variance Reduction with Baselines") print("=" * 40) print(f"Without baseline: variance = {var_no_baseline:.4f}") print(f"With baseline: variance = {var_with_baseline:.4f}") print(f"Reduction factor: {var_no_baseline / var_with_baseline:.1f}x") def gumbel_softmax_demo(): """ Demonstrate Gumbel-Softmax for differentiable discrete sampling. """ def sample_gumbel(shape): """Sample from Gumbel(0, 1)""" u = torch.rand(shape) return -torch.log(-torch.log(u + 1e-10) + 1e-10) def gumbel_softmax(logits, temperature): """ Gumbel-Softmax: differentiable approximation to categorical sampling. """ gumbels = sample_gumbel(logits.shape) y = (logits + gumbels) / temperature return F.softmax(y, dim=-1) # Categorical logits logits = torch.tensor([1.0, 2.0, 0.5], requires_grad=True) print("\nGumbel-Softmax Demo") print("=" * 40) print("Logits:", logits.data.numpy()) print("True probs:", F.softmax(logits, dim=0).data.numpy()) for temp in [0.1, 0.5, 1.0, 2.0]: samples = [gumbel_softmax(logits, temp) for _ in range(5)] mean_sample = torch.stack(samples).mean(dim=0) print(f"\nTemp={temp}: Sample mean = {mean_sample.data.numpy()}") # Show a single sample single = samples[0].data.numpy() print(f" Single sample = {single}") print("\nNote: Lower temperature → more discrete (one-hot)") print(" Higher temperature → more uniform") # Verify gradients flow y = gumbel_softmax(logits, temperature=0.5) loss = (y * torch.tensor([0., 0., 1.])).sum() # Want class 2 loss.backward() print(f"\nGradients exist: {logits.grad.numpy()}") reinforce_with_baseline()gumbel_softmax_demo()Successfully training variational models requires attention to several practical details beyond choosing the right gradient estimator.
Modern VI typically uses adaptive optimizers:
Learning rate: Start with 1e-3 to 1e-4 for Adam; schedule to decay over training.
For large datasets, stochastic VI uses mini-batches:
$$\mathcal{L}(\phi) \approx \frac{N}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathcal{L}_i(\phi)$$
where $\mathcal{B}$ is a mini-batch of size $|\mathcal{B}|$ from $N$ total examples.
This enables VI on datasets with millions of examples—each gradient step uses only a subset of data.
| Problem | Symptoms | Solutions |
|---|---|---|
| Posterior collapse | KL → 0, reconstruction poor | KL annealing, free bits, architecture changes |
| High gradient variance | Unstable training, slow convergence | More samples, baselines, reparameterization |
| Mode collapse | Generated samples lack diversity | Richer variational family, adversarial training |
| Underfitting | ELBO plateaus early | Increase model capacity, check implementation |
| Numerical instability | NaN/Inf in loss | Gradient clipping, smaller LR, numerical safeguards |
Posterior collapse is a common failure mode where $q(\mathbf{z}|\mathbf{x}) \approx p(\mathbf{z})$, ignoring the data. The model finds it easier to set $D_{\text{KL}} = 0$ than to use informative latents.
KL Annealing: Gradually increase the KL weight during training:
$$\mathcal{L}_\beta = \mathbb{E}q[\log p(\mathbf{x}|\mathbf{z})] - \beta_t \cdot D{\text{KL}}(q | p)$$
where $\beta_t$ increases from 0 to 1 over training. This allows the model to first learn good reconstructions, then regularize.
Free Bits: Ensure minimum information flow through the latent:
$$D_{\text{KL}}^{\text{free}}(q | p) = \max(D_{\text{KL}}(q | p), \lambda)$$
Don't penalize KL below threshold $\lambda$, ensuring at least $\lambda$ nats of information passes through.
Track these throughout training:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
import torchimport torch.nn as nnimport torch.optim as optimimport numpy as np class VAETrainer: """ Production-ready VAE training with best practices. """ def __init__(self, model, lr=1e-3, kl_anneal_steps=10000): self.model = model self.optimizer = optim.Adam(model.parameters(), lr=lr) self.kl_anneal_steps = kl_anneal_steps self.step = 0 # Tracking metrics self.history = { 'elbo': [], 'recon': [], 'kl': [], 'beta': [] } def get_beta(self): """ KL annealing: linearly increase beta from 0 to 1. This prevents posterior collapse by allowing the model to first learn good reconstructions. """ if self.step < self.kl_anneal_steps: return self.step / self.kl_anneal_steps return 1.0 def train_step(self, x): """ One training step with monitoring. """ self.optimizer.zero_grad() # Forward pass x_recon, mu, log_var = self.model(x) # Compute loss components recon_loss = nn.functional.mse_loss(x_recon, x, reduction='sum') / x.shape[0] kl_loss = 0.5 * torch.mean(torch.sum( torch.exp(log_var) + mu**2 - 1 - log_var, dim=1 )) # Apply KL annealing beta = self.get_beta() loss = recon_loss + beta * kl_loss # Backward pass with gradient clipping loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) self.optimizer.step() # Track metrics elbo = -(recon_loss + kl_loss).item() self.history['elbo'].append(elbo) self.history['recon'].append(-recon_loss.item()) self.history['kl'].append(kl_loss.item()) self.history['beta'].append(beta) self.step += 1 return { 'loss': loss.item(), 'elbo': elbo, 'recon': -recon_loss.item(), 'kl': kl_loss.item(), 'beta': beta } def check_posterior_collapse(self, threshold=0.1): """ Check if posterior collapse is occurring. Collapse = KL divergence near zero = q ≈ p(z) = N(0, I) """ recent_kl = np.mean(self.history['kl'][-100:]) if len(self.history['kl']) > 100 else None if recent_kl is not None and recent_kl < threshold: print(f"⚠️ Warning: Possible posterior collapse! KL = {recent_kl:.4f}") print(" Consider: KL annealing, free bits, or architecture changes") return True return False def print_diagnostics(self): """ Print training diagnostics. """ n = min(100, len(self.history['elbo'])) print(f"\nTraining Diagnostics (last {n} steps):") print(f" ELBO: {np.mean(self.history['elbo'][-n:]):.2f}") print(f" Recon: {np.mean(self.history['recon'][-n:]):.2f}") print(f" KL: {np.mean(self.history['kl'][-n:]):.2f}") print(f" Beta: {self.history['beta'][-1]:.3f}") # Example training snippetprint("Training Configuration Example")print("=" * 50)print("""# Best practices for VAE training: 1. Use Adam optimizer with lr=1e-3 to 1e-42. Apply KL annealing over 10k-50k steps3. Monitor ELBO, reconstruction, and KL separately4. Watch for posterior collapse (KL → 0)5. Use gradient clipping (max_norm=10)6. Consider learning rate scheduling7. Multiple MC samples per example can help8. Validate with sample quality, not just ELBO""")We have now completed a comprehensive tour of the Variational Inference Framework. Let's consolidate the key ideas:
Variational inference transforms intractable Bayesian inference into tractable optimization:
This module establishes the foundations. The remaining modules in this chapter explore:
Congratulations! You have mastered the foundational concepts of variational inference. You understand why VI exists (intractability), how it works (optimization of the ELBO), and how to implement it (reparameterization, gradient estimation). This knowledge forms the basis for all modern probabilistic machine learning, from VAEs to Bayesian neural networks to large-scale topic models.