Loading content...
Stochastic variational inference works remarkably well in practice, but why does it work? When can we guarantee convergence? How fast should we expect progress? These questions are not merely academic—they inform crucial practical decisions about learning rate schedules, stopping criteria, and hyperparameter selection.
The theoretical analysis of SVI draws on decades of work in stochastic optimization, adapting classical results to the probabilistic inference setting. Understanding this theory provides:
This page develops the convergence theory for stochastic variational inference, from classical Robbins-Monro conditions to modern non-convex convergence guarantees.
By the end of this page, you will understand the mathematical conditions that guarantee SVI convergence, convergence rates for convex and non-convex ELBOs, the role of learning rate schedules in balancing bias and variance, and practical diagnostics for monitoring convergence.
Stochastic variational inference is an instance of stochastic gradient ascent (or descent, depending on sign conventions). The foundational theory dates to Robbins and Monro (1951), who established conditions for convergence of stochastic approximation algorithms.
The SVI update:
Recall the SVI parameter update:
$$\phi_{t+1} = \phi_t + \rho_t \hat{g}_t$$
where:
The fundamental challenge:
Unlike deterministic gradient descent, stochastic updates have noise that doesn't vanish. Each gradient estimate has variance \(\sigma^2 > 0\). If we use constant step size \(\rho\), updates will oscillate around the optimum rather than converging to it.
Decreasing step sizes are necessary for convergence but must decrease slowly enough that we still make progress.
For stochastic gradient methods to converge to a stationary point, the learning rate sequence {ρ_t} must satisfy:
Common choices: ρ_t = c/(t+τ) or ρ_t = c/t^α for α ∈ (0.5, 1].
Intuition for Robbins-Monro:
Consider what happens with different learning rate behaviors:
Too fast decay (e.g., \(\rho_t = 1/t^2\)):
Too slow decay (e.g., \(\rho_t = 1\), constant):
Just right (e.g., \(\rho_t = 1/t\)):
The \(1/t\) learning rate:
The canonical choice \(\rho_t = c/(t + \tau)\) satisfies both conditions. The offset \(\tau\) controls initial step size: larger \(\tau\) means smaller initial steps (more conservative start).
When the ELBO is a concave function of the variational parameters (equivalent to convex minimization of the negative ELBO), strong convergence guarantees exist.
Convexity arises when:
Key convex convergence theorem:
Suppose \(\mathcal{L}(\phi)\) is concave, \(L\)-smooth (gradient Lipschitz), and the stochastic gradient has bounded variance \(\mathbb{E}[|\hat{g}_t - \nabla\mathcal{L}(\phi_t)|^2] \leq \sigma^2\). Then with learning rate \(\rho_t = c/\sqrt{t}\):
$$\mathbb{E}[\mathcal{L}(\phi^*)] - \mathbb{E}[\mathcal{L}(\bar{\phi}_T)] \leq O\left(\frac{1}{\sqrt{T}}\right)$$
where \(\bar{\phi}T = \frac{1}{T}\sum{t=1}^T \phi_t\) is the average iterate.
Interpretation: To achieve \(\epsilon\)-suboptimality, we need \(O(1/\epsilon^2)\) iterations.
| Setting | Rate | Iterations for ε-optimal | Notes |
|---|---|---|---|
| Convex, stochastic gradient | O(1/√T) | O(1/ε²) | Standard SVI rate |
| Strongly convex, stochastic | O(1/T) | O(1/ε) | Faster with strong convexity |
| Convex, batch gradient | O(1/T) | O(1/ε) | Deterministic GD |
| Strongly convex, batch | O(exp(-cT)) | O(log(1/ε)) | Linear convergence |
Strong convexity and faster rates:
If the ELBO is additionally \(\mu\)-strongly concave:
$$\mathcal{L}(\phi) \leq \mathcal{L}(\phi') + \nabla\mathcal{L}(\phi')^T(\phi - \phi') - \frac{\mu}{2}|\phi - \phi'|^2$$
then convergence accelerates to \(O(1/T)\) with appropriate learning rate \(\rho_t = 2/(\mu(t+1))\).
Practical implication:
Strong convexity arises from regularization. Adding an \(L_2\) penalty on variational parameters (or equivalently, a prior on the variational posterior) induces strong convexity and accelerates convergence:
$$\mathcal{L}_{\text{reg}}(\phi) = \mathcal{L}(\phi) - \frac{\lambda}{2}|\phi|^2$$
This trades bias (the regularized optimum differs from the true optimum) for faster convergence.
Many practical variational inference problems involve non-convex ELBOs:
For non-convex objectives, we cannot guarantee convergence to global optima—the landscape may have many local maxima, saddle points, and plateaus.
For non-convex problems, standard SVI guarantees convergence to a stationary point (zero gradient), not a global optimum. A stationary point may be a local maximum, saddle point, or even a local minimum (if the ELBO has negative curvature regions).
Non-convex convergence theorem:
Suppose \(\mathcal{L}(\phi)\) is \(L\)-smooth (not necessarily concave), gradients have bounded variance \(\sigma^2\), and learning rate \(\rho_t = c/\sqrt{T}\) (constant over the run, tuned to horizon \(T\)). Then:
$$\min_{t \leq T} \mathbb{E}[|\nabla\mathcal{L}(\phi_t)|^2] \leq O\left(\frac{1}{\sqrt{T}}\right)$$
Interpretation: After \(T\) iterations, the smallest gradient norm observed is \(O(1/\sqrt{T})\). To find a point with gradient norm \(\leq \epsilon\), we need \(O(1/\epsilon^2)\) iterations.
The gap with convex analysis:
Notice we measure convergence by gradient norm, not function value. In non-convex settings, small gradients don't imply closeness to optimal value. A saddle point has zero gradient but may be far from any local maximum.
Escaping saddle points:
Modern theory shows that stochastic gradient noise actually helps escape saddle points. The randomness perturbs the trajectory, making it unlikely to get trapped at unstable equilibria. This is one advantage of stochastic over deterministic optimization for non-convex problems.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
import numpy as npfrom typing import List, Tuplefrom dataclasses import dataclass @dataclassclass ConvergenceMetrics: """Metrics for monitoring SVI convergence.""" iteration: int elbo: float gradient_norm: float parameter_change: float learning_rate: float class ConvergenceMonitor: """ Monitor convergence of stochastic variational inference. Tracks ELBO, gradient norms, and parameter changes to detect convergence, divergence, or pathological behavior. """ def __init__( self, window_size: int = 100, elbo_tolerance: float = 1e-4, gradient_tolerance: float = 1e-5, patience: int = 10 ): """ Args: window_size: Number of iterations for moving average elbo_tolerance: Relative improvement threshold for convergence gradient_tolerance: Gradient norm threshold for convergence patience: Consecutive windows without improvement before stopping """ self.window_size = window_size self.elbo_tolerance = elbo_tolerance self.gradient_tolerance = gradient_tolerance self.patience = patience self.history: List[ConvergenceMetrics] = [] self.no_improvement_count = 0 self.best_elbo = float('-inf') def update( self, iteration: int, elbo: float, gradient: np.ndarray, params_prev: np.ndarray, params_curr: np.ndarray, learning_rate: float ) -> Tuple[bool, str]: """ Record metrics and check convergence. Returns: converged: Whether convergence criteria are met status: Descriptive status message """ gradient_norm = np.linalg.norm(gradient) param_change = np.linalg.norm(params_curr - params_prev) metrics = ConvergenceMetrics( iteration=iteration, elbo=elbo, gradient_norm=gradient_norm, parameter_change=param_change, learning_rate=learning_rate ) self.history.append(metrics) # Check for divergence (NaN or extreme values) if np.isnan(elbo) or np.isnan(gradient_norm): return True, "DIVERGED: NaN detected" if np.abs(elbo) > 1e10: return True, "DIVERGED: ELBO explosion" # Not enough history yet if len(self.history) < 2 * self.window_size: return False, "WARMING_UP" # Check gradient norm convergence recent_grads = [m.gradient_norm for m in self.history[-self.window_size:]] avg_gradient = np.mean(recent_grads) if avg_gradient < self.gradient_tolerance: return True, f"CONVERGED: Avg gradient norm {avg_gradient:.2e} < {self.gradient_tolerance}" # Check ELBO improvement recent_elbos = [m.elbo for m in self.history[-self.window_size:]] older_elbos = [m.elbo for m in self.history[-2*self.window_size:-self.window_size]] recent_avg = np.mean(recent_elbos) older_avg = np.mean(older_elbos) relative_improvement = (recent_avg - older_avg) / (np.abs(older_avg) + 1e-10) if relative_improvement > self.elbo_tolerance: self.no_improvement_count = 0 if recent_avg > self.best_elbo: self.best_elbo = recent_avg else: self.no_improvement_count += 1 if self.no_improvement_count >= self.patience: return True, f"CONVERGED: No improvement for {self.patience} windows" return False, f"RUNNING: Rel improvement {relative_improvement:.2e}" def get_summary(self) -> dict: """Get summary statistics of optimization run.""" if not self.history: return {} elbos = [m.elbo for m in self.history] grads = [m.gradient_norm for m in self.history] return { "iterations": len(self.history), "final_elbo": elbos[-1], "best_elbo": max(elbos), "elbo_std": np.std(elbos[-self.window_size:]), "final_gradient_norm": grads[-1], "avg_gradient_norm": np.mean(grads[-self.window_size:]), } def diagnose_convergence_issues(history: List[ConvergenceMetrics]) -> List[str]: """ Analyze optimization history to diagnose common issues. """ issues = [] elbos = np.array([m.elbo for m in history]) grads = np.array([m.gradient_norm for m in history]) lrs = np.array([m.learning_rate for m in history]) # Check for oscillation if len(elbos) > 100: recent_elbos = elbos[-100:] if np.std(recent_elbos) > 0.1 * np.abs(np.mean(recent_elbos)): issues.append("HIGH_VARIANCE: ELBO oscillating significantly. Consider reducing learning rate.") # Check for plateau if len(elbos) > 200: recent = elbos[-100:] older = elbos[-200:-100] if np.abs(np.mean(recent) - np.mean(older)) < 1e-6: issues.append("PLATEAU: ELBO not improving. May be converged or stuck at saddle point.") # Check gradient explosion if np.any(grads > 1e6): issues.append("GRADIENT_EXPLOSION: Very large gradients detected. Use gradient clipping.") # Check gradient vanishing if len(grads) > 50 and np.mean(grads[-50:]) < 1e-10: issues.append("GRADIENT_VANISHING: Gradients near zero. Check for numerical issues.") # Check learning rate if len(lrs) > 100 and lrs[-1] < 1e-8: issues.append("LR_TOO_SMALL: Learning rate may have decayed too aggressively.") return issuesThe \(O(1/\sqrt{T})\) convergence rate of stochastic gradient methods is fundamentally limited by gradient variance. Variance reduction techniques can achieve faster rates by constructing lower-variance gradient estimators.
Why variance matters:
The convergence rate depends on the signal-to-noise ratio of gradients:
$$\text{Rate} \propto \frac{|\nabla\mathcal{L}|^2}{\sigma^2}$$
Reducing \(\sigma^2\) directly accelerates convergence.
SVRG for Variational Inference:
Stochastic Variance Reduced Gradient (SVRG) periodically computes a snapshot gradient \(\tilde{g} = \nabla\mathcal{L}(\tilde{\phi})\) at a reference point \(\tilde{\phi}\), then corrects mini-batch gradients:
$$\hat{g}_{\text{SVRG}} = \hat{g}_t - \hat{g}_t^{(ref)} + \tilde{g}$$
where \(\hat{g}_t^{(ref)}\) is the mini-batch gradient at the reference point.
Why this works:
Convergence improvement: SVRG achieves \(O(1/T)\) convergence for convex problems (matching deterministic GD) while using only mini-batches between snapshots.
Variance reduction adds complexity and occasional full-batch computation. It's most beneficial when: • Dataset is moderately sized (full-batch feasible occasionally) • Gradient variance is the bottleneck (not computation) • High precision is required For very large datasets where even one full pass is expensive, standard SVI may be preferable.
The choice of learning rate schedule profoundly affects SVI performance. While any schedule satisfying Robbins-Monro conditions converges asymptotically, practical performance varies dramatically.
Common schedules and their properties:
| Schedule | Formula | Properties | Best For |
|---|---|---|---|
| Inverse time | ρ_t = c/(t + τ) | Classical, proven convergence | Theoretical guarantees |
| Inverse sqrt | ρ_t = c/√t | Slower decay, more exploration | Non-convex, early training |
| Step decay | ρ_t = c × γ^{floor(t/s)} | Piecewise constant | Fine-tuning after warmup |
| Cosine annealing | ρ_t = ρ_min + ½(ρ_max-ρ_min)(1+cos(πt/T)) | Smooth, restarts possible | Modern deep learning |
| Warmup + decay | Linear increase then inverse | Stable start, convergent end | Large models, unstable starts |
The warmup phase:
Many modern systems use a warmup period where learning rate increases from near-zero to the target value:
$$\rho_t = \begin{cases} \rho_{\max} \cdot t / T_{\text{warmup}} & t < T_{\text{warmup}} \ \rho_{\max} \cdot \text{decay}(t - T_{\text{warmup}}) & t \geq T_{\text{warmup}} \end{cases}$$
Warmup helps with:
Practical learning rate tuning:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
import numpy as npfrom abc import ABC, abstractmethod class LRSchedule(ABC): """Abstract base class for learning rate schedules.""" @abstractmethod def get_lr(self, step: int) -> float: pass class InverseTimeSchedule(LRSchedule): """ρ_t = initial_lr / (1 + decay_rate * t)""" def __init__(self, initial_lr: float = 0.1, decay_rate: float = 0.01): self.initial_lr = initial_lr self.decay_rate = decay_rate def get_lr(self, step: int) -> float: return self.initial_lr / (1 + self.decay_rate * step) class RobbinsMonroSchedule(LRSchedule): """ρ_t = c / (t + τ)^α, satisfying Robbins-Monro conditions.""" def __init__(self, c: float = 1.0, tau: float = 10.0, alpha: float = 0.75): """ Args: c: Scale factor tau: Offset (larger = more conservative start) alpha: Decay exponent (must be in (0.5, 1] for convergence) """ assert 0.5 < alpha <= 1, "Alpha must be in (0.5, 1] for Robbins-Monro" self.c = c self.tau = tau self.alpha = alpha def get_lr(self, step: int) -> float: return self.c / (step + self.tau) ** self.alpha class CosineAnnealingSchedule(LRSchedule): """Cosine annealing with optional warm restarts.""" def __init__( self, lr_max: float = 0.1, lr_min: float = 1e-6, period: int = 1000, warmup_steps: int = 100 ): self.lr_max = lr_max self.lr_min = lr_min self.period = period self.warmup_steps = warmup_steps def get_lr(self, step: int) -> float: if step < self.warmup_steps: # Linear warmup return self.lr_max * (step + 1) / self.warmup_steps # Cosine annealing post_warmup_step = step - self.warmup_steps progress = (post_warmup_step % self.period) / self.period return self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(np.pi * progress)) class AdaptiveLRSchedule(LRSchedule): """ Adaptive learning rate based on optimization progress. Reduces LR when improvement stalls, increases when stable. """ def __init__( self, initial_lr: float = 0.01, min_lr: float = 1e-6, max_lr: float = 1.0, patience: int = 50, factor: float = 0.5 ): self.current_lr = initial_lr self.min_lr = min_lr self.max_lr = max_lr self.patience = patience self.factor = factor self.best_value = float('-inf') self.steps_without_improvement = 0 def step(self, current_value: float): """Update schedule based on current objective value.""" if current_value > self.best_value * 1.001: # Small threshold for improvement self.best_value = current_value self.steps_without_improvement = 0 else: self.steps_without_improvement += 1 if self.steps_without_improvement >= self.patience: self.current_lr = max(self.min_lr, self.current_lr * self.factor) self.steps_without_improvement = 0 def get_lr(self, step: int) -> float: return self.current_lr def robbins_monro_verify(schedule: LRSchedule, check_steps: int = 100000) -> bool: """ Verify that a schedule satisfies Robbins-Monro conditions (approximately). """ lrs = [schedule.get_lr(t) for t in range(check_steps)] # Sum should be large (diverging) lr_sum = sum(lrs) # Sum of squares should be converging lr_sq_sum = sum(lr**2 for lr in lrs) # Heuristic checks sum_diverges = lr_sum > check_steps * 0.01 # Growing significantly sq_converges = lr_sq_sum < check_steps * lrs[0]**2 # Growing slower than linear return sum_diverges and sq_convergesNatural gradients often converge faster than Euclidean gradients in practice. The theory explains this through the condition number of the optimization problem.
Condition number and convergence:
The condition number \(\kappa\) measures how "stretched" the optimization landscape is:
$$\kappa = \frac{L}{\mu}$$
where \(L\) is the smoothness constant and \(\mu\) is the strong convexity constant.
Natural gradients improve conditioning:
The Fisher Information Matrix \(\mathbf{F}\) acts as a preconditioner. Multiplying by \(\mathbf{F}^{-1}\) transforms the problem to have near-identity curvature in the natural geometry:
$$\kappa_{\text{natural}} \approx 1 \quad \text{vs} \quad \kappa_{\text{Euclidean}} \gg 1$$
For exponential family models with natural gradient updates and step size ρ_t = 1, each iteration exactly achieves the coordinate-wise optimum. This means natural gradient VI converges in a single pass for conditionally conjugate models—no iteration needed!
Quantifying the speedup:
For a Gaussian variational distribution with \(d\) dimensions:
The condition number \(\kappa\) can be \(10^3\)–\(10^6\) for ill-conditioned posteriors, representing potential 1000× speedups.
Natural gradients for neural network VI:
For Bayesian neural networks, exact natural gradients are impractical (inverting the full Fisher is \(O(d^3)\) for millions of parameters). Approximations like K-FAC maintain much of the benefit:
The tradeoff is computational: K-FAC adds ~20-50% overhead per iteration but converges in many fewer iterations.
Theoretical convergence guarantees are asymptotic—they tell us what happens as \(T \to \infty\). In practice, we need diagnostics to determine when to stop and whether optimization is proceeding normally.
Detecting common pathologies:
1. Posterior collapse (VAEs):
2. Oscillation:
3. Premature convergence:
4. Slow convergence:
For practical applications, perfect convergence often isn't necessary. Early stopping when validation metrics plateau can actually improve generalization by acting as regularization. The key is ensuring the variational approximation is accurate enough for downstream tasks—sometimes a rough posterior is sufficient.
Understanding convergence theory transforms SVI from a heuristic into a principled algorithm with predictable behavior.
What's next:
With the theoretical foundations complete, we conclude this module with practical considerations—a synthesis of implementation wisdom covering initialization strategies, debugging techniques, and guidelines for when to use (and not use) stochastic variational inference.
You now understand the convergence theory underlying stochastic variational inference. This knowledge enables you to set hyperparameters with confidence, diagnose optimization issues systematically, and reason about the computational requirements of probabilistic inference at scale.