Loading learning content...
At the heart of variational inference lies a single mathematical object that governs everything we do: the Evidence Lower Bound (ELBO). Understanding the ELBO deeply—not just its formula, but its internal structure—is the key to mastering variational inference and building intuition for why certain algorithms work the way they do.
In this page, we will dissect the ELBO into its component parts, revealing a beautiful and interpretable structure. This decomposition isn't merely an academic exercise; it provides the conceptual foundation for understanding VAEs, choosing regularization strategies, diagnosing training problems, and developing new variational methods.
By the end of this page, you will: (1) Understand multiple equivalent forms of the ELBO and when each is useful, (2) Decompose the ELBO into reconstruction and regularization terms, (3) Interpret each component probabilistically and information-theoretically, (4) See how the decomposition guides algorithm design and hyperparameter tuning, and (5) Connect ELBO optimization to broader themes in machine learning.
Before diving into decomposition, let's ensure we have a rock-solid understanding of where the ELBO comes from and why it's the objective we optimize in variational inference.
The Inference Problem:
Given observed data $\mathbf{x}$ and a probabilistic model with latent variables $\mathbf{z}$, we want to compute the posterior distribution:
$$p(\mathbf{z}|\mathbf{x}) = \frac{p(\mathbf{x}|\mathbf{z})p(\mathbf{z})}{p(\mathbf{x})}$$
The challenge is the denominator—the marginal likelihood or evidence:
$$p(\mathbf{x}) = \int p(\mathbf{x}|\mathbf{z})p(\mathbf{z})d\mathbf{z}$$
This integral is typically intractable. It requires integrating over all possible configurations of latent variables, which grows exponentially with the dimensionality and complexity of $\mathbf{z}$.
Consider a simple example: if z consists of 100 binary latent variables, computing p(x) exactly would require summing over 2^100 ≈ 10^30 configurations. Even with continuous latent variables, the integral has no closed form for most interesting models. This computational barrier is what motivates approximate inference.
The Variational Approach:
Variational inference sidesteps intractability by introducing an approximate posterior $q_\phi(\mathbf{z}|\mathbf{x})$ from a tractable family of distributions (parameterized by $\phi$), and then optimizing $q$ to be as close as possible to the true posterior.
The measure of closeness is the Kullback-Leibler divergence:
$$D_{KL}(q_\phi(\mathbf{z}|\mathbf{x}) | p(\mathbf{z}|\mathbf{x})) = \mathbb{E}{q\phi}\left[\log \frac{q_\phi(\mathbf{z}|\mathbf{x})}{p(\mathbf{z}|\mathbf{x})}\right]$$
But there's a problem: this KL divergence itself depends on $p(\mathbf{z}|\mathbf{x})$, which contains the intractable $p(\mathbf{x})$ in its denominator. How can we optimize an objective we can't compute?
The ELBO Derivation:
The resolution comes from a clever manipulation. Starting with the log-evidence:
$$\log p(\mathbf{x}) = \log \int p(\mathbf{x}, \mathbf{z})d\mathbf{z}$$
We introduce our approximate posterior through importance weighting:
$$\log p(\mathbf{x}) = \log \int q_\phi(\mathbf{z}|\mathbf{x}) \frac{p(\mathbf{x}, \mathbf{z})}{q_\phi(\mathbf{z}|\mathbf{x})}d\mathbf{z}$$
Applying Jensen's inequality (since $\log$ is concave):
$$\log p(\mathbf{x}) \geq \int q_\phi(\mathbf{z}|\mathbf{x}) \log \frac{p(\mathbf{x}, \mathbf{z})}{q_\phi(\mathbf{z}|\mathbf{x})}d\mathbf{z}$$
This lower bound is the Evidence Lower Bound (ELBO):
$$\mathcal{L}(\phi) = \mathbb{E}{q\phi(\mathbf{z}|\mathbf{x})}\left[\log \frac{p(\mathbf{x}, \mathbf{z})}{q_\phi(\mathbf{z}|\mathbf{x})}\right]$$
The ELBO is computable! It only requires sampling from q_φ(z|x) (which we control) and evaluating the joint p(x, z) (which is part of our model specification). We never need to compute p(x) directly. Maximizing the ELBO with respect to φ tightens the bound, making q closer to the true posterior.
The Gap Between ELBO and Log-Evidence:
How tight is this bound? The gap is precisely the KL divergence we wanted to minimize:
$$\log p(\mathbf{x}) = \mathcal{L}(\phi) + D_{KL}(q_\phi(\mathbf{z}|\mathbf{x}) | p(\mathbf{z}|\mathbf{x}))$$
Since KL divergence is always non-negative, the ELBO is indeed a lower bound. Moreover, maximizing the ELBO is equivalent to minimizing the KL divergence (since $\log p(\mathbf{x})$ is constant with respect to $\phi$).
This relationship is profound: by maximizing a tractable objective (ELBO), we are implicitly minimizing an intractable one (KL to true posterior).
Now we arrive at the central topic: decomposing the ELBO into interpretable components. Starting from the ELBO definition:
$$\mathcal{L}(\phi) = \mathbb{E}{q\phi(\mathbf{z}|\mathbf{x})}\left[\log \frac{p(\mathbf{x}, \mathbf{z})}{q_\phi(\mathbf{z}|\mathbf{x})}\right]$$
We use the factorization of the joint: $p(\mathbf{x}, \mathbf{z}) = p(\mathbf{x}|\mathbf{z})p(\mathbf{z})$:
$$\mathcal{L}(\phi) = \mathbb{E}{q\phi(\mathbf{z}|\mathbf{x})}\left[\log \frac{p(\mathbf{x}|\mathbf{z})p(\mathbf{z})}{q_\phi(\mathbf{z}|\mathbf{x})}\right]$$
Splitting the logarithm:
$$\mathcal{L}(\phi) = \mathbb{E}{q\phi(\mathbf{z}|\mathbf{x})}[\log p(\mathbf{x}|\mathbf{z})] + \mathbb{E}{q\phi(\mathbf{z}|\mathbf{x})}\left[\log \frac{p(\mathbf{z})}{q_\phi(\mathbf{z}|\mathbf{x})}\right]$$
Recognizing the second term as a negative KL divergence:
$$\mathcal{L}(\phi) = \underbrace{\mathbb{E}{q\phi(\mathbf{z}|\mathbf{x})}[\log p(\mathbf{x}|\mathbf{z})]}\text{Reconstruction Term} - \underbrace{D{KL}(q_\phi(\mathbf{z}|\mathbf{x}) | p(\mathbf{z}))}_\text{Regularization Term}$$
ELBO = Reconstruction - Regularization. This decomposition reveals that variational inference balances two competing objectives: faithfully reconstructing the data while keeping the approximate posterior close to the prior. This tension is the essence of probabilistic modeling.
Let's examine each term in depth:
The Reconstruction Term: $\mathbb{E}{q\phi(\mathbf{z}|\mathbf{x})}[\log p(\mathbf{x}|\mathbf{z})]$
This term measures how well latent codes sampled from $q$ explain the observed data. Think of it as:
Maximizing this term pushes the model to learn latent representations that make the data likely. If $q$ encodes data into latent codes that decode back to something completely different, this term will be very negative.
Interpretation: The reconstruction term is the expected log-likelihood under the variational posterior. In autoencoders, this corresponds to how well the decoder reconstructs inputs from encoded representations.
The Regularization Term: $D_{KL}(q_\phi(\mathbf{z}|\mathbf{x}) | p(\mathbf{z}))$
This term measures the divergence between our approximate posterior and the prior. Minimizing this (by subtracting it in the ELBO) encourages $q$ to stay close to $p(\mathbf{z})$.
Several interpretations illuminate this term:
1. Complexity Penalty: The KL term acts like a complexity regularizer. Without it, $q$ could become arbitrarily complex—learning a different, narrow distribution for each data point. The prior keeps $q$ grounded.
2. Information Rate: From an information-theoretic perspective, $D_{KL}(q | p)$ bounds the mutual information between $\mathbf{x}$ and $\mathbf{z}$. This limits how much information about $\mathbf{x}$ can be encoded in $\mathbf{z}$.
3. Latent Space Structure: The prior $p(\mathbf{z})$ (often a standard Gaussian) imposes structure on the latent space. Forcing $q$ toward this prior encourages smooth, continuous latent representations.
4. Generative Quality: For generative models, we sample $\mathbf{z} \sim p(\mathbf{z})$ at generation time. If $q$-learned latent codes differ dramatically from the prior, generated samples won't resemble training data.
| Term | Mathematical Form | Encourages | Discourages |
|---|---|---|---|
| Reconstruction | $\mathbb{E}_q[\log p(\mathbf{x}|\mathbf{z})]$ | Expressive latent codes that perfectly describe each data point | Lossy representations that discard data-specific information |
| Regularization | $-D_{KL}(q | p)$ | Simple, prior-like posteriors that generalize well | Complex, data-specific posteriors that overfit to individual examples |
The reconstruction-regularization decomposition is the most common, but alternative decompositions provide additional insights. Each reveals different aspects of what the ELBO is doing.
Energy-Based Decomposition:
We can also write the ELBO as:
$$\mathcal{L}(\phi) = \mathbb{E}{q\phi(\mathbf{z}|\mathbf{x})}[\log p(\mathbf{x}, \mathbf{z})] + H(q_\phi(\mathbf{z}|\mathbf{x}))$$
where $H(q)$ is the entropy of the approximate posterior. This decomposes into:
This view connects variational inference to energy-based modeling. The energy term says "find probable configurations" while entropy says "maintain uncertainty."
Free Energy Perspective:
In statistical physics, the ELBO corresponds to the negative variational free energy:
$$\mathcal{F}(\phi) = -\mathcal{L}(\phi) = \mathbb{E}{q\phi}[-\log p(\mathbf{x}, \mathbf{z})] - H(q_\phi)$$
Minimizing free energy means minimizing expected energy while maximizing entropy—the second law of thermodynamics! This connection explains why variational inference is sometimes called "free energy minimization."
Bits-Back Interpretation:
From a coding theory perspective:
$$\mathcal{L}(\phi) = \underbrace{-\mathbb{E}q[\log q\phi(\mathbf{z}|\mathbf{x})]}_\text{Bits to encode z} + \underbrace{\mathbb{E}q[\log p(\mathbf{x}, \mathbf{z})]}\text{Bits refunded for explaining x,z}$$
The ELBO measures the expected number of bits "saved" by using the learned model instead of a baseline. If the ELBO is positive, the model compresses data; if negative, it's worse than random.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
import torchimport torch.nn.functional as Ffrom torch.distributions import Normal, kl_divergence def compute_elbo_standard(x, encoder, decoder, prior): """ Standard ELBO decomposition: Reconstruction - KL Args: x: Input data [batch_size, ...] encoder: Returns q(z|x) distribution parameters decoder: Returns p(x|z) distribution parameters prior: Prior p(z) distribution """ # Encode: get approximate posterior parameters z_mean, z_logvar = encoder(x) z_std = torch.exp(0.5 * z_logvar) # Approximate posterior q(z|x) q_z = Normal(z_mean, z_std) # Sample using reparameterization z = q_z.rsample() # Decode: compute reconstruction likelihood x_recon_mean, x_recon_logvar = decoder(z) # Reconstruction term: E_q[log p(x|z)] # For Gaussian likelihood with learned variance recon_loss = -0.5 * torch.sum( x_recon_logvar + (x - x_recon_mean)**2 / torch.exp(x_recon_logvar), dim=-1 ) # Regularization term: KL(q(z|x) || p(z)) kl_loss = kl_divergence(q_z, prior).sum(dim=-1) # ELBO = Reconstruction - KL elbo = recon_loss - kl_loss return elbo.mean(), recon_loss.mean(), kl_loss.mean() def compute_elbo_energy(x, encoder, decoder, prior): """ Energy-based decomposition: E_q[log p(x,z)] + H(q) """ z_mean, z_logvar = encoder(x) z_std = torch.exp(0.5 * z_logvar) q_z = Normal(z_mean, z_std) z = q_z.rsample() x_recon_mean, x_recon_logvar = decoder(z) # Log joint: log p(x,z) = log p(x|z) + log p(z) log_px_given_z = -0.5 * torch.sum( x_recon_logvar + (x - x_recon_mean)**2 / torch.exp(x_recon_logvar), dim=-1 ) log_pz = prior.log_prob(z).sum(dim=-1) log_joint = log_px_given_z + log_pz # Entropy of q(z|x): H(q) = 0.5 * log(2πe) * d + 0.5 * sum(log_var) # For Gaussian: H = 0.5 * sum(1 + log(2π) + log_var) entropy = 0.5 * torch.sum(1 + torch.log(2 * torch.pi * torch.ones_like(z_std)) + z_logvar, dim=-1) # ELBO = E[log p(x,z)] + H(q) elbo = log_joint + entropy return elbo.mean(), log_joint.mean(), entropy.mean() def verify_decomposition_equivalence(x, encoder, decoder, prior, n_samples=1000): """ Verify that both decompositions give the same ELBO. This is a sanity check for implementation correctness. """ elbos_standard = [] elbos_energy = [] for _ in range(n_samples): elbo_std, _, _ = compute_elbo_standard(x, encoder, decoder, prior) elbo_eng, _, _ = compute_elbo_energy(x, encoder, decoder, prior) elbos_standard.append(elbo_std.item()) elbos_energy.append(elbo_eng.item()) mean_std = torch.tensor(elbos_standard).mean() mean_eng = torch.tensor(elbos_energy).mean() print(f"Standard decomposition ELBO: {mean_std:.4f}") print(f"Energy decomposition ELBO: {mean_eng:.4f}") print(f"Difference: {abs(mean_std - mean_eng):.6f}") return abs(mean_std - mean_eng) < 0.01 # Should be very closeOne of the most illuminating perspectives on the ELBO comes from rate-distortion theory, a branch of information theory that studies optimal lossy compression.
Setting Up the Connection:
Consider what VAEs and variational inference are doing:
This is precisely the setting of rate-distortion theory!
Formal Definitions:
Rate: The mutual information between $\mathbf{x}$ and $\mathbf{z}$: $$R = I(\mathbf{x}; \mathbf{z}) = \mathbb{E}{p(\mathbf{x})}[D{KL}(q_\phi(\mathbf{z}|\mathbf{x}) | p(\mathbf{z}))]$$
This measures how many bits of information about $\mathbf{x}$ are encoded in $\mathbf{z}$.
Distortion: The expected reconstruction error: $$D = \mathbb{E}{p(\mathbf{x})}\mathbb{E}{q_\phi(\mathbf{z}|\mathbf{x})}[d(\mathbf{x}, \hat{\mathbf{x}})]$$
where $d$ is some distortion measure (e.g., squared error) and $\hat{\mathbf{x}}$ is the reconstruction.
The Tradeoff:
Rate-distortion theory proves that there's a fundamental tradeoff: lower distortion requires higher rate (more bits). The rate-distortion curve captures the optimal tradeoff achievable by any encoder/decoder pair.
The ELBO can be viewed as a Lagrangian for the rate-distortion optimization problem: minimize distortion (negative reconstruction) subject to a rate constraint (KL term). The coefficient β in β-VAE controls this tradeoff—higher β means lower rate but higher distortion.
The β-VAE Objective:
The β-VAE modifies the standard ELBO:
$$\mathcal{L}\beta = \mathbb{E}{q_\phi(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})] - \beta \cdot D_{KL}(q_\phi(\mathbf{z}|\mathbf{x}) | p(\mathbf{z}))$$
This parameter traces out points on the rate-distortion curve:
| β Value | Rate | Distortion | Latent Properties | Use Case |
|---|---|---|---|---|
| β << 1 | Very High | Very Low | Nearly deterministic, memorizes data | Perfect reconstruction needed |
| β = 1 | Moderate | Moderate | Balanced compression | Standard generative modeling |
| β > 1 | Low | Higher | Disentangled, interpretable factors | Representation learning |
| β >> 1 | Very Low | Very High | Ignores data, posterior ≈ prior | Pure regularization (degenerate) |
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
import torchimport numpy as npimport matplotlib.pyplot as pltfrom collections import defaultdict def analyze_rate_distortion(model, data_loader, beta_values): """ Analyze the rate-distortion tradeoff across different β values. For each β, computes: - Rate: Average KL divergence - Distortion: Average reconstruction error """ results = defaultdict(list) for beta in beta_values: # Set the beta value model.beta = beta total_kl = 0 total_recon = 0 total_samples = 0 with torch.no_grad(): for x in data_loader: # Forward pass z_mean, z_logvar = model.encoder(x) z_std = torch.exp(0.5 * z_logvar) # Compute KL divergence (rate) kl = 0.5 * torch.sum( z_mean**2 + z_std**2 - z_logvar - 1, dim=-1 ).mean() # Sample and reconstruct z = z_mean + z_std * torch.randn_like(z_std) x_recon = model.decoder(z) # Compute reconstruction error (distortion) recon_error = torch.sum((x - x_recon)**2, dim=-1).mean() total_kl += kl.item() * len(x) total_recon += recon_error.item() * len(x) total_samples += len(x) avg_rate = total_kl / total_samples avg_distortion = total_recon / total_samples results['beta'].append(beta) results['rate'].append(avg_rate) results['distortion'].append(avg_distortion) print(f"β={beta:.2f}: Rate={avg_rate:.4f}, Distortion={avg_distortion:.4f}") return results def plot_rate_distortion_curve(results): """ Plot the rate-distortion curve from experimental results. """ fig, axes = plt.subplots(1, 2, figsize=(12, 5)) # Rate-Distortion curve axes[0].scatter(results['rate'], results['distortion'], c=results['beta'], cmap='viridis', s=100) axes[0].plot(results['rate'], results['distortion'], 'k--', alpha=0.5) axes[0].set_xlabel('Rate (KL Divergence)', fontsize=12) axes[0].set_ylabel('Distortion (Reconstruction Error)', fontsize=12) axes[0].set_title('Rate-Distortion Curve', fontsize=14) # Add colorbar for beta values sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(min(results['beta']), max(results['beta']))) plt.colorbar(sm, ax=axes[0], label='β value') # Rate and Distortion vs Beta ax1 = axes[1] ax2 = ax1.twinx() line1, = ax1.plot(results['beta'], results['rate'], 'b-o', label='Rate') line2, = ax2.plot(results['beta'], results['distortion'], 'r-s', label='Distortion') ax1.set_xlabel('β', fontsize=12) ax1.set_ylabel('Rate', color='b', fontsize=12) ax2.set_ylabel('Distortion', color='r', fontsize=12) ax1.set_title('Rate and Distortion vs β', fontsize=14) # Combine legends lines = [line1, line2] labels = [l.get_label() for l in lines] ax1.legend(lines, labels, loc='center right') plt.tight_layout() return figThe ELBO decomposition has deep connections to information theory that provide profound insights into what variational inference is achieving.
Mutual Information Bound:
The KL regularization term bounds the mutual information between data and latents:
$$I(\mathbf{x}; \mathbf{z}) = H(\mathbf{z}) - H(\mathbf{z}|\mathbf{x})$$
For the variational distribution: $$I_q(\mathbf{x}; \mathbf{z}) \leq \mathbb{E}{p(\mathbf{x})}[D{KL}(q_\phi(\mathbf{z}|\mathbf{x}) | p(\mathbf{z}))]$$
This says the KL term is an upper bound on how much information about $\mathbf{x}$ is captured in $\mathbf{z}$. Minimizing KL directly limits this information flow.
This connects VAEs to the Information Bottleneck principle: learn representations that compress X into Z while preserving information relevant for some downstream task. In unsupervised VAEs, the 'task' is reconstruction of X itself.
Decomposing Mutual Information:
We can further decompose the mutual information:
$$I(\mathbf{x}; \mathbf{z}) = \sum_i I(\mathbf{x}; z_i | z_{<i})$$
This reveals how different latent dimensions contribute to capturing data information. In a well-trained VAE:
The Posterior Collapse Problem:
A notorious issue in VAEs is posterior collapse, where $q_\phi(\mathbf{z}|\mathbf{x}) \approx p(\mathbf{z})$ for many or all dimensions. Information-theoretically:
This happens when the regularization term dominates reconstruction, or when the decoder is so powerful it doesn't need the latent code.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
import torchimport torch.nn as nnfrom torch.distributions import Normal, kl_divergence def compute_active_units(model, data_loader, threshold=0.01): """ Count 'active' latent dimensions based on KL contribution. A dimension is active if its average KL contribution exceeds threshold. Inactive dimensions indicate posterior collapse for that dimension. """ kl_per_dim = None n_samples = 0 with torch.no_grad(): for x in data_loader: z_mean, z_logvar = model.encoder(x) # Per-dimension KL: 0.5 * (μ² + σ² - 1 - log(σ²)) kl_dim = 0.5 * (z_mean**2 + torch.exp(z_logvar) - 1 - z_logvar) if kl_per_dim is None: kl_per_dim = kl_dim.sum(dim=0) else: kl_per_dim += kl_dim.sum(dim=0) n_samples += len(x) avg_kl_per_dim = kl_per_dim / n_samples active_units = (avg_kl_per_dim > threshold).sum().item() total_units = len(avg_kl_per_dim) print(f"Active units: {active_units}/{total_units}") print(f"KL per dimension: {avg_kl_per_dim.cpu().numpy()}") return active_units, avg_kl_per_dim def estimate_mutual_information(model, data_loader, n_samples=10): """ Estimate mutual information I(X; Z) using variational lower bound. Uses the variational lower bound: I(X;Z) ≥ E_p(x)[E_q(z|x)[log q(z|x)] - E_q(z)[log q(z)]] The marginal q(z) is approximated by aggregating over data points. """ # First pass: collect aggregate statistics for marginal q(z) z_samples_all = [] posteriors = [] with torch.no_grad(): for x in data_loader: z_mean, z_logvar = model.encoder(x) z_std = torch.exp(0.5 * z_logvar) # Store posterior parameters posteriors.append((z_mean, z_std)) # Sample from posterior for _ in range(n_samples): z = z_mean + z_std * torch.randn_like(z_std) z_samples_all.append(z) z_samples_all = torch.cat(z_samples_all, dim=0) # Fit a Gaussian to the aggregated samples (approximate marginal) z_marginal_mean = z_samples_all.mean(dim=0) z_marginal_var = z_samples_all.var(dim=0) # Estimate mutual information mi_estimate = 0 n_total = 0 for z_mean, z_std in posteriors: batch_size = z_mean.shape[0] # Entropy of posterior: H(q(z|x)) # For Gaussian: H = 0.5 * d * (1 + log(2π)) + 0.5 * sum(log σ²) posterior_entropy = 0.5 * ( z_mean.shape[-1] * (1 + torch.log(torch.tensor(2 * torch.pi))) + torch.sum(2 * torch.log(z_std), dim=-1) ) # Cross-entropy with marginal: E_q(z|x)[log q(z)] # Approximate q(z) as Gaussian with computed mean/var z = z_mean + z_std * torch.randn_like(z_std) log_qz = Normal(z_marginal_mean, torch.sqrt(z_marginal_var)).log_prob(z).sum(dim=-1) # MI contribution: H(Z|X) - H(Z) ≈ E[-log q(z|x)] - E[-log q(z)] # Or equivalently: E[log q(z|x)] - E[log q(z)] mi_batch = -posterior_entropy - log_qz mi_estimate += mi_batch.sum().item() n_total += batch_size mi_estimate /= n_total print(f"Estimated Mutual Information I(X;Z): {mi_estimate:.4f} nats") return mi_estimateThe standard ELBO is just one lower bound on log-evidence. Researchers have developed tighter bounds that can improve learning, especially when the variational family is limited.
The Importance Weighted ELBO (IWELBO):
The IWAE bound uses multiple samples from $q$ to tighten the bound:
$$\mathcal{L}K = \mathbb{E}{z_1,...,z_K \sim q_\phi}\left[\log \frac{1}{K}\sum_{k=1}^K \frac{p(\mathbf{x}, z_k)}{q_\phi(z_k|\mathbf{x})}\right]$$
Properties:
However, the signal-to-noise ratio of gradients degrades with large $K$, trading off bound tightness against gradient quality.
Multi-Sample Bounds:
Several variants of multi-sample objectives exist:
MIWAE (Multiple Importance Weighted AE): Uses multiple samples from the encoder within each importance weight calculation, combining ideas from IWAE and multiple samples.
DReG (Doubly Reparameterized Gradients): A gradient estimator that maintains signal-to-noise ratio even with many importance samples, addressing IWAE's SNR degradation.
SUMO/OVIS: Methods that unbiasedly estimate the evidence bounds without the $K$-sample limitation, using techniques from randomized computing.
| Bound | Tightness | Compute Cost | Gradient Quality | Best Use Case |
|---|---|---|---|---|
| Standard ELBO | Loose | O(1) samples | High SNR | General-purpose VAE training |
| IWAE (K=5) | Tighter | O(K) samples | Medium SNR | Better evidence estimation |
| IWAE (K=50) | Very Tight | O(K) samples | Low SNR | Evaluation, not training |
| DReG | Tight | O(K) samples | High SNR | Training with many samples |
| Hierarchical VI | Depends | O(depth) | High SNR | Expressive posteriors |
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
import torchfrom torch.distributions import Normal def compute_iwae_bound(x, encoder, decoder, prior, K=5): """ Compute the Importance Weighted Autoencoder (IWAE) bound. Args: x: Input data [batch_size, ...] encoder: Returns q(z|x) parameters decoder: Returns p(x|z) parameters prior: Prior p(z) distribution K: Number of importance samples Returns: IWAE bound estimate """ batch_size = x.shape[0] # Get posterior parameters z_mean, z_logvar = encoder(x) z_std = torch.exp(0.5 * z_logvar) q_z = Normal(z_mean, z_std) # Sample K latent codes for each data point # Shape: [K, batch_size, latent_dim] z_samples = q_z.rsample((K,)) # Compute log weights for each sample log_weights = [] for k in range(K): z = z_samples[k] # Log joint: log p(x, z) x_recon_mean, x_recon_logvar = decoder(z) log_px_z = -0.5 * torch.sum( x_recon_logvar + (x - x_recon_mean)**2 / torch.exp(x_recon_logvar), dim=-1 ) log_pz = prior.log_prob(z).sum(dim=-1) log_joint = log_px_z + log_pz # Log proposal: log q(z|x) log_qz = q_z.log_prob(z).sum(dim=-1) # Log importance weight log_w = log_joint - log_qz log_weights.append(log_w) # Stack: [K, batch_size] log_weights = torch.stack(log_weights, dim=0) # IWAE bound: log(1/K * sum(exp(log_w))) # Use log-sum-exp for numerical stability iwae_bound = torch.logsumexp(log_weights, dim=0) - torch.log(torch.tensor(K, dtype=torch.float)) return iwae_bound.mean() def compare_bounds(x, encoder, decoder, prior, K_values=[1, 5, 10, 50]): """ Compare ELBO tightness across different K values. """ print("Comparing bounds (higher is tighter/better):") print("-" * 40) bounds = {} for K in K_values: bound = compute_iwae_bound(x, encoder, decoder, prior, K=K) bounds[K] = bound.item() print(f"K={K:3d}: IWAE bound = {bound.item():.4f}") # Verify monotonicity prev = None monotonic = True for K in sorted(K_values): if prev is not None and bounds[K] < prev - 0.01: # Allow small numerical error monotonic = False prev = bounds[K] print("-" * 40) print(f"Bounds monotonically increasing: {monotonic}") return boundsUnderstanding the ELBO decomposition has immediate practical consequences for training and debugging variational models.
Monitoring Training:
During training, you should track both components separately:
Warning Signs:
A common technique is to anneal β from 0 to 1 during training. Start with β=0 (pure autoencoder, maximizing reconstruction), then gradually increase to β=1 (full ELBO). This helps the model first learn useful reconstructions before enforcing the prior regularization.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
import torchimport wandbfrom collections import deque class ELBOMonitor: """ Monitor ELBO components during training for diagnostics. """ def __init__(self, window_size=100): self.reconstruction_history = deque(maxlen=window_size) self.kl_history = deque(maxlen=window_size) self.elbo_history = deque(maxlen=window_size) self.active_units_history = deque(maxlen=window_size) def update(self, recon_loss, kl_loss, elbo, active_units=None): """Record a training step's metrics.""" self.reconstruction_history.append(recon_loss) self.kl_history.append(kl_loss) self.elbo_history.append(elbo) if active_units is not None: self.active_units_history.append(active_units) def diagnose(self): """Analyze current training state and provide diagnostics.""" diagnostics = {} if len(self.kl_history) < 10: return {"status": "insufficient_data"} avg_kl = sum(self.kl_history) / len(self.kl_history) avg_recon = sum(self.reconstruction_history) / len(self.reconstruction_history) # Check for posterior collapse if avg_kl < 0.1: diagnostics["posterior_collapse"] = { "detected": True, "severity": "severe" if avg_kl < 0.01 else "moderate", "recommendation": "Try KL warmup, free bits, or weaker decoder" } # Check for training instability kl_std = torch.std(torch.tensor(list(self.kl_history))).item() if kl_std > avg_kl * 0.5: diagnostics["unstable_training"] = { "detected": True, "kl_std": kl_std, "recommendation": "Reduce learning rate or add gradient clipping" } # Check reconstruction progress recent_recon = list(self.reconstruction_history)[-10:] old_recon = list(self.reconstruction_history)[:10] if len(old_recon) == 10: recon_improvement = sum(old_recon)/10 - sum(recent_recon)/10 if recon_improvement < 0.1: diagnostics["reconstruction_plateau"] = { "detected": True, "recommendation": "Increase decoder capacity or learning rate" } # Active units check if len(self.active_units_history) > 0: recent_active = self.active_units_history[-1] if recent_active < 3: diagnostics["low_active_units"] = { "detected": True, "count": recent_active, "recommendation": "Reduce latent dim or use auxiliary losses" } diagnostics["summary"] = { "avg_reconstruction": avg_recon, "avg_kl": avg_kl, "avg_elbo": sum(self.elbo_history) / len(self.elbo_history) } return diagnostics def log_to_wandb(self, step): """Log current metrics to Weights & Biases.""" if len(self.elbo_history) == 0: return wandb.log({ "elbo/reconstruction": self.reconstruction_history[-1], "elbo/kl_divergence": self.kl_history[-1], "elbo/total": self.elbo_history[-1], "elbo/reconstruction_avg": sum(self.reconstruction_history) / len(self.reconstruction_history), "elbo/kl_avg": sum(self.kl_history) / len(self.kl_history), }, step=step)We've taken a deep dive into the ELBO, revealing its internal structure and the profound insights this structure provides. Let's consolidate what we've learned.
Connections to Other Topics:
What's Next:
In the next page, we'll take an even deeper look at the reconstruction-regularization tradeoff, examining the tension between these terms, techniques to balance them (β-VAE, free bits, warmup), and the conceptual implications for representation learning.
You now have a comprehensive understanding of the ELBO and its decomposition. This forms the conceptual foundation for all variational inference methods. The reconstruction-regularization split will guide your intuitions about VAE behavior, training dynamics, and representation quality.