Loading learning content...
The Variational Autoencoder framework sounds elegant in theory: encode observations to distributions, sample latent codes, decode to reconstructions, and backpropagate to train. But there's a fundamental problem that makes naive implementation impossible.
The problem: Backpropagation computes gradients by tracing computational paths from loss to parameters. But in a VAE, between the encoder output (distribution parameters) and the decoder input (latent sample), lies a stochastic sampling operation. We can't take gradients through randomness.
This page introduces the reparameterization trick—the ingenious solution that transforms stochastic sampling into a deterministic computation plus external randomness, enabling gradient flow and allowing VAEs to be trained end-to-end with standard backpropagation.
By the end of this page, you will: (1) Understand why naive sampling blocks gradient flow, (2) Derive and implement the reparameterization trick for Gaussian distributions, (3) Understand why reparameterization produces low-variance gradients, (4) Generalize the trick to other distributions, (5) Know alternatives when reparameterization is impossible.
Let's trace the computational graph of a VAE to see where the problem arises.
To update the encoder parameters $\boldsymbol{\phi}$, we need:
$$ abla_{\boldsymbol{\phi}} \mathcal{L} = abla_{\boldsymbol{\phi}} \mathbb{E}{\mathbf{z} \sim q\phi(\mathbf{z}|\mathbf{x})}[f(\mathbf{z})]$$
where $f(\mathbf{z})$ is the reconstruction loss.
The expectation is over a distribution that depends on $\boldsymbol{\phi}$ (through $\boldsymbol{\mu}\phi$ and $\boldsymbol{\sigma}\phi$). We cannot simply move the gradient inside:
$$ abla_{\boldsymbol{\phi}} \mathbb{E}{q\phi}[f(\mathbf{z})] eq \mathbb{E}{q\phi}[ abla_{\boldsymbol{\phi}} f(\mathbf{z})]$$
The RHS is zero because once we sample $\mathbf{z}$, it's just a constant—the gradient of a constant with respect to $\boldsymbol{\phi}$ is zero.
Sampling $\mathbf{z} \sim \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\sigma}^2)$ is a stochastic operation. Consider:
$$\mathbf{z} = \text{sample}(\boldsymbol{\mu}, \boldsymbol{\sigma}^2)$$
Even with identical inputs, repeated calls produce different outputs due to randomness. There's no well-defined $\partial \mathbf{z} / \partial \boldsymbol{\mu}$—the output isn't a deterministic function of the input.
In terms of automatic differentiation:
One approach is the score function estimator (REINFORCE):
$$ abla_{\boldsymbol{\phi}} \mathbb{E}{q\phi}[f(\mathbf{z})] = \mathbb{E}{q\phi}[f(\mathbf{z}) abla_{\boldsymbol{\phi}} \log q_\phi(\mathbf{z}|\mathbf{x})]$$
This is valid but has extremely high variance—making training slow and unstable. The reparameterization trick avoids this entirely.
The distribution we sample from depends on the encoder parameters. We need gradients with respect to these parameters. But sampling is not a differentiable operation. We need a way to express the same computation where randomness enters as external input rather than internal operation.
The reparameterization trick, introduced by Kingma & Welling (2014), transforms stochastic sampling into a deterministic transformation of external noise.
Instead of sampling: $$\mathbf{z} \sim \mathcal{N}(\boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2))$$
We equivalently write: $$\boldsymbol{\epsilon} \sim \mathcal{N}(0, I)$$ $$\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}$$
where $\odot$ is element-wise multiplication.
Why this works:
Now gradients can flow: $$\frac{\partial \mathbf{z}}{\partial \boldsymbol{\mu}} = I, \quad \frac{\partial \mathbf{z}}{\partial \boldsymbol{\sigma}} = \text{diag}(\boldsymbol{\epsilon})$$
Let's verify that the reparameterized sample has the correct distribution.
If $\boldsymbol{\epsilon} \sim \mathcal{N}(0, I)$, then $\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}$ has:
Mean: $$\mathbb{E}[\mathbf{z}] = \mathbb{E}[\boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}] = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \mathbb{E}[\boldsymbol{\epsilon}] = \boldsymbol{\mu}$$
Variance: $$\text{Var}[\mathbf{z}] = \text{Var}[\boldsymbol{\sigma} \odot \boldsymbol{\epsilon}] = \boldsymbol{\sigma}^2 \odot \text{Var}[\boldsymbol{\epsilon}] = \boldsymbol{\sigma}^2$$
Distribution: Linear transformation of Gaussian is Gaussian, so: $$\mathbf{z} \sim \mathcal{N}(\boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2))$$
Exactly what we wanted! But now expressed as a differentiable function.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
import torchimport torch.nn as nn def reparameterize_gaussian( mu: torch.Tensor, log_var: torch.Tensor, training: bool = True) -> torch.Tensor: """ Reparameterization trick for diagonal Gaussian. z = mu + sigma * epsilon, where epsilon ~ N(0, I) Args: mu: Mean of q(z|x), shape [batch, latent_dim] log_var: Log variance of q(z|x), shape [batch, latent_dim] training: If False, return mean (no sampling) Returns: z: Sampled latent code, shape [batch, latent_dim] """ if training: # Compute standard deviation std = torch.exp(0.5 * log_var) # Sample epsilon from standard normal # randn_like uses same device and dtype as std epsilon = torch.randn_like(std) # Reparameterized sample z = mu + std * epsilon return z else: # At inference, use deterministic mean return mu class ReparameterizedSampling(nn.Module): """ Module wrapper for reparameterized sampling. Useful when you want to include sampling in nn.Sequential or explicitly track it as a module. """ def __init__(self): super().__init__() def forward(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: """Sample using reparameterization trick.""" std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return mu + std * eps # Demonstration of gradient flowif __name__ == "__main__": # Create parameters requiring gradients mu = torch.randn(2, 4, requires_grad=True) log_var = torch.randn(2, 4, requires_grad=True) # Reparameterized sampling z = reparameterize_gaussian(mu, log_var, training=True) # Some loss function of z loss = z.sum() # Compute gradients loss.backward() # Verify gradients exist! print(f"Gradient w.r.t. mu exists: {mu.grad is not None}") print(f"Gradient w.r.t. log_var exists: {log_var.grad is not None}") print(f"mu.grad:{mu.grad}") print(f"log_var.grad:{log_var.grad}")The reparameterization trick doesn't just enable gradient computation—it produces low-variance gradient estimates, making optimization far more efficient.
Consider estimating $ abla_{\boldsymbol{\phi}} \mathbb{E}{q\phi}[f(\mathbf{z})]$:
Score Function (REINFORCE) Estimator: $$\hat{g}{\text{SF}} = f(\mathbf{z}) abla{\boldsymbol{\phi}} \log q_\phi(\mathbf{z}|\mathbf{x}), \quad \mathbf{z} \sim q_\phi$$
Reparameterization Estimator: $$\hat{g}{\text{rep}} = abla{\boldsymbol{\phi}} f(g_\phi(\boldsymbol{\epsilon})), \quad \boldsymbol{\epsilon} \sim \mathcal{N}(0, I)$$
where $g_\phi(\boldsymbol{\epsilon}) = \boldsymbol{\mu}\phi + \boldsymbol{\sigma}\phi \odot \boldsymbol{\epsilon}$.
Both estimators are unbiased—their expectation equals the true gradient. But their variances differ dramatically.
High variance in gradient estimates means:
Low variance means:
Score function variance: The score function estimator has variance proportional to $\text{Var}[f(\mathbf{z})]$. If $f$ takes a wide range of values (which reconstruction loss does), variance is high.
Additionally, the estimator includes the term $ abla \log q$, which can have high magnitude (especially near the tails of the distribution).
Reparameterization variance: The reparameterization estimator's variance depends on $\text{Var}[ abla_\epsilon f(g(\epsilon))]$. For smooth $f$, this is often much lower.
Crucially, the randomness ($\epsilon$) is independent of what we're differentiating with respect to ($\phi$). The gradient $ abla_\phi$ only sees the deterministic part of the computation.
Studies consistently show:
| Property | Score Function | Reparameterization |
|---|---|---|
| Unbiased | ✓ | ✓ |
| Variance | High | Low |
| Requires differentiable f | No | Yes |
| Requires reparameterizable distribution | No | Yes |
| Works for discrete latents | Yes | No |
| Practical for VAEs | Only with variance reduction | Yes (default choice) |
Reparameterization is also called the 'pathwise derivative' estimator. Instead of differentiating through the probability (score function), we differentiate through the sample path itself. Given a fixed noise realization ε, we trace a deterministic path from inputs to outputs and differentiate along that path.
Implementing the reparameterization trick correctly requires attention to numerical stability and gradient flow. Let's examine common issues and solutions.
The encoder outputs $\log \boldsymbol{\sigma}^2$, not $\boldsymbol{\sigma}$ or $\boldsymbol{\sigma}^2$ directly. Why?
Converting to standard deviation: $$\boldsymbol{\sigma} = \exp(0.5 \cdot \log \boldsymbol{\sigma}^2)$$
Why $0.5$? Because $\exp(0.5 \cdot \log \sigma^2) = \exp(\log \sigma) = \sigma$.
Extreme log-variance values cause problems:
Too small (log_var << 0):
Too large (log_var >> 0):
Solution: Clamp log_var to safe range, e.g., $[-10, 10]$ giving $\sigma \in [\approx 0.007, \approx 150]$.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
import torchimport torch.nn as nnimport torch.nn.functional as F def stable_reparameterize( mu: torch.Tensor, log_var: torch.Tensor, training: bool = True, min_logvar: float = -10.0, max_logvar: float = 10.0) -> torch.Tensor: """ Numerically stable reparameterization with safeguards. Args: mu: Latent mean log_var: Latent log-variance training: Whether in training mode min_logvar: Minimum allowed log-variance max_logvar: Maximum allowed log-variance Returns: Sampled latent code """ # Clamp log_var to prevent numerical issues log_var = torch.clamp(log_var, min=min_logvar, max=max_logvar) if training: # Compute std: exp(0.5 * log_var) = exp(log(sigma)) = sigma std = torch.exp(0.5 * log_var) # Sample noise eps = torch.randn_like(std) # Reparameterized sample return mu + std * eps else: return mu def compute_kl_divergence_stable( mu: torch.Tensor, log_var: torch.Tensor, reduce: str = 'sum') -> torch.Tensor: """ Compute KL divergence with numerical stability. KL(N(mu, sigma^2) || N(0, 1)) = 0.5 * sum(sigma^2 + mu^2 - 1 - log(sigma^2)) Args: mu: Latent mean [batch, latent_dim] log_var: Latent log-variance [batch, latent_dim] reduce: 'sum', 'mean', or 'none' Returns: KL divergence per sample [batch] if reduce='sum'/'mean' over latent dims """ # Clamp log_var for stability log_var = torch.clamp(log_var, min=-10.0, max=10.0) # KL formula: -0.5 * sum(1 + log_var - mu^2 - exp(log_var)) # Note: exp(log_var) = sigma^2 kl_per_dim = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) if reduce == 'sum': return kl_per_dim.sum(dim=-1) # [batch] elif reduce == 'mean': return kl_per_dim.mean(dim=-1) # [batch] else: return kl_per_dim # [batch, latent_dim] class VAEEncoder(nn.Module): """ Encoder with proper reparameterization setup. """ def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int): super().__init__() self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU() ) self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # Initialize logvar bias to produce approximately unit variance # log(1) = 0, so bias = 0 gives variance ≈ 1 nn.init.zeros_(self.fc_logvar.bias) # Small weights to keep initial log_var near 0 nn.init.normal_(self.fc_logvar.weight, mean=0.0, std=0.01) def forward(self, x: torch.Tensor, sample: bool = True): """ Returns: z: Sampled (or mean) latent code mu: Latent mean log_var: Latent log-variance """ h = self.encoder(x) mu = self.fc_mu(h) log_var = self.fc_logvar(h) z = stable_reparameterize(mu, log_var, training=sample) return z, mu, log_varDuring training, always sample using the reparameterization trick to maintain gradient flow. At inference/evaluation, you can either: (1) Use the mean (deterministic, reproducible), or (2) Sample with temperature scaling (varied generation). When computing reconstruction metrics like ELBO, you should sample. For deterministic tasks like encoding, use the mean.
The reparameterization trick isn't limited to Gaussians. Any distribution that can be expressed as a deterministic transformation of simpler noise can be reparameterized.
A distribution is reparameterizable if we can write: $$\mathbf{z} = g_{\boldsymbol{\theta}}(\boldsymbol{\epsilon}), \quad \boldsymbol{\epsilon} \sim p_0(\boldsymbol{\epsilon})$$
where:
Gaussian: $$\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\Sigma}^{1/2} \boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(0, I)$$
Uniform: $$z = a + (b - a) \cdot u, \quad u \sim \text{Uniform}(0, 1)$$
Exponential: $$z = -\frac{1}{\lambda} \log(u), \quad u \sim \text{Uniform}(0, 1)$$
Logistic: $$z = \mu + s \cdot \log\left(\frac{u}{1-u}\right), \quad u \sim \text{Uniform}(0, 1)$$
Gumbel (for use in Gumbel-Softmax): $$z = \mu - \log(-\log(u)), \quad u \sim \text{Uniform}(0, 1)$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
import torchimport torch.nn.functional as F def reparameterize_full_covariance_gaussian( mu: torch.Tensor, L: torch.Tensor) -> torch.Tensor: """ Reparameterization for Gaussian with full covariance matrix. Args: mu: Mean [batch, dim] L: Lower triangular Cholesky factor of covariance [batch, dim, dim] (Sigma = L @ L.T) Returns: Sample from N(mu, Sigma) """ eps = torch.randn_like(mu) return mu + torch.matmul(L, eps.unsqueeze(-1)).squeeze(-1) def reparameterize_exponential(rate: torch.Tensor) -> torch.Tensor: """ Reparameterization for Exponential distribution. Args: rate: Rate parameter λ > 0 Returns: Sample from Exp(λ) using inverse CDF transform """ u = torch.rand_like(rate) # Clamp u away from 0 and 1 for numerical stability u = torch.clamp(u, min=1e-7, max=1-1e-7) return -torch.log(u) / rate def reparameterize_logistic(mu: torch.Tensor, s: torch.Tensor) -> torch.Tensor: """ Reparameterization for Logistic distribution. Args: mu: Location parameter s: Scale parameter > 0 Returns: Sample from Logistic(mu, s) """ u = torch.rand_like(mu) u = torch.clamp(u, min=1e-7, max=1-1e-7) return mu + s * (torch.log(u) - torch.log(1 - u)) def reparameterize_gumbel(mu: torch.Tensor) -> torch.Tensor: """ Reparameterization for Gumbel distribution. Args: mu: Location parameter Returns: Sample from Gumbel(mu, 1) """ u = torch.rand_like(mu) u = torch.clamp(u, min=1e-7, max=1-1e-7) return mu - torch.log(-torch.log(u)) def gumbel_softmax( logits: torch.Tensor, temperature: float = 1.0, hard: bool = False) -> torch.Tensor: """ Gumbel-Softmax: Differentiable approximation to categorical sampling. Args: logits: Log-probabilities [batch, num_classes] temperature: Temperature for softmax (lower = harder) hard: If True, return one-hot in forward, soft in backward Returns: Soft (or hard) categorical sample """ # Sample Gumbel noise gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-7) + 1e-7) # Add noise to logits and apply temperature-scaled softmax y_soft = F.softmax((logits + gumbels) / temperature, dim=-1) if hard: # Straight-through estimator: hard in forward, soft in backward index = y_soft.argmax(dim=-1) y_hard = torch.zeros_like(logits).scatter_(-1, index.unsqueeze(-1), 1.0) # Use soft gradients, but return hard values return y_hard - y_soft.detach() + y_soft else: return y_softDiscrete distributions (Categorical, Bernoulli, Poisson) cannot be exactly reparameterized because there's no continuous transformation from continuous noise to discrete values. Approaches like Gumbel-Softmax provide differentiable approximations but introduce bias. The straight-through estimator uses hard values forward but soft gradients backward—another biased but practical solution.
A limitation of standard VAEs is that the approximate posterior is a simple diagonal Gaussian. Complex true posteriors may not be well-approximated. Normalizing flows extend reparameterization to create richer posteriors.
Start with a simple base distribution (diagonal Gaussian), then apply a chain of invertible, differentiable transformations:
$$\mathbf{z}_0 \sim \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\sigma}^2)$$ $$\mathbf{z}K = f_K \circ f{K-1} \circ ... \circ f_1(\mathbf{z}_0)$$
Each $f_k$ is an invertible transformation with tractable Jacobian determinant.
The density of $\mathbf{z}_K$ is computed via change of variables:
$$\log q(\mathbf{z}K) = \log q(\mathbf{z}0) - \sum{k=1}^{K} \log \left|\det \frac{\partial f_k}{\partial \mathbf{z}{k-1}}\right|$$
1. Planar Flows: $$f(\mathbf{z}) = \mathbf{z} + \mathbf{u} \cdot \tanh(\mathbf{w}^T \mathbf{z} + b)$$
2. Radial Flows: $$f(\mathbf{z}) = \mathbf{z} + \beta h(\alpha, r)(\mathbf{z} - \mathbf{z}_0)$$
3. Inverse Autoregressive Flow (IAF): Autoregressive transformation with masked networks—powerful but complex.
In a VAE with normalizing flows:
The reparameterization trick applies to both the initial sample and all flow transformations.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
import torchimport torch.nn as nnimport torch.nn.functional as F class PlanarFlow(nn.Module): """ Planar flow transformation: z' = z + u * tanh(w^T z + b) Invertible transformation that 'folds' the space along a hyperplane. """ def __init__(self, dim: int): super().__init__() self.dim = dim # Flow parameters self.u = nn.Parameter(torch.randn(dim) * 0.01) self.w = nn.Parameter(torch.randn(dim) * 0.01) self.b = nn.Parameter(torch.zeros(1)) def forward(self, z: torch.Tensor): """ Apply flow and return transformed z and log-det-jacobian contribution. Args: z: Input [batch, dim] Returns: z': Transformed [batch, dim] log_det: Log determinant of Jacobian [batch] """ # Ensure invertibility: u' = u + (m(w^Tu) - w^Tu) * w/||w||^2 # where m(x) = -1 + softplus(x) wTu = torch.dot(self.w, self.u) m_wTu = -1 + F.softplus(wTu) u_hat = self.u + (m_wTu - wTu) * self.w / (torch.norm(self.w) ** 2) # Forward transformation activation = torch.tanh(F.linear(z, self.w.unsqueeze(0), self.b)) # [batch, 1] z_new = z + u_hat.unsqueeze(0) * activation # Log determinant of Jacobian # d(tanh(x))/dx = 1 - tanh(x)^2 psi = (1 - activation ** 2) * self.w.unsqueeze(0) # [batch, dim] log_det = torch.log(torch.abs(1 + torch.sum(psi * u_hat.unsqueeze(0), dim=1)) + 1e-7) return z_new, log_det class NormalizingFlowPosterior(nn.Module): """ Approximate posterior enhanced with normalizing flows. q(z|x) = q_0(z_0|x) * prod_k |det df_k/dz_{k-1}|^{-1} where z_K = f_K ⚬ ... ⚬ f_1(z_0) and z_0 ~ N(mu(x), sigma(x)^2) """ def __init__(self, latent_dim: int, num_flows: int = 4): super().__init__() self.flows = nn.ModuleList([ PlanarFlow(latent_dim) for _ in range(num_flows) ]) def forward(self, z0: torch.Tensor, log_q0: torch.Tensor): """ Apply flow chain to base sample. Args: z0: Sample from base Gaussian [batch, dim] log_q0: Log probability under base distribution [batch] Returns: zK: Final transformed sample log_qK: Log probability of zK under flow-transformed distribution """ z = z0 log_det_sum = 0.0 for flow in self.flows: z, log_det = flow(z) log_det_sum = log_det_sum + log_det # log q(z_K) = log q(z_0) - sum of log det jacobians log_qK = log_q0 - log_det_sum return z, log_qK def vae_with_flow_loss( x, recon, z0, mu, log_var, log_qK, beta=1.0): """ VAE loss with normalizing flow posterior. ELBO = E_q[log p(x|z)] - KL(q(z|x) || p(z)) = E_q[log p(x|z)] + E_q[log p(z)] - E_q[log q(z|x)] ≈ log p(x|z_K) + log p(z_K) - log q(z_K|x) """ batch_size = x.size(0) # Reconstruction loss recon_loss = F.binary_cross_entropy_with_logits( recon.view(batch_size, -1), x.view(batch_size, -1), reduction='sum' ) / batch_size # log p(z) under standard Gaussian prior log_pz = -0.5 * (z0.pow(2).sum(dim=1) + z0.size(1) * torch.log(torch.tensor(2 * 3.14159))) # ELBO terms: - E_q[log p(x|z)] - E_q[log p(z)] + E_q[log q(z|x)] # We minimize negative ELBO kl_term = (log_qK - log_pz).mean() return recon_loss + beta * kl_term, recon_loss, kl_termSome distributions can't be reparameterized. What then? Several alternatives exist.
The score function estimator works for any distribution but has high variance. Variance reduction techniques make it practical:
1. Control Variates: Subtract a baseline $b$ from the function: $$\hat{g} = (f(\mathbf{z}) - b) abla \log q(\mathbf{z})$$ If $b \approx \mathbb{E}[f(\mathbf{z})]$, variance is reduced without changing the expected gradient.
2. Rao-Blackwellization: Analytically integrate out variables where possible, only using Monte Carlo for the remainder.
3. Multiple Samples: Average over many samples per update. Reduces variance as $1/K$ but increases computation.
For discrete variables, use hard values forward, pretend they were soft backward: $$\text{Forward: } z = \text{argmax}(\text{logits})$$ $$\text{Backward: } \frac{\partial z}{\partial \text{logits}} \approx I$$
Biased but often works in practice.
A continuous relaxation of categorical distributions. As temperature $\tau \to 0$, approaches true categorical but becomes non-differentiable.
$$y_i = \frac{\exp((\log \pi_i + g_i)/\tau)}{\sum_j \exp((\log \pi_j + g_j)/\tau)}$$
where $g_i$ are Gumbel noise samples.
| Situation | Recommended Approach | Notes |
|---|---|---|
| Continuous Gaussian posterior | Reparameterization | Standard VAE, always prefer this |
| Other continuous distribution | Reparameterize if possible | Check if inverse CDF or similar exists |
| Discrete categorical | Gumbel-Softmax | Temperature annealing often needed |
| Binary Bernoulli | Gumbel-Softmax (2 classes) or REINFORCE | Relaxation preferred |
| Complex discrete | Score function + variance reduction | May need significant engineering |
| Mixed continuous/discrete | Hybrid: reparam for continuous, relaxation for discrete | VQ-VAE uses straight-through |
For most VAE applications, stick with Gaussian posteriors and standard reparameterization. Only explore alternatives when: (1) Your latent variables are inherently discrete, (2) Gaussian posteriors demonstrably fail to capture posterior structure, (3) You need structured latent spaces (e.g., VQ-VAE for discrete codebooks). The complexity added by non-reparameterizable distributions often isn't worth it.
The reparameterization trick is the technical innovation that makes VAE training practical. Without it, the stochastic sampling step would block gradient flow. With it, VAEs train smoothly with standard backpropagation.
What's Next:
With the reparameterization trick understood, we've covered all core VAE components. The final page explores VAE variants—the many extensions and modifications that address VAE limitations, including β-VAE, VQ-VAE, hierarchical VAEs, and more.
You now have complete understanding of the reparameterization trick—what problem it solves, how it works, why it produces good gradients, and how to implement it correctly. This technique enables VAE training and appears throughout probabilistic deep learning.