Loading learning content...
In machine learning, we constantly compare probability distributions. When training a classifier, we compare the model's predicted distribution to the true label distribution. When training a variational autoencoder, we compare our learned latent distribution to a prior. When doing Bayesian inference, we compare our approximate posterior to the true posterior.
But how do we quantify "distributional difference"? The answer is the Kullback-Leibler (KL) divergence, named after Solomon Kullback and Richard Leibler who introduced it in 1951. Despite its name, KL divergence is not a true distance metric—it's asymmetric and doesn't satisfy the triangle inequality. Yet this very asymmetry gives it deep meaning and makes it the right tool for many ML applications.
The KL divergence D_KL(P || Q) measures the extra bits needed to encode samples from P when using a code optimized for Q, rather than the optimal code for P. It answers: "How costly is it to pretend the data comes from Q when it really comes from P?"
This page establishes KL divergence rigorously, explores its asymmetry, connects it to variational inference, and shows why it appears throughout modern machine learning.
By the end of this page, you will understand KL divergence's definition and information-theoretic meaning, appreciate why asymmetry matters and when to use each direction, connect KL divergence to variational inference and the ELBO, and recognize KL divergence in VAEs, GANs, policy optimization, and more.
The Kullback-Leibler divergence from distribution Q to distribution P (often read as "KL divergence of P from Q" or "KL divergence from Q to P") is defined as:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
# KL Divergence Definition# ======================== D_KL(P || Q) = ∑ᵢ P(xᵢ) · log[P(xᵢ) / Q(xᵢ)] = ∑ᵢ P(xᵢ) · [log P(xᵢ) - log Q(xᵢ)] = -H(P) + H(P, Q) = H(P, Q) - H(P) # Equivalent forms:D_KL(P || Q) = E_P[log(P/Q)] # Expectation form = E_P[log P] - E_P[log Q] # Difference of expectations = H(P, Q) - H(P) # Cross-entropy minus entropy # For continuous distributions:D_KL(P || Q) = ∫ p(x) · log[p(x) / q(x)] dx # Python Implementation:import numpy as np def kl_divergence(p, q, epsilon=1e-15): """ Compute KL divergence D_KL(P || Q). Args: p: True distribution (array of probabilities) q: Approximate distribution (array of probabilities) epsilon: Small constant to avoid log(0) and division by zero Returns: KL divergence in bits (using log base 2) """ p = np.array(p) + epsilon q = np.array(q) + epsilon p = p / p.sum() # Normalize q = q / q.sum() return np.sum(p * np.log2(p / q)) # Examplesp = np.array([0.4, 0.4, 0.2])q1 = np.array([0.4, 0.4, 0.2]) # Exact matchq2 = np.array([0.33, 0.33, 0.34]) # Close but not exactq3 = np.array([0.1, 0.1, 0.8]) # Very different print(f"D_KL(P || P) = {kl_divergence(p, p):.6f}") # ≈ 0print(f"D_KL(P || Q_close) = {kl_divergence(p, q2):.4f}")print(f"D_KL(P || Q_far) = {kl_divergence(p, q3):.4f}") # Note the asymmetry!print(f"\nAsymmetry demonstration:")print(f"D_KL(P || Q_far) = {kl_divergence(p, q3):.4f}")print(f"D_KL(Q_far || P) = {kl_divergence(q3, p):.4f}")Understanding the formula:
D_KL(P || Q) = Σ P(x) · log[P(x) / Q(x)]
Each term P(x) · log[P(x)/Q(x)] measures:
When P(x) > Q(x), the log ratio is positive: Q underestimates x's probability. When P(x) < Q(x), the log ratio is negative: Q overestimates x's probability.
The expectation over P weights these disagreements by how often they matter under the true distribution.
KL divergence = "extra bits" penalty for using the wrong distribution.
If you encode data from P using an optimal code for Q: • You'd use H(P, Q) bits on average (cross-entropy) • The optimal code for P uses H(P) bits (entropy) • The difference H(P, Q) - H(P) = D_KL(P || Q) is the "waste"
This waste is always ≥ 0, with equality only when P = Q.
KL divergence has several important properties that determine when and how it can be used:
123456789101112131415161718192021222324252627282930313233343536373839404142
import numpy as np def kl_divergence(p, q, epsilon=1e-15): p = np.array(p) + epsilon q = np.array(q) + epsilon p, q = p / p.sum(), q / q.sum() return np.sum(p * np.log2(p / q)) # Property 1: Non-negativityprint("Property: Non-negativity")for _ in range(5): p = np.random.dirichlet([1, 1, 1, 1]) q = np.random.dirichlet([1, 1, 1, 1]) kl = kl_divergence(p, q) print(f" D_KL(P || Q) = {kl:.6f} >= 0: {kl >= 0}")print() # Property 2: Zero iff equalprint("Property: Zero iff equal")p = [0.25, 0.25, 0.25, 0.25]print(f" D_KL(P || P) = {kl_divergence(p, p):.10f}")print() # Property 3: Asymmetryprint("Property: Asymmetry")p = [0.9, 0.1]q = [0.5, 0.5]print(f" D_KL(P || Q) = {kl_divergence(p, q):.4f}")print(f" D_KL(Q || P) = {kl_divergence(q, p):.4f}")print(f" Difference: {abs(kl_divergence(p, q) - kl_divergence(q, p)):.4f}")print() # Property 5: Support mismatch (undefined / infinite)print("Property: Support mismatch")p = [0.5, 0.3, 0.2]q_good = [0.4, 0.4, 0.2]q_bad = [0.6, 0.4, 0.0] # Zero where P > 0 print(f" D_KL with matching support: {kl_divergence(p, q_good):.4f}")# Note: With epsilon, we avoid infinity, but the true KL would be ∞print(f" D_KL with mismatched support (using epsilon): {kl_divergence(p, q_bad):.4f}")print(" (Without epsilon, this would be infinity: log(0.2/0) = ∞)")If Q assigns zero probability to any outcome that P assigns positive probability, D_KL(P || Q) = ∞. This is the "zero-forcing" behavior of KL divergence. In practice, we often add small epsilon to Q or use distributions that always have full support (like Gaussians). This issue fundamentally affects model design in variational inference.
The asymmetry of KL divergence is not a bug—it's a feature with deep implications. The two directions have fundamentally different behaviors that determine which is appropriate for a given task.
Visual intuition for multimodal P:
Imagine the true distribution P has two well-separated modes (like a bimodal distribution), and we must approximate it with a unimodal Q (like a single Gaussian):
Forward KL D_KL(P || Q): Q will stretch to cover both modes, placing mass between them where P has little. It's "mode-averaging."
Reverse KL D_KL(Q || P): Q will concentrate on one mode, ignoring the other entirely. It's "mode-seeking."
Neither is universally better—it depends on whether missing modes is worse than putting mass where it doesn't belong.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
import numpy as npfrom scipy.stats import normfrom scipy.optimize import minimize def kl_forward(p_samples, q_mean, q_std, p_weights=None): """ Approximate D_KL(P || Q) using samples from P. Q is Gaussian with given mean and std. """ if p_weights is None: p_weights = np.ones(len(p_samples)) / len(p_samples) # E_P[log P] is constant for optimization log_q = norm.logpdf(p_samples, q_mean, q_std) return -np.sum(p_weights * log_q) # -E_P[log Q] def kl_reverse(q_samples, p_pdf, q_mean, q_std): """ Approximate D_KL(Q || P) using samples from Q. p_pdf is a function giving P's density. """ log_p = np.log(p_pdf(q_samples) + 1e-10) log_q = norm.logpdf(q_samples, q_mean, q_std) return np.mean(log_q - log_p) # E_Q[log Q - log P] # True distribution: mixture of two Gaussiansdef p_pdf(x): return 0.5 * norm.pdf(x, -3, 1) + 0.5 * norm.pdf(x, 3, 1) def sample_p(n): """Sample from bimodal distribution.""" samples = [] for _ in range(n): if np.random.rand() < 0.5: samples.append(np.random.normal(-3, 1)) else: samples.append(np.random.normal(3, 1)) return np.array(samples) # Find optimal Gaussian Q under forward KL (moment matching)p_samples = sample_p(10000)forward_opt_mean = np.mean(p_samples)forward_opt_std = np.std(p_samples) print("Bimodal P with modes at x=-3 and x=3")print("=" * 50)print(f"\nForward KL (mode-averaging) optimal Q:")print(f" Mean: {forward_opt_mean:.2f} (between modes)")print(f" Std: {forward_opt_std:.2f} (large, covering both)") # For reverse KL, Q will focus on one mode# (This is more complex to optimize; showing concept)print(f"\nReverse KL (mode-seeking) tends to give Q centered at")print(f" Mean: ≈-3 or ≈3 (one mode)")print(f" Std: ≈1 (matching single mode's width)") # Summaryprint(f"\nKey insight:")print("- Forward KL: Q covers everything, even gaps between modes")print("- Reverse KL: Q focuses on one mode perfectly")Variational inference minimizes D_KL(Q || P_true), the reverse KL. Why? Because we can compute E_Q[...] by sampling from Q (our approximate), but we cannot easily sample from the true posterior P. The reverse KL is tractable; the forward KL often isn't. The cost is mode-seeking behavior—the variational posterior may miss some modes of the true posterior.
The KL divergence between Gaussian distributions has a closed-form expression, making it extremely useful for variational autoencoders and other latent variable models.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
# KL Divergence Between Univariate Gaussians# ========================================== # For P = N(μ₁, σ₁²) and Q = N(μ₂, σ₂²): D_KL(P || Q) = log(σ₂/σ₁) + (σ₁² + (μ₁ - μ₂)²) / (2σ₂²) - 1/2 # Special case: P = N(μ, σ²), Q = N(0, 1) (VAE regularization):D_KL(N(μ, σ²) || N(0, 1)) = (1/2) · (σ² + μ² - 1 - log(σ²)) # KL Divergence Between Multivariate Gaussians# ============================================ # For P = N(μ₁, Σ₁) and Q = N(μ₂, Σ₂) in d dimensions: D_KL(P || Q) = (1/2) · [ log(|Σ₂|/|Σ₁|) # Log determinant ratio - d # Dimension + tr(Σ₂⁻¹ Σ₁) # Trace term + (μ₂-μ₁)ᵀ Σ₂⁻¹ (μ₂-μ₁) # Mean difference term] # Python Implementation:import numpy as npimport torchimport torch.nn.functional as F def kl_gaussian_univariate(mu1, sigma1, mu2, sigma2): """ KL divergence D_KL(N(μ₁, σ₁²) || N(μ₂, σ₂²)). """ return (np.log(sigma2 / sigma1) + (sigma1**2 + (mu1 - mu2)**2) / (2 * sigma2**2) - 0.5) def kl_gaussian_standard_normal(mu, sigma): """ KL divergence D_KL(N(μ, σ²) || N(0, 1)). This is the VAE regularization term. """ return 0.5 * (sigma**2 + mu**2 - 1 - np.log(sigma**2)) # VAE-style: batch of latent distributionsdef kl_divergence_vae(mu, log_var): """ KL divergence for VAE, where we parameterize log(σ²) for stability. Args: mu: Mean of approximate posterior, shape (batch, latent_dim) log_var: Log variance of approximate posterior Returns: KL divergence summed over latent dimensions, shape (batch,) """ # D_KL(N(μ, σ²) || N(0, 1)) = 0.5 * (σ² + μ² - 1 - log(σ²)) return 0.5 * torch.sum(torch.exp(log_var) + mu**2 - 1 - log_var, dim=-1) # Exampleprint("Univariate Gaussian KL Examples:")print(f"D_KL(N(0,1) || N(0,1)) = {kl_gaussian_univariate(0, 1, 0, 1):.4f}")print(f"D_KL(N(1,1) || N(0,1)) = {kl_gaussian_univariate(1, 1, 0, 1):.4f}")print(f"D_KL(N(0,2) || N(0,1)) = {kl_gaussian_univariate(0, 2, 0, 1):.4f}")print(f"D_KL(N(2,0.5) || N(0,1)) = {kl_gaussian_univariate(2, 0.5, 0, 1):.4f}") # VAE usageprint("\nVAE KL regularization:")mu = torch.tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 0.5]])log_var = torch.tensor([[0.0, 0.0], [0.5, 0.5], [-0.5, 0.2]])kl = kl_divergence_vae(mu, log_var)print(f"Batch KL values: {kl}")Understanding the Gaussian KL formula:
For D_KL(N(μ, σ²) || N(0, 1)) = ½(σ² + μ² − 1 − log σ²):
This is why VAE latents tend toward the standard normal—the KL term explicitly encourages this. Without it, the model would learn deterministic (degenerate) encodings.
The VAE loss = Reconstruction Loss + β × KL Loss
• Reconstruction: Make decoded output match input (E_q[-log p(x|z)]) • KL: Keep encoder distribution close to prior (D_KL(q(z|x) || p(z)))
The KL term acts as regularization, preventing the encoder from learning arbitrary latent distributions. β controls the tradeoff (β-VAE uses β > 1 for disentanglement).
KL divergence is central to variational inference through the Evidence Lower Bound (ELBO). Understanding this connection illuminates why we use KL divergence to train VAEs and other latent variable models.
The Setup:
Deriving the ELBO:
1234567891011121314151617181920212223242526272829303132
# ELBO Derivation from KL Divergence# =================================== # We want to approximate the true posterior p(z|x) with q(z|x).# The KL divergence measures how far q is from the true posterior: D_KL(q(z|x) || p(z|x)) = E_q[log q(z|x)] - E_q[log p(z|x)] # By Bayes rule: p(z|x) = p(x,z) / p(x)# log p(z|x) = log p(x,z) - log p(x) D_KL(q || p_posterior) = E_q[log q(z|x)] - E_q[log p(x,z) - log p(x)] = E_q[log q(z|x)] - E_q[log p(x,z)] + log p(x) # Rearranging:log p(x) = D_KL(q(z|x) || p(z|x)) + E_q[log p(x,z)] - E_q[log q(z|x)] = D_KL(q(z|x) || p(z|x)) + ELBO # Where the ELBO is:ELBO = E_q[log p(x,z)] - E_q[log q(z|x)] = E_q[log p(x|z)] + E_q[log p(z)] - E_q[log q(z|x)] = E_q[log p(x|z)] - D_KL(q(z|x) || p(z)) ↑ ↑ Reconstruction Regularization # Since D_KL ≥ 0:log p(x) ≥ ELBO # The ELBO is a lower bound on the log evidence log p(x).# Maximizing ELBO both:# 1. Maximizes log p(x) (evidence)# 2. Minimizes D_KL(q || p_posterior) (better approximation)The ELBO decomposition:
ELBO = E_q(z|x)[log p(x|z)] − D_KL(q(z|x) || p(z))
| Term | Name | Meaning |
|---|---|---|
| E_q[log p(x | z)] | Reconstruction |
| D_KL(q | p) |
Training VAEs:
The reparameterization trick allows gradient flow through the sampling step, making the entire objective differentiable.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
import torchimport torch.nn as nnimport torch.nn.functional as F class VAELoss(nn.Module): """ VAE loss = Reconstruction + KL divergence. """ def __init__(self, beta=1.0): super().__init__() self.beta = beta # β-VAE weighting def forward(self, x, x_recon, mu, log_var): """ Args: x: Original input (batch_size, channels, height, width) x_recon: Reconstructed input mu: Encoder mean (batch_size, latent_dim) log_var: Encoder log variance Returns: total_loss, recon_loss, kl_loss """ batch_size = x.size(0) # Reconstruction loss (per-sample, summed over dimensions) # For continuous data: Gaussian likelihood → MSE # For binary data: Bernoulli likelihood → BCE recon_loss = F.mse_loss(x_recon, x, reduction='sum') / batch_size # Alternative for binary: # recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum') / batch_size # KL divergence: D_KL(N(μ, σ²) || N(0, 1)) # = 0.5 * sum(exp(log_var) + mu^2 - 1 - log_var) kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / batch_size # Total loss (ELBO with negative sign) total_loss = recon_loss + self.beta * kl_loss return total_loss, recon_loss, kl_loss # Example usageloss_fn = VAELoss(beta=1.0) # Dummy databatch_size = 32latent_dim = 16x = torch.rand(batch_size, 1, 28, 28) # MNIST-likex_recon = torch.rand(batch_size, 1, 28, 28)mu = torch.randn(batch_size, latent_dim)log_var = torch.randn(batch_size, latent_dim) total, recon, kl = loss_fn(x, x_recon, mu, log_var)print(f"Total loss: {total.item():.4f}")print(f"Reconstruction: {recon.item():.4f}")print(f"KL divergence: {kl.item():.4f}") # With β-VAE (β > 1 for disentanglement)loss_fn_beta = VAELoss(beta=4.0)total_beta, _, kl_beta = loss_fn_beta(x, x_recon, mu, log_var)print(f"\nβ=4.0 Total loss: {total_beta.item():.4f}")print(f"KL contribution: {4.0 * kl.item():.4f}")Beyond VAEs, KL divergence appears throughout machine learning. Here are some key applications:
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
import torchimport torch.nn.functional as F def distillation_loss(student_logits, teacher_logits, true_labels, temperature=4.0, alpha=0.7): """ Knowledge distillation loss. Combines: 1. Hard target loss (student vs true labels) 2. Soft target loss (student vs teacher distributions) Args: student_logits: Raw outputs from student model teacher_logits: Raw outputs from teacher model (pretrained) true_labels: Ground truth class indices temperature: Softens distributions (higher = softer) alpha: Weight for soft target loss (1-alpha for hard target) Returns: Combined distillation loss """ # Hard target loss (standard cross-entropy) hard_loss = F.cross_entropy(student_logits, true_labels) # Soft target loss (KL divergence with temperature) # Higher temperature makes distributions more uniform soft_student = F.log_softmax(student_logits / temperature, dim=1) soft_teacher = F.softmax(teacher_logits / temperature, dim=1) # KL divergence: sum over classes, mean over batch soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') # Scale soft loss by T² to match gradient magnitudes soft_loss = soft_loss * (temperature ** 2) # Combine losses return alpha * soft_loss + (1 - alpha) * hard_loss # Examplebatch_size = 16num_classes = 10 student_logits = torch.randn(batch_size, num_classes)teacher_logits = torch.randn(batch_size, num_classes) * 2 # Teacher is more confidenttrue_labels = torch.randint(0, num_classes, (batch_size,)) loss = distillation_loss(student_logits, teacher_logits, true_labels)print(f"Distillation loss: {loss.item():.4f}") # Compare to plain cross-entropyplain_ce = F.cross_entropy(student_logits, true_labels)print(f"Plain cross-entropy: {plain_ce.item():.4f}")1234567891011121314151617181920212223242526272829303132333435363738394041
import torchimport torch.nn.functional as F def ppo_policy_loss(log_probs_new, log_probs_old, advantages, clip_epsilon=0.2, kl_target=0.01): """ PPO clipped objective with optional KL penalty. The clipping mechanism is a surrogate for KL constraint: - Prevents new policy from deviating too far from old - More computationally efficient than exact KL computation """ # Probability ratio r(θ) = π_new(a|s) / π_old(a|s) ratio = torch.exp(log_probs_new - log_probs_old) # Clipped objective unclipped = ratio * advantages clipped = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages # Take minimum (pessimistic bound) policy_loss = -torch.min(unclipped, clipped).mean() # Optional: Approximate KL for monitoring/adaptive clipping approx_kl = 0.5 * ((log_probs_new - log_probs_old) ** 2).mean() return policy_loss, approx_kl # Examplebatch_size = 64log_probs_old = torch.randn(batch_size) - 1 # Log probs are negativelog_probs_new = log_probs_old + torch.randn(batch_size) * 0.1 # Small updateadvantages = torch.randn(batch_size) # Could be positive or negative loss, kl = ppo_policy_loss(log_probs_new, log_probs_old, advantages)print(f"PPO policy loss: {loss.item():.4f}")print(f"Approximate KL: {kl.item():.6f}") # Large update (KL would be too high)log_probs_new_large = log_probs_old + torch.randn(batch_size) * 1.0_, kl_large = ppo_policy_loss(log_probs_new_large, log_probs_old, advantages)print(f"Large update KL: {kl_large.item():.6f} (too high, would trigger early stopping)")In TRPO (Trust Region Policy Optimization), the constraint is an explicit D_KL(π_old || π_new) ≤ δ bound, enforced via conjugate gradient. PPO approximates this with clipping, which is simpler but less theoretically grounded. Both aim to prevent the "policy cliff" problem where one bad update destroys a good policy.
KL divergence is powerful but not always the right choice. Alternative divergences address different needs:
| Divergence | Definition | Properties | Use Cases |
|---|---|---|---|
| KL Divergence | Σ P log(P/Q) | Asymmetric, unbounded, requires Q > 0 | VAEs, VI, MLE |
| Jensen-Shannon | ½D_KL(P||M) + ½D_KL(Q||M), M=(P+Q)/2 | Symmetric, bounded [0, log2], always finite | GANs, distribution comparison |
| Total Variation | ½Σ|P - Q| | Symmetric, metric, bounded [0, 1] | Differential privacy, robust stats |
| Wasserstein | inf E[d(X,Y)] | Symmetric, metric, considers geometry | WGANs, optimal transport |
| f-Divergences | Σ Q f(P/Q) | Generalizes KL, TV, χ² | f-GAN, statistical theory |
123456789101112131415161718192021222324252627282930313233343536373839404142434445
import numpy as np def kl_divergence(p, q): """KL divergence D_KL(P || Q).""" p, q = np.array(p) + 1e-10, np.array(q) + 1e-10 return np.sum(p * np.log(p / q)) def jensen_shannon(p, q): """Jensen-Shannon divergence (symmetric).""" p, q = np.array(p), np.array(q) m = (p + q) / 2 return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m) def total_variation(p, q): """Total variation distance.""" return 0.5 * np.sum(np.abs(np.array(p) - np.array(q))) def hellinger(p, q): """Hellinger distance.""" return np.sqrt(0.5 * np.sum((np.sqrt(p) - np.sqrt(q))**2)) # Compare distributionsp = [0.1, 0.4, 0.4, 0.1] # Trueq1 = [0.1, 0.4, 0.4, 0.1] # Identicalq2 = [0.25, 0.25, 0.25, 0.25] # Uniformq3 = [0.01, 0.01, 0.01, 0.97] # Very different print("Comparing various divergences")print("=" * 60)print(f"{'Distribution':<15} {'KL(P||Q)':<12} {'JS(P,Q)':<12} {'TV(P,Q)':<12} {'Hellinger':<12}")print("-" * 60) for name, q in [("Identical", q1), ("Uniform", q2), ("Extreme", q3)]: kl = kl_divergence(p, q) js = jensen_shannon(p, q) tv = total_variation(p, q) h = hellinger(p, q) print(f"{name:<15} {kl:<12.4f} {js:<12.4f} {tv:<12.4f} {h:<12.4f}") print()print("Key observations:")print("- KL can be very large (unbounded) for extreme differences")print("- JS is bounded by log(2) ≈ 0.693")print("- TV is bounded by 1")print("- All are 0 for identical distributions")• KL: When you have a "reference" distribution (P) and want Q to cover it. Standard for MLE/VI. • JS: When comparing symmetric pairs; bounded and always defined. Good for comparing model outputs. • Wasserstein: When distributions don't overlap much. Provides gradients even when supports are disjoint. • Total Variation: When you need strict bounds on the probability of any event differing.
KL divergence is a cornerstone of modern machine learning. Let's consolidate what we've learned:
What's next:
We've covered entropy, cross-entropy, and KL divergence. The next page explores mutual information—a symmetric measure of shared information between random variables that arises from KL divergence and has profound applications in feature selection, representation learning, and understanding what neural networks learn.
You now understand KL divergence as a fundamental measure of distributional difference, can reason about forward vs reverse KL for different applications, and know how it underpins variational inference, policy optimization, and knowledge distillation.