Loading learning content...
Traditional variational inference faces a fundamental computational barrier: every iteration requires processing the entire dataset. For a dataset with \(N\) observations, computing the Evidence Lower Bound (ELBO) gradient demands \(O(N)\) operations per update—a cost that becomes prohibitive when \(N\) reaches millions or billions.
This bottleneck isn't merely inconvenient; it fundamentally limits the applicability of Bayesian methods to real-world problems. Modern machine learning datasets routinely contain billions of data points—web-scale text corpora, social network graphs, genomic sequences, recommendation system interactions. Without a scalable inference strategy, variational methods remain confined to toy problems.
Stochastic Variational Inference (SVI) resolves this crisis by embracing a simple but profound insight: we can estimate gradients using random mini-batches of data, trading some variance for dramatic computational savings. This page develops the mathematical foundations of mini-batch optimization for variational inference, establishing the principles that enable modern Bayesian deep learning.
By the end of this page, you will understand how to formulate unbiased gradient estimators from mini-batches, the mathematical conditions that guarantee convergence, the variance-bias tradeoffs inherent in stochastic estimation, and the practical considerations for implementing mini-batch VI at scale.
To understand why mini-batch optimization is necessary, we must first appreciate the computational structure of the ELBO. Recall that for a model with latent variables \(z\), observations \(\mathbf{x} = {x_1, \ldots, x_N}\), and variational distribution \(q(z; \phi)\), the ELBO takes the form:
$$\mathcal{L}(\phi) = \mathbb{E}{q(z; \phi)}[\log p(\mathbf{x}, z)] - \mathbb{E}{q(z; \phi)}[\log q(z; \phi)]$$
Assuming conditionally independent observations given the latent variables—a standard assumption in most models—the joint likelihood factorizes:
$$\log p(\mathbf{x}, z) = \log p(z) + \sum_{i=1}^{N} \log p(x_i | z)$$
Substituting into the ELBO:
$$\mathcal{L}(\phi) = \mathbb{E}{q(z; \phi)}\left[\log p(z) + \sum{i=1}^{N} \log p(x_i | z)\right] - \mathbb{E}_{q(z; \phi)}[\log q(z; \phi)]$$
The inner summation over all N data points is the computational bottleneck. Every gradient computation requires evaluating log-likelihoods for every observation, making the per-iteration cost O(N). For N = 10^9, even a single gradient step becomes infeasible.
Decomposing the ELBO:
We can separate the ELBO into data-dependent and data-independent terms:
$$\mathcal{L}(\phi) = \underbrace{\mathbb{E}{q}[\log p(z)] - \mathbb{E}{q}[\log q(z; \phi)]}{\text{KL regularization term}} + \underbrace{\sum{i=1}^{N} \mathbb{E}{q}[\log p(x_i | z)]}{\text{Data likelihood term}}$$
The KL regularization term involves only the prior \(p(z)\) and variational distribution \(q(z; \phi)\)—this can be computed efficiently, often in closed form for conjugate families. The computational burden lies entirely in the data likelihood summation.
The gradient inherits this structure:
$$\nabla_\phi \mathcal{L}(\phi) = \nabla_\phi \left(-\text{KL}[q(z; \phi) | p(z)]\right) + \sum_{i=1}^{N} \nabla_\phi \mathbb{E}_{q(z; \phi)}[\log p(x_i | z)]$$
The sum over all \(N\) data points persists in the gradient, making standard gradient ascent impractical for large datasets.
The key insight enabling scalability is that the data likelihood term is an average over observations, which can be approximated by sampling.
Reformulating as an expectation:
Define the per-datapoint contribution to the ELBO:
$$\ell_i(\phi) = \mathbb{E}_{q(z; \phi)}[\log p(x_i | z)]$$
The full data likelihood term becomes:
$$\sum_{i=1}^{N} \ell_i(\phi) = N \cdot \frac{1}{N} \sum_{i=1}^{N} \ell_i(\phi) = N \cdot \mathbb{E}_{i \sim \text{Uniform}(1, N)}[\ell_i(\phi)]$$
This reformulation reveals that the summation is equivalent to \(N\) times an expectation over uniformly sampled data indices. We can estimate this expectation using Monte Carlo sampling.
Given a mini-batch B of size |B| = M, sampled uniformly from the dataset, an unbiased estimator of the data likelihood term is:
$$\sum_{i=1}^{N} \ell_i(\phi) \approx \frac{N}{M} \sum_{j \in B} \ell_j(\phi)$$
The scaling factor N/M ensures the estimator is unbiased: its expectation equals the true sum.
The stochastic ELBO estimator:
Combining with the KL term, the mini-batch ELBO estimator is:
$$\hat{\mathcal{L}}(\phi; B) = -\text{KL}[q(z; \phi) | p(z)] + \frac{N}{M} \sum_{j \in B} \mathbb{E}_{q(z; \phi)}[\log p(x_j | z)]$$
Unbiasedness proof:
$$\mathbb{E}B[\hat{\mathcal{L}}(\phi; B)] = -\text{KL}[q(z; \phi) | p(z)] + \frac{N}{M} \cdot M \cdot \frac{1}{N} \sum{i=1}^{N} \ell_i(\phi) = \mathcal{L}(\phi)$$
The expectation over random mini-batches recovers the true ELBO exactly. This unbiasedness property is crucial—it guarantees that stochastic optimization will converge to the same optimum as full-batch optimization, given sufficient iterations.
The stochastic gradient:
$$\hat{g}(\phi; B) = \nabla_\phi \left(-\text{KL}[q(z; \phi) | p(z)]\right) + \frac{N}{M} \sum_{j \in B} \nabla_\phi \mathbb{E}_{q(z; \phi)}[\log p(x_j | z)]$$
This gradient estimator is also unbiased: \(\mathbb{E}B[\hat{g}(\phi; B)] = \nabla\phi \mathcal{L}(\phi)\).
While mini-batch gradient estimators are unbiased, they introduce variance that affects optimization dynamics. Understanding this variance is critical for practical implementation.
Variance decomposition:
The variance of the mini-batch gradient estimator arises from two sources:
For the mini-batch contribution, the variance scales as:
$$\text{Var}B\left[\frac{N}{M} \sum{j \in B} \nabla_\phi \ell_j(\phi)\right] = \frac{N^2}{M^2} \cdot M \cdot \text{Var}i[\nabla\phi \ell_i(\phi)] = \frac{N^2}{M} \sigma^2$$
where \(\sigma^2 = \text{Var}i[\nabla\phi \ell_i(\phi)]\) is the variance across individual data point gradients.
| Batch Size (M) | Variance Factor | Computation | Updates per Epoch |
|---|---|---|---|
| 1 (pure SGD) | N²σ² | O(1) | N |
| √N | N^(3/2)σ² | O(√N) | √N |
| N/10 | 10Nσ² | O(N/10) | 10 |
| N (full batch) | 0 | O(N) | 1 |
The variance-computation tradeoff:
Increasing batch size reduces variance but increases computation per update. The key insight is that variance decreases as \(1/M\), but computation increases linearly with \(M\). This means:
This suggests that smaller batches with more updates may achieve similar performance to larger batches with fewer updates, at equal computational cost—a principle confirmed empirically across many domains.
Practical implications:
$$\text{Effective progress} \propto \frac{\text{Number of updates}}{\sqrt{\text{Variance per update}}} \propto \frac{N/M}{N/\sqrt{M}} = \frac{\sqrt{M}}{M} = \frac{1}{\sqrt{M}}$$
Wait—this analysis suggests larger batches make slower progress per unit computation! The reality is nuanced: very small batches suffer from gradient noise so severe that learning rate must be reduced, partially offsetting the advantage. There exists an optimal batch size regime that balances these effects.
When increasing batch size by factor k, one can often increase the learning rate by factor √k (or sometimes k) while maintaining stable optimization. This 'linear scaling rule' partially compensates for the reduced number of updates, but doesn't fully eliminate the efficiency advantage of smaller batches in many settings.
Beyond mini-batch sampling, computing \(\nabla_\phi \mathbb{E}_{q(z; \phi)}[\log p(x | z)]\) itself requires estimation. Two principal approaches exist, each with distinct variance characteristics and applicability.
The challenge:
The gradient \(\nabla_\phi \mathbb{E}_{q(z; \phi)}[f(z)]\) cannot be directly computed by sampling from \(q\) and differentiating \(f\)—the expectation depends on \(\phi\) through the distribution, not just the integrand.
Reparameterization in detail:
For a Gaussian variational distribution \(q(z; \mu, \sigma) = \mathcal{N}(z; \mu, \sigma^2)\), we can write:
$$z = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, 1)$$
Now the gradient becomes:
$$\nabla_{\mu, \sigma} \mathbb{E}{q}[f(z)] = \mathbb{E}{\epsilon}\left[\nabla_{\mu, \sigma} f(\mu + \sigma \cdot \epsilon)\right]$$
The expectation is now over \(\epsilon\), which doesn't depend on the variational parameters. We can estimate this by:
This is the foundation of variational autoencoders (VAEs) and modern amortized inference.
Empirically, the reparameterization trick typically yields gradient estimates with 10x–1000x lower variance than the score function estimator for continuous latent variables. This dramatic reduction is why reparameterization-based methods dominate modern variational inference.
We now have all components to assemble the Stochastic Variational Inference algorithm. The procedure combines mini-batch sampling with gradient estimation to optimize the ELBO efficiently.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
import numpy as np def stochastic_variational_inference( data: np.ndarray, # Shape: (N, D) model, # Probabilistic model with log_prior, log_likelihood q_family, # Variational distribution family batch_size: int = 100, learning_rate: float = 0.01, num_iterations: int = 10000, num_samples: int = 1 # MC samples for expectation estimation): """ Stochastic Variational Inference with mini-batch optimization. Args: data: Dataset of N observations, each D-dimensional model: Defines p(z) via log_prior(z) and p(x|z) via log_likelihood(x, z) q_family: Variational distribution q(z; phi), supports sample() and log_prob() batch_size: Mini-batch size M learning_rate: Step size for gradient ascent num_iterations: Total optimization iterations num_samples: Number of MC samples for ELBO gradient estimation Returns: Optimized variational parameters phi """ N = len(data) phi = q_family.initialize_parameters() for iteration in range(num_iterations): # Step 1: Sample mini-batch uniformly batch_indices = np.random.choice(N, size=batch_size, replace=False) batch = data[batch_indices] # Step 2: Compute stochastic ELBO gradient grad_accumulator = np.zeros_like(phi) for _ in range(num_samples): # Sample from variational distribution (using reparameterization) epsilon = np.random.standard_normal(q_family.latent_dim) z = q_family.reparameterize(phi, epsilon) # Compute gradient contributions # KL term gradient (often analytic for exponential families) grad_kl = q_family.grad_kl_divergence(phi, model.prior) # Likelihood term gradient (scaled by N/M for unbiasedness) grad_likelihood = np.zeros_like(phi) for x in batch: grad_likelihood += q_family.grad_log_likelihood(phi, z, x, model) grad_likelihood *= (N / batch_size) grad_accumulator += -grad_kl + grad_likelihood # Average over MC samples gradient = grad_accumulator / num_samples # Step 3: Update variational parameters phi = phi + learning_rate * gradient # Optional: Decrease learning rate (Robbins-Monro conditions) if iteration % 1000 == 0: learning_rate *= 0.9 return phiAlgorithm correctness:
The algorithm produces an unbiased gradient estimate at each step. Under standard stochastic optimization assumptions (convexity or appropriate non-convex conditions, decreasing learning rate satisfying Robbins-Monro conditions \(\sum_t \alpha_t = \infty\) and \(\sum_t \alpha_t^2 < \infty\)), the iterates converge to a local optimum of the true ELBO.
Computational complexity:
Compared to batch VI's \(O(T \cdot N \cdot S \cdot C)\) where \(T\) is iterations, SVI trades more iterations for cheaper per-iteration cost, typically yielding dramatic speedups.
Many probabilistic models distinguish between global latent variables (shared across all observations) and local latent variables (specific to each observation). This distinction has profound implications for SVI.
Examples:
The factorization assumption:
For models with this structure, the variational distribution typically factorizes:
$$q(z_{\text{global}}, z_{\text{local}}) = q(z_{\text{global}}) \prod_{n=1}^{N} q(z_{\text{local}}^{(n)})$$
Local variational parameters scale with dataset size—there are N separate q(z_local^(n)) distributions to optimize. This creates O(N) variational parameters, potentially negating the benefits of mini-batch optimization if handled naively.
The SVI solution for local variables:
Hoffman et al. (2013) introduced an elegant solution: update local variational parameters to their optimal values given current global parameters, rather than taking gradient steps.
For many models (especially those in the conjugate exponential family), the optimal local variational parameters have closed-form solutions:
$$\phi_{\text{local}}^{(n)*} = \arg\max_{\phi_{\text{local}}^{(n)}} \mathcal{L}(\phi_{\text{global}}, \phi_{\text{local}}^{(n)})$$
This is called local coordinate ascent within the mini-batch:
This approach ensures local parameters never need to be stored for the entire dataset—only for the current mini-batch.
Complete algorithm with local/global separation:
Translating the SVI algorithm into production code requires careful attention to numerical stability, efficiency, and practical optimization heuristics.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
import torchimport torch.nn as nnfrom torch.utils.data import DataLoader, TensorDataset class StochasticVITrainer: """ Production-ready SVI trainer with best practices. """ def __init__( self, model: nn.Module, # Encoder-decoder or variational model dataset: TensorDataset, batch_size: int = 128, learning_rate: float = 1e-3, mc_samples: int = 1, device: str = "cuda" ): self.model = model.to(device) self.device = device self.mc_samples = mc_samples # DataLoader with efficient batching self.dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=4, # Parallel data loading pin_memory=True, # Faster GPU transfer drop_last=True # Consistent batch sizes ) # Optimizer with weight decay (L2 regularization on parameters) self.optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=1e-5 ) # Learning rate scheduler self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( self.optimizer, T_0=10, # Restart every 10 epochs T_mult=2 # Double period after each restart ) self.N = len(dataset) def train_epoch(self) -> float: """Train for one epoch, return average ELBO.""" self.model.train() total_elbo = 0.0 num_batches = 0 for batch in self.dataloader: x = batch[0].to(self.device) batch_size = x.shape[0] # Zero gradients self.optimizer.zero_grad(set_to_none=True) # More efficient # Forward pass with MC sampling elbo = 0.0 for _ in range(self.mc_samples): # Reparameterized sampling and ELBO computation # (Implementation depends on specific model) z, kl_div = self.model.encode_and_sample(x) log_likelihood = self.model.decode_log_prob(x, z) # ELBO = E[log p(x|z)] - KL[q(z|x) || p(z)] # Scale likelihood by N/batch_size for unbiased gradient elbo += (log_likelihood.sum() / batch_size - kl_div.mean()) elbo /= self.mc_samples # Negate for minimization (PyTorch minimizes by default) loss = -elbo # Backward pass with gradient clipping loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0) # Update parameters self.optimizer.step() total_elbo += elbo.item() num_batches += 1 # Step the scheduler self.scheduler.step() return total_elbo / num_batchesWe have established the mathematical foundations of mini-batch optimization for variational inference—the key innovation that enables Bayesian methods to scale to modern datasets.
What's next:
Mini-batch optimization provides the computational foundation for scalable VI, but the choice of gradient direction matters profoundly. Standard Euclidean gradients can lead to slow convergence in the probability simplex and other constrained spaces. The next page introduces natural gradients—a principled approach that respects the geometry of probability distributions and dramatically accelerates convergence.
You now understand the fundamental mechanics of mini-batch optimization for variational inference. This technique—combining unbiased gradient estimation with stochastic optimization—is the cornerstone of all modern scalable Bayesian methods, from VAEs to Bayesian neural networks to massive topic models.