Loading learning content...
Real-world data often exhibits hierarchical structure: a document is composed of paragraphs, which are composed of sentences, which are composed of words. Images have global structure (scene layout), mid-level structure (objects), and local structure (textures). Hierarchical latent variable models capture this by introducing multiple levels of latent variables, each representing different scales of abstraction.
Hierarchical VI addresses the challenge of performing approximate inference in these models. The posterior over all latent levels becomes highly complex, with intricate dependencies that naive factorized approximations miss entirely.
By the end of this page, you will understand hierarchical latent variable models, master the bidirectional inference architecture, learn about ladder VAEs and their training, understand the posterior collapse problem in hierarchical models, and know techniques for encouraging all levels to be utilized.
A hierarchical latent variable model has a generative process that unfolds through multiple layers:
$$p(x, z_{1:L}) = p(z_L) \prod_{l=1}^{L-1} p(z_l | z_{l+1}) \cdot p(x | z_1)$$
The generative process is top-down:
Intuition:
| Level | Abstraction | Example Features |
|---|---|---|
| z_L (top) | Global/semantic | Object category, scene type, identity |
| z_{L-1} | High-level structure | Pose, layout, composition |
| ... | Mid-level | Parts, edges, textures |
| z_2 | Local structure | Fine details, colors |
| z_1 (bottom) | Pixel-level | Specific pixel patterns |
Parameterization Options:
Conditional Gaussian (most common): $$p(z_l | z_{l+1}) = \mathcal{N}(z_l; \mu_l(z_{l+1}), \sigma_l^2(z_{l+1}))$$
where μ_l and σ_l are neural networks.
Autoregressive within levels: $$p(z_l | z_{l+1}) = \prod_j p(z_{l,j} | z_{l, <j}, z_{l+1})$$
Normalizing flows within levels: $$z_l = f_l(\epsilon; z_{l+1}), \quad \epsilon \sim \mathcal{N}(0, I)$$
Each parameterization trades off expressiveness against computational cost.
A deeper hierarchy (more levels) captures finer abstractions but is harder to train. A wider hierarchy (larger z_l dimensions) increases capacity at each level but not the abstraction hierarchy. In practice, 2-5 levels are common, with the effective depth depending on the skip connection structure.
In hierarchical models, the true posterior p(z_{1:L}|x) has complex dependencies across all levels. The challenge: how do we approximate this joint posterior efficiently?
Naive Factorization Fails:
A fully factorized approximation: $$q(z_{1:L}|x) = \prod_{l=1}^{L} q(z_l|x)$$
ignores all dependencies between levels. Since the true posterior has strong cross-level correlations, this leads to poor approximations.
Top-Down Factorization:
Mirrors the generative model structure: $$q(z_{1:L}|x) = q(z_L|x) \prod_{l=1}^{L-1} q(z_l|z_{l+1}, x)$$
Better captures dependencies but requires careful design.
The Information Flow Problem:
In a hierarchical VAE, information about x must flow to all latent levels. Two competing paths:
If the bottom-up path is too powerful, higher levels may be ignored (all information encoded in z_1). If the top-down path dominates, the model may not use conditioning on x effectively.
The ELBO Decomposition:
For hierarchical models, the ELBO decomposes across levels:
$$\mathcal{L} = \mathbb{E}q[\log p(x|z_1)] - \sum{l=1}^{L} \text{KL}(q(z_l|\cdot) | p(z_l|\cdot))$$
Each KL term measures the "cost" of using level l. If a level is not needed, its KL goes to zero (posterior collapse).
In hierarchical models, posterior collapse can occur at any level. Often, higher levels collapse first—the model learns to encode everything in z_1 and ignores z_2, z_3, etc. This defeats the purpose of the hierarchy and degrades generation quality.
Modern hierarchical VAEs use bidirectional inference, combining bottom-up and top-down information to compute each level's posterior.
The Architecture:
Bottom-up pass: Process x through encoder to get deterministic features at each level $$h_l^{\text{bu}} = f_l^{\text{enc}}(h_{l-1}^{\text{bu}}), \quad h_0^{\text{bu}} = x$$
Top-down pass: Starting from z_L, combine with bottom-up features $$h_l^{\text{td}} = g_l^{\text{dec}}(z_{l+1}, h_l^{\text{bu}})$$
Posterior at each level: Computed from merged representation $$q(z_l|z_{>l}, x) = \mathcal{N}(\mu_l(h_l^{\text{td}}), \sigma_l(h_l^{\text{td}}))$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
import torchimport torch.nn as nn class BidirectionalVAE(nn.Module): """ Hierarchical VAE with bidirectional inference. Combines bottom-up encoding with top-down inference. """ def __init__(self, input_dim, hidden_dim, latent_dims): """ Args: input_dim: Observation dimensionality hidden_dim: Hidden layer size latent_dims: List of latent dims per level [z1_dim, z2_dim, ...] """ super().__init__() self.num_levels = len(latent_dims) self.latent_dims = latent_dims # Bottom-up encoder self.bottom_up = nn.ModuleList() prev_dim = input_dim for l in range(self.num_levels): self.bottom_up.append(nn.Sequential( nn.Linear(prev_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), )) prev_dim = hidden_dim # Top-down blocks: merge z_{l+1} with bottom-up features self.top_down = nn.ModuleList() for l in range(self.num_levels - 1, -1, -1): # L-1, L-2, ..., 0 if l == self.num_levels - 1: # Top level: only bottom-up in_dim = hidden_dim else: # Other levels: bottom-up + z_{l+1} in_dim = hidden_dim + latent_dims[l + 1] self.top_down.append(nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.GELU(), )) # Posterior heads for each level self.posterior_mean = nn.ModuleList([ nn.Linear(hidden_dim, z_dim) for z_dim in latent_dims ]) self.posterior_logvar = nn.ModuleList([ nn.Linear(hidden_dim, z_dim) for z_dim in latent_dims ]) # Prior heads (conditioned on z_{l+1}) self.prior_mean = nn.ModuleList() self.prior_logvar = nn.ModuleList() for l in range(self.num_levels): if l == self.num_levels - 1: # Top level: standard Gaussian prior self.prior_mean.append(None) self.prior_logvar.append(None) else: self.prior_mean.append(nn.Linear(latent_dims[l + 1], latent_dims[l])) self.prior_logvar.append(nn.Linear(latent_dims[l + 1], latent_dims[l])) # Decoder self.decoder = nn.Sequential( nn.Linear(latent_dims[0], hidden_dim), nn.GELU(), nn.Linear(hidden_dim, input_dim), ) def encode(self, x): """Bottom-up pass to get features at each level.""" features = [] h = x for encoder in self.bottom_up: h = encoder(h) features.append(h) return features def forward(self, x): """ Full forward pass for training. Returns: reconstruction, list of (z, kl) per level """ batch_size = x.shape[0] device = x.device # Bottom-up encoding bu_features = self.encode(x) # Top-down inference z_samples = [None] * self.num_levels kl_losses = [] for l_idx, l in enumerate(range(self.num_levels - 1, -1, -1)): # Merge bottom-up with top-down if l == self.num_levels - 1: td_input = bu_features[l] else: z_above = z_samples[l + 1] td_input = torch.cat([bu_features[l], z_above], dim=-1) h = self.top_down[l_idx](td_input) # Posterior parameters post_mean = self.posterior_mean[l](h) post_logvar = self.posterior_logvar[l](h) post_std = torch.exp(0.5 * post_logvar) # Prior parameters if l == self.num_levels - 1: prior_mean = torch.zeros_like(post_mean) prior_logvar = torch.zeros_like(post_logvar) else: z_above = z_samples[l + 1] prior_mean = self.prior_mean[l](z_above) prior_logvar = self.prior_logvar[l](z_above) prior_std = torch.exp(0.5 * prior_logvar) # Sample z_l eps = torch.randn_like(post_mean) z_l = post_mean + post_std * eps z_samples[l] = z_l # KL divergence kl = 0.5 * ( prior_logvar - post_logvar + (post_std.pow(2) + (post_mean - prior_mean).pow(2)) / prior_std.pow(2) - 1 ).sum(dim=-1) kl_losses.append(kl) # Decode from z_0 x_recon = self.decoder(z_samples[0]) return x_recon, z_samples, kl_lossesThe Ladder VAE is a specific hierarchical architecture designed to encourage all levels to be used. Its key innovation is a particular way of combining information in the generative model.
Key Design Principles:
Residual connections in latent space: Higher levels modulate rather than replace lower-level information
Shared structure between inference and generation: The encoder and decoder have matched architectures at each level
Precision-weighted combination: When merging bottom-up and top-down signals, weight by their respective precision (inverse variance)
$$z_l = f\left(\frac{h_l^{\text{bu}}}{\sigma_{\text{bu}}^2} + \frac{h_l^{\text{td}}}{\sigma_{\text{td}}^2}\right)$$
The Precision-Weighted Merge:
When combining bottom-up features d (from encoder) with top-down features μ (from prior), the posterior is:
$$\hat{\mu} = \frac{\mu/\sigma_{\text{prior}}^2 + d/\sigma_{\text{enc}}^2}{1/\sigma_{\text{prior}}^2 + 1/\sigma_{\text{enc}}^2}$$ $$\hat{\sigma}^2 = \frac{1}{1/\sigma_{\text{prior}}^2 + 1/\sigma_{\text{enc}}^2}$$
This is the optimal fusion under Gaussian assumptions and naturally balances the two information sources.
Benefits:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
import torchimport torch.nn as nn class LadderVAE(nn.Module): """ Ladder VAE with precision-weighted inference. Uses residual connections and shared encoder-decoder structure. """ def __init__(self, input_dim, hidden_dims, latent_dims): super().__init__() self.num_levels = len(latent_dims) self.latent_dims = latent_dims # Encoder (bottom-up) self.encoders = nn.ModuleList() self.enc_mean = nn.ModuleList() self.enc_logvar = nn.ModuleList() prev_dim = input_dim for i, (h_dim, z_dim) in enumerate(zip(hidden_dims, latent_dims)): self.encoders.append(nn.Sequential( nn.Linear(prev_dim, h_dim), nn.BatchNorm1d(h_dim), nn.ReLU(), nn.Linear(h_dim, h_dim), nn.ReLU(), )) self.enc_mean.append(nn.Linear(h_dim, z_dim)) self.enc_logvar.append(nn.Linear(h_dim, z_dim)) prev_dim = h_dim # Decoder (top-down) self.decoders = nn.ModuleList() self.dec_mean = nn.ModuleList() self.dec_logvar = nn.ModuleList() for i in range(self.num_levels - 1, -1, -1): z_dim = latent_dims[i] h_dim = hidden_dims[i] if i == self.num_levels - 1: # Top level: prior is N(0, I) self.dec_mean.append(None) self.dec_logvar.append(None) self.decoders.append(None) else: z_above = latent_dims[i + 1] self.decoders.append(nn.Sequential( nn.Linear(z_above, h_dim), nn.ReLU(), )) self.dec_mean.append(nn.Linear(h_dim, z_dim)) self.dec_logvar.append(nn.Linear(h_dim, z_dim)) # Output decoder self.output_decoder = nn.Sequential( nn.Linear(latent_dims[0], hidden_dims[0]), nn.ReLU(), nn.Linear(hidden_dims[0], input_dim), ) def precision_weighted_merge(self, d_mean, d_logvar, p_mean, p_logvar): """ Merge encoder (d) and prior (p) using precision weighting. Returns posterior mean and logvar. """ # Precision = 1/variance d_prec = torch.exp(-d_logvar) p_prec = torch.exp(-p_logvar) # Combined precision combined_prec = d_prec + p_prec combined_var = 1.0 / combined_prec combined_logvar = torch.log(combined_var) # Precision-weighted mean combined_mean = (d_mean * d_prec + p_mean * p_prec) / combined_prec return combined_mean, combined_logvar def forward(self, x): batch_size = x.shape[0] device = x.device # Bottom-up pass enc_features = [] h = x for encoder in self.encoders: h = encoder(h) enc_features.append(h) # Get encoder distributions enc_means = [] enc_logvars = [] for i, (mean_net, logvar_net) in enumerate(zip(self.enc_mean, self.enc_logvar)): enc_means.append(mean_net(enc_features[i])) enc_logvars.append(logvar_net(enc_features[i])) # Top-down pass with precision-weighted merge z_samples = [None] * self.num_levels kl_losses = [] for i in range(self.num_levels - 1, -1, -1): # Get prior parameters if i == self.num_levels - 1: prior_mean = torch.zeros_like(enc_means[i]) prior_logvar = torch.zeros_like(enc_logvars[i]) else: z_above = z_samples[i + 1] dec_idx = self.num_levels - 1 - i h = self.decoders[dec_idx](z_above) prior_mean = self.dec_mean[dec_idx](h) prior_logvar = self.dec_logvar[dec_idx](h) # Precision-weighted merge post_mean, post_logvar = self.precision_weighted_merge( enc_means[i], enc_logvars[i], prior_mean, prior_logvar ) # Sample std = torch.exp(0.5 * post_logvar) eps = torch.randn_like(std) z_samples[i] = post_mean + std * eps # KL divergence kl = 0.5 * ( prior_logvar - post_logvar + (std.pow(2) + (post_mean - prior_mean).pow(2)) / torch.exp(prior_logvar) - 1 ).sum(dim=-1) kl_losses.append(kl) # Decode x_recon = self.output_decoder(z_samples[0]) return x_recon, z_samples, kl_lossesPosterior collapse in hierarchical models is even more challenging than in single-level VAEs. Higher levels often collapse first, as the model learns to encode everything in the lowest level that has direct access to the reconstruction.
Strategies to Encourage Hierarchical Usage:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
import torchimport torch.nn as nn class HierarchicalVAETrainer: """ Training utilities for hierarchical VAEs. Implements strategies to prevent hierarchical posterior collapse. """ def __init__(self, model, num_levels, free_bits_per_level=None, kl_warmup_epochs=None, level_weights=None): self.model = model self.num_levels = num_levels # Free bits per level (default: same for all) if free_bits_per_level is None: self.free_bits = [0.0] * num_levels else: self.free_bits = free_bits_per_level # Warmup epochs per level (higher levels warm up first) if kl_warmup_epochs is None: self.warmup_epochs = [10 * (num_levels - l) for l in range(num_levels)] else: self.warmup_epochs = kl_warmup_epochs # Per-level weights if level_weights is None: self.level_weights = [1.0] * num_levels else: self.level_weights = level_weights self.current_epoch = 0 def get_kl_weight(self, level): """Get KL weight for a specific level based on warmup schedule.""" warmup = self.warmup_epochs[level] if self.current_epoch >= warmup: return 1.0 return self.current_epoch / warmup def apply_free_bits(self, kl_per_level): """Apply free bits constraint per level.""" adjusted_kl = [] for l, kl in enumerate(kl_per_level): if self.free_bits[l] > 0: # Average KL per dimension at this level kl_clamped = torch.clamp(kl, min=self.free_bits[l]) adjusted_kl.append(kl_clamped) else: adjusted_kl.append(kl) return adjusted_kl def compute_loss(self, x, recon_loss_fn): """ Compute hierarchical ELBO with all techniques applied. """ x_recon, z_samples, kl_per_level = self.model(x) # Reconstruction loss recon_loss = recon_loss_fn(x_recon, x) # Apply free bits kl_losses = self.apply_free_bits(kl_per_level) # Apply per-level weights and warmup total_kl = 0.0 kl_breakdown = {} for l, kl in enumerate(kl_losses): kl_weight = self.get_kl_weight(l) * self.level_weights[l] weighted_kl = kl_weight * kl.mean() total_kl = total_kl + weighted_kl kl_breakdown[f'kl_level_{l}'] = kl.mean().item() kl_breakdown[f'kl_weight_{l}'] = kl_weight # Total loss = -ELBO loss = recon_loss.mean() + total_kl return { 'loss': loss, 'recon_loss': recon_loss.mean().item(), 'total_kl': total_kl.item(), **kl_breakdown, } def step_epoch(self): """Call at end of each epoch.""" self.current_epoch += 1 def diagnose_collapse(self, x_batch, threshold=0.1): """ Check which levels are actively used. Returns list of (level, active_dims, status). """ with torch.no_grad(): _, z_samples, kl_per_level = self.model(x_batch) results = [] for l, kl in enumerate(kl_per_level): mean_kl = kl.mean().item() status = "active" if mean_kl > threshold else "COLLAPSED" results.append({ 'level': l, 'mean_kl': mean_kl, 'status': status, }) return resultsHierarchical VI enables powerful models across various domains:
Image Generation:
Language Modeling:
| Model | Domain | Hierarchy Structure | Key Innovation |
|---|---|---|---|
| NVAE | Images | 30+ levels, residual | Depth through residual cells |
| VQ-VAE-2 | Images | 2-level discrete | Vector quantization at each level |
| DRAW | Images | Recurrent hierarchy | Iterative refinement |
| Hierarchical VAE-Text | Text | Word/sentence/document | Multi-scale structure |
| Neural Process | Functions | Context/target | Functional data hierarchy |
Match hierarchy depth to data structure. For images with global and local patterns, 2-4 levels often suffice. For structured data (documents, sequences), align levels with natural groupings. Monitor KL at each level during training to ensure all levels are utilized.
We've explored hierarchical variational inference for models with multiple levels of latent variables. Here are the key takeaways:
You now understand hierarchical variational inference and techniques for training deep latent variable models. Next, we'll explore VI for deep learning, where variational methods enable uncertainty quantification in neural networks.