Loading content...
Exact inference—variable elimination, junction trees—is the gold standard when tractable. But many real-world graphical models have structure that makes exact computation infeasible: high treewidth, continuous variables, or simply too many states to enumerate. When exact methods fail, we turn to approximate inference.
Approximate inference is a rich field with two major paradigms: sampling (Monte Carlo methods) and optimization (variational methods). Each paradigm offers different trade-offs in accuracy, speed, and theoretical guarantees. Mastering both is essential for practical probabilistic modeling.
By the end of this page, you will understand: (1) the fundamental divide between sampling and variational methods, (2) key sampling algorithms including importance sampling, MCMC, and Gibbs sampling, (3) variational inference basics including mean-field approximation, (4) practical trade-offs and selection criteria, and (5) how these methods apply to graphical model inference.
Approximate inference methods can be organized along several dimensions: stochastic vs. deterministic, local vs. global, and asymptotically exact vs. biased. Understanding this taxonomy helps in selecting the right method for a given problem.
The Two Major Paradigms:
Sampling (Monte Carlo) Methods: Generate samples from (or approximately from) the target distribution. Compute expectations by averaging over samples. Asymptotically exact as sample count increases.
Variational Methods: Approximate the target distribution with a simpler one from a tractable family. Optimize to make the approximation as close as possible. Deterministic and fast, but inherently biased.
| Aspect | Sampling Methods | Variational Methods |
|---|---|---|
| Core idea | Draw samples, estimate by averaging | Optimize within tractable family |
| Stochastic? | Yes—different random samples each run | No—deterministic optimization |
| Asymptotic behavior | Converges to true answer with enough samples | Converges to best approximation (may be biased) |
| Error characterization | Variance (decreases with samples) | Bias (depends on approximation family) |
| Speed | Can be slow (many samples needed) | Often fast (fixed optimization steps) |
| Parallelization | Highly parallelizable (independent samples) | Sequential updates; some parallel variants |
| Continuous variables | Handles naturally | Requires analytic tractability |
| Multimodal distributions | Can explore all modes (with care) | Often collapses to one mode |
Loopy belief propagation, covered in the previous page, is actually a variational method! Its fixed points minimize the Bethe free energy, a variational approximation. This connects message-passing to the broader variational framework.
Monte Carlo methods estimate expectations by averaging over random samples. If we want E[f(X)] where X ~ P, and we can draw samples x₁, x₂, ..., xₙ from P, then:
Ê[f(X)] = (1/n) Σᵢ f(xᵢ) → E[f(X)] as n → ∞
For graphical models, we want marginals P(Xᵢ = xᵢ), which are expectations of indicator functions. The challenge is: how do we sample from complex, high-dimensional joint distributions?
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
import numpy as npfrom typing import List, Dict, Callablefrom scipy.special import logsumexp def ancestral_sampling( bayesian_network: Dict[str, Dict], num_samples: int = 1000) -> List[Dict[str, int]]: """ Ancestral (forward) sampling from a Bayesian network. Works by sampling variables in topological order, conditioning on already-sampled parents. This is the simplest sampling method for BNs: - Exact samples from the joint distribution - Only works for Bayesian networks (not MRFs) - Cannot incorporate evidence easily Args: bayesian_network: Dict mapping variable to {parents, cpt} num_samples: Number of samples to draw Returns: List of sample dicts {var: value} """ # Get topological order order = _topological_sort(bayesian_network) samples = [] for _ in range(num_samples): sample = {} for var in order: info = bayesian_network[var] parents = info['parents'] cpt = info['cpt'] # Get parent values parent_values = tuple(sample[p] for p in parents) # Index into CPT to get distribution for this variable # CPT shape: (parent1_card, parent2_card, ..., var_card) if parents: dist = cpt[parent_values] else: dist = cpt # Sample from conditional distribution sample[var] = np.random.choice(len(dist), p=dist) samples.append(sample) return samples def rejection_sampling( target_unnorm: Callable, proposal: Callable, proposal_sample: Callable, M: float, num_samples: int = 1000) -> np.ndarray: """ Rejection sampling from unnormalized target distribution. Requires a proposal distribution q(x) such that: - We can sample from q(x) - We can evaluate q(x) - target(x) <= M * q(x) for all x Acceptance rate = Z / M, where Z is target's normalizing constant. Poor for high dimensions or when M >> Z. Args: target_unnorm: Function returning unnormalized target density proposal: Function returning proposal density proposal_sample: Function to sample from proposal M: Upper bound constant (target <= M * proposal everywhere) num_samples: Number of accepted samples desired Returns: Array of accepted samples """ samples = [] while len(samples) < num_samples: # Sample from proposal x = proposal_sample() # Compute acceptance probability p_accept = target_unnorm(x) / (M * proposal(x)) # Accept with this probability if np.random.random() < p_accept: samples.append(x) return np.array(samples) def _topological_sort(bn: Dict) -> List[str]: """Return variables in topological order.""" visited = set() order = [] def visit(var): if var in visited: return visited.add(var) for parent in bn[var]['parents']: visit(parent) order.append(var) for var in bn: visit(var) return orderImportance sampling avoids rejection by reweighting samples from a different distribution. Instead of sampling from P (hard), we sample from a proposal Q (easy) and weight each sample by the importance ratio.
The Core Identity:
E_P[f(X)] = E_Q[f(X) · P(X)/Q(X)] = E_Q[f(X) · w(X)]
where w(X) = P(X)/Q(X) is the importance weight.
For unnormalized targets where P(X) = ψ(X)/Z, we use self-normalized importance sampling:
Ê[f(X)] = Σᵢ w̃ᵢ f(xᵢ) where w̃ᵢ = wᵢ / Σⱼ wⱼ
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
def importance_sampling( target_unnorm: Callable, proposal_log_prob: Callable, proposal_sample: Callable, f: Callable, num_samples: int = 1000) -> Tuple[float, float]: """ Self-normalized importance sampling. Estimates E_P[f(X)] using samples from proposal Q. Handles unnormalized target distributions. Args: target_unnorm: Function returning log of unnormalized target proposal_log_prob: Function returning log proposal density proposal_sample: Function to sample from proposal f: Function to compute expectation of num_samples: Number of samples Returns: - Estimated expectation - Effective sample size (ESS) """ samples = [proposal_sample() for _ in range(num_samples)] # Compute log importance weights (unnormalized) log_weights = [] for x in samples: log_w = target_unnorm(x) - proposal_log_prob(x) log_weights.append(log_w) log_weights = np.array(log_weights) # Normalize weights using logsumexp for stability log_norm = logsumexp(log_weights) normalized_weights = np.exp(log_weights - log_norm) # Compute weighted average f_values = np.array([f(x) for x in samples]) estimate = np.sum(normalized_weights * f_values) # Effective sample size: measures weight concentration # ESS = 1 / sum(w_i^2), where w_i are normalized weights ess = 1.0 / np.sum(normalized_weights ** 2) return estimate, ess def likelihood_weighting( bayesian_network: Dict[str, Dict], evidence: Dict[str, int], query_var: str, num_samples: int = 1000) -> np.ndarray: """ Likelihood weighting: importance sampling for Bayesian networks. Proposal: sample non-evidence variables ancestrally, fix evidence Weight: product of evidence likelihoods given parents Much more efficient than rejection sampling with evidence. Args: bayesian_network: BN structure evidence: Observed variable assignments query_var: Variable whose marginal we want num_samples: Number of weighted samples Returns: Estimated marginal distribution for query_var """ order = _topological_sort(bayesian_network) query_card = bayesian_network[query_var]['cpt'].shape[-1] weighted_counts = np.zeros(query_card) total_weight = 0.0 for _ in range(num_samples): sample = {} log_weight = 0.0 for var in order: info = bayesian_network[var] parents = info['parents'] cpt = info['cpt'] parent_values = tuple(sample[p] for p in parents) if parents: dist = cpt[parent_values] else: dist = cpt if var in evidence: # Evidence variable: don't sample, add to weight sample[var] = evidence[var] log_weight += np.log(dist[evidence[var]] + 1e-10) else: # Hidden variable: sample as usual sample[var] = np.random.choice(len(dist), p=dist) weight = np.exp(log_weight) weighted_counts[sample[query_var]] += weight total_weight += weight return weighted_counts / total_weight class SequentialMonteCarlo: """ Sequential Monte Carlo (particle filtering) for dynamic models. Maintains a weighted set of particles (samples) that are propagated, reweighted, and resampled as evidence arrives. Essential for online inference in temporal graphical models (HMMs, DBNs). """ def __init__( self, transition: Callable, emission: Callable, initial: Callable, num_particles: int = 100 ): """ Initialize SMC. Args: transition: P(x_t | x_{t-1}) sampler emission: P(y_t | x_t) density evaluator initial: P(x_0) sampler num_particles: Number of particles to maintain """ self.transition = transition self.emission = emission self.initial = initial self.num_particles = num_particles self.particles = None self.weights = None def initialize(self): """Sample initial particles from prior.""" self.particles = [self.initial() for _ in range(self.num_particles)] self.weights = np.ones(self.num_particles) / self.num_particles def step(self, observation): """ Process one observation: propagate, reweight, resample. """ # Propagate: sample new states from transition new_particles = [self.transition(p) for p in self.particles] # Reweight: multiply by observation likelihood log_weights = np.log(self.weights + 1e-10) for i, particle in enumerate(new_particles): log_weights[i] += np.log(self.emission(observation, particle) + 1e-10) # Normalize log_norm = logsumexp(log_weights) self.weights = np.exp(log_weights - log_norm) self.particles = new_particles # Resample if effective sample size is low ess = 1.0 / np.sum(self.weights ** 2) if ess < self.num_particles / 2: self._resample() def _resample(self): """Systematic resampling to rejuvenate particle set.""" indices = np.random.choice( self.num_particles, size=self.num_particles, p=self.weights ) self.particles = [self.particles[i] for i in indices] self.weights = np.ones(self.num_particles) / self.num_particlesImportance sampling suffers from weight degeneracy in high dimensions: a few samples dominate, and effective sample size becomes tiny. This is why naive IS doesn't scale to complex graphical models. Sequential methods (SMC) and MCMC address this limitation.
MCMC constructs a Markov chain whose stationary distribution is the target P. By running the chain long enough, samples from the chain approximate samples from P—even when we only know P up to a normalizing constant.
Key Insight:
We don't need to sample P directly. Instead, we design a transition kernel T(x' | x) such that:
A sufficient condition is detailed balance: P(x) T(x' | x) = P(x') T(x | x')
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
class MetropolisHastings: """ Metropolis-Hastings MCMC for graphical model inference. General-purpose MCMC that works with any proposal distribution. Automatically satisfies detailed balance via accept/reject step. """ def __init__( self, log_target: Callable, proposal_sample: Callable, proposal_log_prob: Callable = None ): """ Initialize MH sampler. Args: log_target: Function returning log unnormalized target density proposal_sample: Function(current_x) -> proposed_x proposal_log_prob: Function(x, x') -> log Q(x' | x) If None, assumes symmetric proposal """ self.log_target = log_target self.proposal_sample = proposal_sample self.proposal_log_prob = proposal_log_prob self.symmetric = proposal_log_prob is None def sample( self, initial: np.ndarray, num_samples: int, burn_in: int = 100, thin: int = 1 ) -> Tuple[List[np.ndarray], float]: """ Run MH to collect samples. Args: initial: Starting state num_samples: Number of samples to return burn_in: Initial samples to discard thin: Keep every thin-th sample Returns: - List of samples - Acceptance rate """ samples = [] current = initial current_log_prob = self.log_target(current) total_steps = burn_in + num_samples * thin accepted = 0 for step in range(total_steps): # Propose new state proposed = self.proposal_sample(current) proposed_log_prob = self.log_target(proposed) # Compute acceptance probability log_alpha = proposed_log_prob - current_log_prob if not self.symmetric: # Add Hastings correction for asymmetric proposals log_alpha += self.proposal_log_prob(proposed, current) log_alpha -= self.proposal_log_prob(current, proposed) # Accept or reject if np.log(np.random.random()) < log_alpha: current = proposed current_log_prob = proposed_log_prob accepted += 1 # Collect sample after burn-in, with thinning if step >= burn_in and (step - burn_in) % thin == 0: samples.append(current.copy()) acceptance_rate = accepted / total_steps return samples, acceptance_rate class GibbsSampling: """ Gibbs sampling for graphical models. Special case of MH where each variable is sampled from its full conditional distribution. Always accepts (acceptance rate = 1). Particularly efficient for graphical models because full conditionals depend only on the Markov blanket. """ def __init__( self, variables: List[str], cardinalities: Dict[str, int], factors: List[Factor] ): """ Initialize Gibbs sampler for a factor graph. Args: variables: List of variable names cardinalities: Dict mapping variable to number of values factors: List of factors defining the distribution """ self.variables = variables self.cardinalities = cardinalities self.factors = factors # Precompute which factors involve each variable self.var_to_factors = {v: [] for v in variables} for factor in factors: for var in factor.variables: self.var_to_factors[var].append(factor) def compute_full_conditional( self, var: str, current_state: Dict[str, int] ) -> np.ndarray: """ Compute P(var | all other variables) = P(var | Markov blanket). The full conditional only depends on factors containing var. """ card = self.cardinalities[var] log_probs = np.zeros(card) for value in range(card): # Temporarily set this value test_state = current_state.copy() test_state[var] = value # Compute product of relevant factors log_prob = 0.0 for factor in self.var_to_factors[var]: idx = tuple(test_state[v] for v in factor.variables) log_prob += np.log(factor.potential[idx] + 1e-10) log_probs[value] = log_prob # Normalize to get valid distribution log_probs -= logsumexp(log_probs) return np.exp(log_probs) def sample( self, initial: Dict[str, int], num_samples: int, burn_in: int = 100 ) -> List[Dict[str, int]]: """ Run Gibbs sampling to collect samples. One iteration updates all variables in order. """ samples = [] current = initial.copy() for iteration in range(burn_in + num_samples): # Update each variable from its full conditional for var in self.variables: full_cond = self.compute_full_conditional(var, current) current[var] = np.random.choice( self.cardinalities[var], p=full_cond ) if iteration >= burn_in: samples.append(current.copy()) return samples def estimate_marginals( self, samples: List[Dict[str, int]] ) -> Dict[str, np.ndarray]: """Estimate marginal distributions from Gibbs samples.""" marginals = { v: np.zeros(self.cardinalities[v]) for v in self.variables } for sample in samples: for var, value in sample.items(): marginals[var][value] += 1 for var in marginals: marginals[var] /= len(samples) return marginalsGibbs sampling exploits graphical model structure beautifully: each full conditional depends only on the Markov blanket (neighbors in the graph). This makes updates local and efficient. For many graphical models, Gibbs is the go-to MCMC method.
Variational inference (VI) transforms inference into optimization. Instead of sampling from P, we find the best approximation Q from a tractable family that minimizes the divergence from P.
The Key Objective:
We minimize KL(Q || P), but since we don't have access to P's normalizing constant, we equivalently maximize the Evidence Lower Bound (ELBO):
ELBO(Q) = E_Q[log P(X,E)] - E_Q[log Q(X)] = E_Q[log P(X,E)] + H(Q)
where E is evidence and X are latent variables. This is a lower bound on log P(E).
Mean-Field Approximation:
The simplest VI assumes Q fully factorizes:
Q(X) = ∏ᵢ Qᵢ(Xᵢ)
This ignores all dependencies in the approximate posterior! Despite this strong assumption, mean-field often works surprisingly well.
Coordinate ascent update for factor Qⱼ:
log Qⱼ(Xⱼ) = E_{Q_{-j}}[log P(X, E)] + const
Because KL(Q||P) is zero-avoiding (Q doesn't put mass where P is zero), mean-field VI tends to underestimate posterior variance. The approximation covers the mode(s) of P but not the tails.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
class MeanFieldVI: """ Mean-field variational inference for discrete graphical models. Approximates the posterior with a fully factorized distribution and optimizes via coordinate ascent. """ def __init__( self, variables: List[str], cardinalities: Dict[str, int], factors: List[Factor] ): """ Initialize mean-field VI. Args: variables: List of variable names cardinalities: Dict mapping variable to cardinality factors: List of factors defining the model """ self.variables = variables self.cardinalities = cardinalities self.factors = factors # Variational parameters: Q_i(x_i) for each variable # Initialize to uniform self.q_params = { v: np.ones(cardinalities[v]) / cardinalities[v] for v in variables } # Precompute factor associations self.var_to_factors = {v: [] for v in variables} for factor in factors: for var in factor.variables: self.var_to_factors[var].append(factor) def update_q(self, var: str): """ Update Q for one variable (coordinate ascent step). log Q_j(x_j) = E_{Q_{-j}}[log P(X, E)] + const For discrete factors, this is tractable. """ card = self.cardinalities[var] log_q = np.zeros(card) for value in range(card): # Compute expected log potential for this assignment for factor in self.var_to_factors[var]: log_q[value] += self._expected_log_factor( factor, var, value ) # Normalize to get proper distribution log_q -= logsumexp(log_q) self.q_params[var] = np.exp(log_q) def _expected_log_factor( self, factor: Factor, fixed_var: str, fixed_value: int ) -> float: """ Compute E_Q[log factor(X)] with one variable fixed. This is a sum over all configurations, weighted by Q probabilities for non-fixed variables. """ log_potential = np.log(factor.potential + 1e-10) # Sum over all configurations result = 0.0 for idx in np.ndindex(factor.potential.shape): # Check if this config is consistent with fixed assignment var_idx = factor.variables.index(fixed_var) if idx[var_idx] != fixed_value: continue # Compute Q probability of this configuration log_q_prob = 0.0 for i, v in enumerate(factor.variables): if v != fixed_var: log_q_prob += np.log(self.q_params[v][idx[i]] + 1e-10) result += np.exp(log_q_prob) * log_potential[idx] return result def run( self, max_iter: int = 100, tolerance: float = 1e-6 ) -> Tuple[Dict[str, np.ndarray], float]: """ Run coordinate ascent until convergence. Returns: - Final variational approximation Q - Final ELBO value """ prev_elbo = float('-inf') for iteration in range(max_iter): # Update each Q in turn for var in self.variables: self.update_q(var) # Check convergence via ELBO elbo = self.compute_elbo() if abs(elbo - prev_elbo) < tolerance: break prev_elbo = elbo return self.q_params, elbo def compute_elbo(self) -> float: """ Compute the Evidence Lower Bound. ELBO = E_Q[log P(X)] + H(Q) """ # Expected log joint expected_log_joint = 0.0 for factor in self.factors: expected_log_joint += self._expected_log_factor_full(factor) # Entropy of Q (factorized, so sum of individual entropies) entropy = 0.0 for var in self.variables: q = self.q_params[var] entropy -= np.sum(q * np.log(q + 1e-10)) return expected_log_joint + entropy def _expected_log_factor_full(self, factor: Factor) -> float: """Compute E_Q[log factor] summing over all configs.""" log_potential = np.log(factor.potential + 1e-10) result = 0.0 for idx in np.ndindex(factor.potential.shape): q_prob = 1.0 for i, v in enumerate(factor.variables): q_prob *= self.q_params[v][idx[i]] result += q_prob * log_potential[idx] return resultSelecting the right inference method depends on the problem structure, desired accuracy, computational budget, and downstream use of the results. Here's a practical guide:
| Situation | Recommended Method | Rationale |
|---|---|---|
| Low treewidth (≤ 15-20) | Junction Tree | Exact, deterministic, optimal when tractable |
| Tree-structured model | Belief Propagation | Exact in linear time; use this for chains, trees |
| High treewidth, sparse | Loopy BP | Fast, often accurate for sparse graphs |
| Need exact samples | MCMC (Gibbs) | Eventually exact; good for graphical models |
| Sequential/online data | Particle Filtering (SMC) | Handles streaming evidence naturally |
| Speed critical, some bias OK | Mean-Field VI | Fast, deterministic, good for exploration |
| Need uncertainty estimates | MCMC or SVI | Sampling gives full posterior; useful for decisions |
| Very complex, continuous latent | Variational + sampling | Combine VI for approximation, MCMC for refinement |
In practice, combining methods often works best. Use VI to find a good initialization for MCMC. Run loopy BP to get approximate marginals, then refine with importance sampling. Use MCMC within a junction tree for mixed discrete-continuous models.
The field of approximate inference continues to advance rapidly. Here are some important modern developments that extend the classical methods.
The Connection to Deep Learning:
Modern approximate inference increasingly leverages deep learning:
Variational Autoencoders (VAEs): Use neural networks for both the generative model and the variational approximation. The amortized inference network learns to predict posteriors.
Graph Neural Networks: GNNs can be viewed as generalized message passing, connecting to belief propagation on factor graphs.
Attention Mechanisms: Various attention architectures implicitly perform approximate inference by weighting different parts of the input.
Understanding classical approximate inference provides the foundation for appreciating these modern developments.
Approximate inference provides the tools to handle graphical models that exceed exact methods' tractability limits. Whether through sampling or optimization, these methods make probabilistic reasoning practical for complex real-world problems.
Module Complete:
You have now completed the Inference in Graphical Models module. You understand:
These tools form the computational backbone of probabilistic graphical models, enabling everything from medical diagnosis systems to speech recognition to computer vision.
Congratulations! You have mastered inference in graphical models. You can now select and implement appropriate inference algorithms for a wide range of probabilistic models, from exact methods for tractable cases to sophisticated approximations for complex real-world problems.