Loading content...
Stochastic variational inference is a powerful framework, but translating the elegant mathematics into robust, production-ready code requires navigating a landscape of practical challenges that textbooks rarely discuss.
This page distills years of collective experience into actionable guidance for practitioners. We address the questions that arise when implementing SVI: How do I initialize? Why is my ELBO oscillating? When should I use SVI versus alternatives? How do I debug a model that isn't learning?
The goal is to transform you from someone who understands SVI to someone who can deploy it confidently on real problems.
By the end of this page, you will know how to initialize SVI for stable convergence, diagnose and fix common training pathologies, decide when SVI is (and isn't) the right approach, and implement production-grade variational inference systems.
Proper initialization is critical for successful SVI. Poor initialization can lead to slow convergence, convergence to suboptimal local optima, or outright divergence.
General principles:
Start near the prior: Initialize variational parameters so \(q(z; \phi_0) \approx p(z)\). This ensures valid probability distributions and often stable initial gradients.
Use small variances carefully: Very small initial variances can cause gradient explosion in the likelihood term; very large variances dilute the signal.
Leverage problem structure: Use domain knowledge to initialize near plausible posterior modes.
| Model | Variational Family | Recommended Initialization |
|---|---|---|
| Gaussian posterior | Diagonal Gaussian | μ = 0, log σ = 0 (unit variance) |
| Topic models (LDA) | Dirichlet | α = 1 + small noise (near uniform) |
| VAE encoder | Gaussian | Xavier/He init for network; μ→0, σ→1 |
| Bayesian neural net | Gaussian weights | μ = pretrained or Xavier; σ = small (0.01) |
| Mixture models | Categorical + Gaussian | K-means clustering for means; uniform mixing |
Initialization for VAEs:
Variational Autoencoders require careful initialization of both encoder and decoder networks:
# Encoder initialization
# Mean projection: zero output initially → z_mean ≈ 0
nn.init.xavier_uniform_(encoder_mean.weight)
nn.init.zeros_(encoder_mean.bias)
# Log-variance projection: small output initially → z_std ≈ 1
nn.init.xavier_uniform_(encoder_logvar.weight)
nn.init.constant_(encoder_logvar.bias, -2) # exp(-2) ≈ 0.14, conservative
# Decoder: standard initialization
for layer in decoder:
nn.init.xavier_uniform_(layer.weight)
Why this works:
For complex models, initialize from solutions to simpler problems: • For hierarchical VAEs: train bottom-up, one level at a time • For Bayesian neural networks: initialize from maximum likelihood (point estimate) weights • For deep topic models: initialize from shallow LDA
This 'curriculum' of increasing complexity dramatically improves convergence.
When SVI fails to converge or produces poor results, systematic debugging is essential. Here we catalog common failure modes and their remedies.
Problem 1: ELBO is NaN or -Inf
Causes:
Solutions:
# Add numerical stability to log computations
log_prob = torch.log(prob + 1e-10)
# Use log-sum-exp trick for softmax
def stable_softmax(logits):
max_logits = logits.max(dim=-1, keepdim=True).values
exp_logits = torch.exp(logits - max_logits)
return exp_logits / exp_logits.sum(dim=-1, keepdim=True)
# Clamp variance to prevent division by zero
var = torch.clamp(var, min=1e-6)
Problem 3: Posterior Collapse (VAEs)
Symptom: KL divergence → 0, decoder ignores latent code, all samples look similar.
Diagnosis:
# Check per-dimension KL
kl_per_dim = 0.5 * (mu**2 + var - 1 - torch.log(var))
print("KL per dimension:", kl_per_dim.mean(dim=0))
# If all near zero → collapsed
Solutions:
beta = min(1.0, epoch / warmup_epochs)
loss = -reconstruction + beta * kl
free_bits = 0.1
kl_per_dim = torch.clamp(kl_per_dim, min=free_bits)
kl = kl_per_dim.sum(dim=-1)
Weaker decoder: Use simpler decoder (fewer layers, less capacity) so latent code is necessary.
Input dropout: Drop out input features to force reliance on latent code.
Symptom: NaN gradients or ELBO suddenly goes to -Inf.
Solutions:
• Gradient clipping: torch.nn.utils.clip_grad_norm_(params, max_norm=5.0)
• Reduce learning rate by 10×
• Check for numerical instability in log/exp operations
• Use float64 for debugging, then switch back to float32
SVI has several hyperparameters that significantly affect performance. Here we provide guidance for setting them without exhaustive grid search.
Batch size:
The optimal batch size depends on computational resources and data characteristics:
Rule of thumb: Start with 128, increase if training is unstable, decrease if memory-limited.
| Hyperparameter | Starting Value | Tuning Strategy | Signs of Mistuning |
|---|---|---|---|
| Learning rate | 1e-3 (Adam), 0.01 (SGD) | LR range test | Oscillation (too high), no progress (too low) |
| Batch size | 128 | Double until memory limit | Noisy training (too small) |
| MC samples | 1 | Increase if variance high | Slow convergence with high variance |
| Latent dimension | Model-dependent | Cross-validation on held-out likelihood | Underfitting (too small), overfitting (too large) |
| KL weight (β) | 1.0 | Anneal from 0 | Posterior collapse (β too high early) |
Monte Carlo samples for gradient estimation:
The number of samples \(S\) for estimating \(\mathbb{E}_{q}[\cdot]\) trades variance for computation:
Pro tip: Use S = 1 during training (for speed) but S = 100+ for final evaluation (for accurate ELBO estimates).
Number of latent dimensions:
For VAEs and similar models:
Automatic approach: Use automatic relevance determination (ARD) priors that prune unused dimensions.
SVI is a powerful tool, but it's not always the right choice. Understanding its strengths and limitations helps select the best inference method for each problem.
Comparison with alternatives:
| Method | Scalability | Posterior Quality | Flexibility | Complexity |
|---|---|---|---|---|
| SVI | Excellent | Approximate | High | Medium |
| MCMC | Poor | Asymptotically exact | High | Low |
| Expectation Propagation | Moderate | Often better than VI | Low | High |
| Laplace Approximation | Excellent | Gaussian only | Low | Low |
| Maximum Likelihood | Excellent | Point estimate | High | Low |
Decision flowchart intuition:
Deploying SVI in production requires attention to reliability, monitoring, and operational concerns beyond pure algorithmic performance.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
import torchimport numpy as npfrom typing import Dict, Optional, Callablefrom dataclasses import dataclassimport loggingimport jsonfrom pathlib import Path logging.basicConfig(level=logging.INFO)logger = logging.getLogger(__name__) @dataclassclass SVIConfig: """Configuration for production SVI training.""" batch_size: int = 128 learning_rate: float = 1e-3 max_epochs: int = 100 patience: int = 10 min_delta: float = 1e-4 gradient_clip: float = 5.0 checkpoint_every: int = 10 validate_every: int = 1 seed: int = 42 def to_dict(self) -> dict: return {k: getattr(self, k) for k in self.__dataclass_fields__} class ProductionSVITrainer: """ Production-grade SVI trainer with: - Early stopping - Checkpointing - Logging and monitoring - Reproducibility - Error handling """ def __init__( self, model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, config: SVIConfig, output_dir: Path, device: str = "cuda" ): self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.config = config self.output_dir = Path(output_dir) self.device = device # Create output directory self.output_dir.mkdir(parents=True, exist_ok=True) # Save config with open(self.output_dir / "config.json", "w") as f: json.dump(config.to_dict(), f, indent=2) # Set seeds for reproducibility torch.manual_seed(config.seed) np.random.seed(config.seed) # Optimizer self.optimizer = torch.optim.AdamW( model.parameters(), lr=config.learning_rate, weight_decay=1e-5 ) # Learning rate scheduler self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='max', # Maximizing ELBO factor=0.5, patience=5, min_lr=1e-6 ) # Tracking self.best_val_elbo = float('-inf') self.epochs_without_improvement = 0 self.history = { 'train_elbo': [], 'val_elbo': [], 'learning_rate': [], 'gradient_norm': [] } def train_epoch(self) -> Dict[str, float]: """Train for one epoch.""" self.model.train() total_elbo = 0.0 total_grad_norm = 0.0 num_batches = 0 for batch in self.train_loader: x = batch[0].to(self.device) self.optimizer.zero_grad() # Forward pass try: elbo, metrics = self.model.compute_elbo(x) except RuntimeError as e: logger.error(f"Forward pass failed: {e}") raise # Backward pass loss = -elbo.mean() # Minimize negative ELBO loss.backward() # Gradient clipping grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config.gradient_clip ) # Check for NaN gradients if torch.isnan(grad_norm): logger.warning("NaN gradient detected, skipping batch") self.optimizer.zero_grad() continue self.optimizer.step() total_elbo += elbo.mean().item() total_grad_norm += grad_norm.item() num_batches += 1 return { 'train_elbo': total_elbo / num_batches, 'gradient_norm': total_grad_norm / num_batches } @torch.no_grad() def validate(self) -> float: """Compute validation ELBO.""" self.model.eval() total_elbo = 0.0 num_batches = 0 for batch in self.val_loader: x = batch[0].to(self.device) elbo, _ = self.model.compute_elbo(x) total_elbo += elbo.mean().item() num_batches += 1 return total_elbo / num_batches def save_checkpoint(self, epoch: int, is_best: bool = False): """Save model checkpoint.""" checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'best_val_elbo': self.best_val_elbo, 'history': self.history, 'config': self.config.to_dict() } path = self.output_dir / f"checkpoint_epoch_{epoch}.pt" torch.save(checkpoint, path) if is_best: best_path = self.output_dir / "best_model.pt" torch.save(checkpoint, best_path) logger.info(f"Saved best model with val_elbo={self.best_val_elbo:.4f}") def load_checkpoint(self, path: Path): """Load from checkpoint.""" checkpoint = torch.load(path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.best_val_elbo = checkpoint['best_val_elbo'] self.history = checkpoint['history'] return checkpoint['epoch'] def train(self) -> Dict[str, float]: """Full training loop with early stopping.""" logger.info(f"Starting training for up to {self.config.max_epochs} epochs") for epoch in range(self.config.max_epochs): # Training train_metrics = self.train_epoch() self.history['train_elbo'].append(train_metrics['train_elbo']) self.history['gradient_norm'].append(train_metrics['gradient_norm']) self.history['learning_rate'].append( self.optimizer.param_groups[0]['lr'] ) # Validation if epoch % self.config.validate_every == 0: val_elbo = self.validate() self.history['val_elbo'].append(val_elbo) # Learning rate scheduling self.scheduler.step(val_elbo) # Check for improvement if val_elbo > self.best_val_elbo + self.config.min_delta: self.best_val_elbo = val_elbo self.epochs_without_improvement = 0 self.save_checkpoint(epoch, is_best=True) else: self.epochs_without_improvement += 1 logger.info( f"Epoch {epoch}: train_elbo={train_metrics['train_elbo']:.4f}, " f"val_elbo={val_elbo:.4f}, lr={self.optimizer.param_groups[0]['lr']:.2e}" ) # Periodic checkpointing if epoch % self.config.checkpoint_every == 0: self.save_checkpoint(epoch) # Early stopping if self.epochs_without_improvement >= self.config.patience: logger.info(f"Early stopping at epoch {epoch}") break # Save final history with open(self.output_dir / "training_history.json", "w") as f: json.dump(self.history, f, indent=2) return { 'best_val_elbo': self.best_val_elbo, 'final_epoch': epoch, 'final_train_elbo': self.history['train_elbo'][-1] }Evaluating variational inference requires care—the ELBO is a lower bound on the log-likelihood, not the log-likelihood itself. Several metrics and techniques provide better insight into model quality.
IWAE_K → log p(x) as K → ∞.Importance-weighted ELBO:
The IWAE bound uses \(K\) samples to provide a tighter lower bound:
$$\log p(x) \geq \mathcal{L}K = \mathbb{E}{z_1, \ldots, z_K \sim q}\left[\log \frac{1}{K} \sum_{k=1}^{K} \frac{p(x, z_k)}{q(z_k)}\right]$$
As \(K \to \infty\), \(\mathcal{L}_K \to \log p(x)\). For evaluation:
def compute_iwae(model, x, K=100):
"""Compute importance-weighted ELBO."""
log_weights = []
for _ in range(K):
z, log_q = model.sample_and_log_prob(x)
log_p = model.joint_log_prob(x, z)
log_weights.append(log_p - log_q)
# Log-sum-exp for numerical stability
log_weights = torch.stack(log_weights, dim=0)
iwae = torch.logsumexp(log_weights, dim=0) - np.log(K)
return iwae.mean()
Model selection:
For selecting among models (different architectures, hyperparameters):
Important: Don't use training ELBO for model selection—it rewards overfitting.
Even experienced practitioners encounter pitfalls when implementing SVI. Here we catalog the most common mistakes and their solutions.
| Pitfall | Symptom | Solution |
|---|---|---|
| Forgetting the N/M scaling | ELBO is wrong by factor of N/batch_size | Always scale likelihood term by N/M in mini-batch ELBO |
| Using non-reparameterized gradients | Very high variance, slow convergence | Use rsample() not sample() for continuous latents |
| KL divergence computed incorrectly | Negative KL, training instability | Use library functions; verify on known distributions |
| Variance parameterization issues | NaN or negative variance | Parameterize as log(σ²) or use softplus(·) |
| Not using validation set | Overfitting undetected | Always hold out data for monitoring |
| Ignoring prior mismatch | Poor posterior approximation | Ensure prior matches problem structure |
Pitfall deep-dive: The N/M scaling factor
The most common bug in SVI implementations is incorrect scaling. The stochastic ELBO is:
$$\hat{\mathcal{L}} = \underbrace{-\text{KL}[q | p]}{\text{Not scaled}} + \underbrace{\frac{N}{M} \sum{j \in \text{batch}} \log p(x_j | z)}_{\text{Scaled by N/M}}$$
Wrong (common mistake):
# WRONG: Treats batch as if it's the entire dataset
loss = kl_divergence + reconstruction_loss.mean()
Correct:
# CORRECT: Scales reconstruction to full dataset
N = len(full_dataset)
M = batch_size
loss = kl_divergence + (N / M) * reconstruction_loss.sum()
# Or equivalently:
loss = (kl_divergence / N) + reconstruction_loss.mean() # Per-datapoint loss
The scaling matters because:
In PyTorch, sample() and rsample() are different!
• sample(): Non-differentiable sampling (blocks gradients)
• rsample(): Reparameterized sampling (gradients flow through)
Always use z = dist.rsample() for variational inference with continuous latents. Using sample() will silently produce zero gradients for the encoder.
This page has equipped you with the practical knowledge to implement, debug, and deploy stochastic variational inference in real-world applications.
Module complete:
You have now completed the module on Stochastic Variational Inference. You understand:
With this knowledge, you are prepared to apply SVI to real problems—from training VAEs on million-image datasets to fitting Bayesian neural networks for uncertainty-aware predictions to scaling topic models across document corpora.
Congratulations! You have mastered Stochastic Variational Inference. You now possess both the theoretical understanding and practical skills to implement scalable Bayesian inference systems. The techniques in this module form the foundation of modern probabilistic deep learning, from variational autoencoders to Bayesian neural networks to large-scale generative models.