Loading content...
Junction trees provide exact inference—but at a cost exponential in treewidth. For many real-world models (grid-structured images, densely connected networks, high-dimensional factor graphs), treewidth is too high for exact computation. What do we do when exact inference is infeasible?
Loopy Belief Propagation (LBP) offers a surprisingly effective answer: simply run standard belief propagation on the cyclic graph anyway! Despite lacking theoretical guarantees of correctness, LBP often produces remarkably accurate approximate marginals and has become a workhorse algorithm in computer vision, error-correcting codes, and machine learning.
By the end of this page, you will understand: (1) why and when loopy BP works despite its theoretical limitations, (2) convergence behavior and conditions, (3) practical techniques for improving LBP performance, (4) the Bethe free energy interpretation linking LBP to variational inference, and (5) successful applications where LBP shines.
Loopy belief propagation applies the standard sum-product message updates to graphs with cycles. The algorithm is identical to tree BP—the difference is purely in the graph structure.
The Algorithm:
Because the graph has cycles, messages depend on themselves through loops. After one iteration, messages reflect local information; after many iterations, they incorporate information from increasingly distant parts of the graph.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
import numpy as npfrom typing import Dict, Tuple, List, Set, Optionalfrom dataclasses import dataclass class LoopyBeliefPropagation: """ Loopy Belief Propagation for approximate inference on cyclic graphs. Uses iterative message passing until convergence or max iterations. Supports both synchronous and asynchronous updates. """ def __init__( self, factor_graph: FactorGraph, damping: float = 0.0, log_space: bool = True ): """ Initialize loopy BP. Args: factor_graph: The factor graph (may contain cycles) damping: Message damping coefficient [0, 1) Higher damping = more stability, slower convergence log_space: Whether to compute in log-space (recommended) """ self.fg = factor_graph self.damping = damping self.log_space = log_space # Messages indexed by (source_node, target_node) self.messages: Dict[Tuple[str, str], np.ndarray] = {} self._initialize_messages() def _initialize_messages(self): """Initialize all messages to uniform distribution.""" # Variable-to-factor messages for var_id, var in self.fg.variables.items(): for factor_id in var.neighbors: if self.log_space: # Uniform in log space = zeros self.messages[(var_id, factor_id)] = np.zeros(var.cardinality) else: self.messages[(var_id, factor_id)] = np.ones(var.cardinality) / var.cardinality # Factor-to-variable messages for factor_id, factor in self.fg.factors.items(): for var_id in factor.variables: var = self.fg.variables[var_id] if self.log_space: self.messages[(factor_id, var_id)] = np.zeros(var.cardinality) else: self.messages[(factor_id, var_id)] = np.ones(var.cardinality) / var.cardinality def run( self, max_iter: int = 100, tolerance: float = 1e-6, schedule: str = 'synchronous' ) -> Tuple[Dict[str, np.ndarray], bool, int]: """ Run loopy BP until convergence or max iterations. Args: max_iter: Maximum number of iterations tolerance: Convergence threshold (max message change) schedule: 'synchronous' or 'asynchronous' Returns: - Approximate marginals for each variable - Whether algorithm converged - Number of iterations used """ converged = False for iteration in range(max_iter): if schedule == 'synchronous': max_change = self._synchronous_iteration() else: max_change = self._asynchronous_iteration() if max_change < tolerance: converged = True break marginals = self._compute_marginals() return marginals, converged, iteration + 1 def _synchronous_iteration(self) -> float: """ Perform one synchronous iteration. All messages are updated simultaneously using previous iteration's values. Returns maximum message change for convergence check. """ new_messages = {} max_change = 0.0 # Compute all new messages for (src, tgt), old_msg in self.messages.items(): if src in self.fg.variables: new_msg = self._compute_var_to_factor_message(src, tgt) else: new_msg = self._compute_factor_to_var_message(src, tgt) # Apply damping if enabled if self.damping > 0: new_msg = (1 - self.damping) * new_msg + self.damping * old_msg # Track maximum change change = np.max(np.abs(new_msg - old_msg)) max_change = max(max_change, change) new_messages[(src, tgt)] = new_msg # Update all messages at once self.messages = new_messages return max_change def _asynchronous_iteration(self) -> float: """ Perform one asynchronous iteration. Messages are updated sequentially; each update immediately uses the latest values. Often converges faster than synchronous. """ max_change = 0.0 # Update messages in some order (here: arbitrary) for (src, tgt), old_msg in list(self.messages.items()): if src in self.fg.variables: new_msg = self._compute_var_to_factor_message(src, tgt) else: new_msg = self._compute_factor_to_var_message(src, tgt) if self.damping > 0: new_msg = (1 - self.damping) * new_msg + self.damping * old_msg change = np.max(np.abs(new_msg - old_msg)) max_change = max(max_change, change) self.messages[(src, tgt)] = new_msg # Immediate update return max_change def _compute_var_to_factor_message(self, var_id: str, factor_id: str) -> np.ndarray: """Variable-to-factor message: product of incoming factor messages.""" var = self.fg.variables[var_id] if self.log_space: message = np.zeros(var.cardinality) for f_id in var.neighbors: if f_id != factor_id: message += self.messages[(f_id, var_id)] # Normalize in log space message -= np.max(message) # For numerical stability else: message = np.ones(var.cardinality) for f_id in var.neighbors: if f_id != factor_id: message *= self.messages[(f_id, var_id)] message /= (np.sum(message) + 1e-10) return message def _compute_factor_to_var_message(self, factor_id: str, var_id: str) -> np.ndarray: """Factor-to-variable message: marginalize weighted potential.""" factor = self.fg.factors[factor_id] var = self.fg.variables[var_id] if self.log_space: # Start with log potential log_pot = np.log(factor.potential + 1e-10) result = log_pot.copy() # Add incoming log-messages from other variables for v_id in factor.variables: if v_id != var_id: incoming = self.messages[(v_id, factor_id)] axis = factor.variables.index(v_id) shape = [1] * len(factor.variables) shape[axis] = len(incoming) result += incoming.reshape(shape) # Marginalize with logsumexp target_axis = factor.variables.index(var_id) for axis in sorted(range(len(factor.variables)), reverse=True): if axis != target_axis: result = logsumexp(result, axis=axis) # Normalize result -= logsumexp(result) else: result = factor.potential.copy() for v_id in factor.variables: if v_id != var_id: incoming = self.messages[(v_id, factor_id)] axis = factor.variables.index(v_id) shape = [1] * len(factor.variables) shape[axis] = len(incoming) result *= incoming.reshape(shape) target_axis = factor.variables.index(var_id) for axis in sorted(range(len(factor.variables)), reverse=True): if axis != target_axis: result = np.sum(result, axis=axis) result /= (np.sum(result) + 1e-10) return result def _compute_marginals(self) -> Dict[str, np.ndarray]: """Compute approximate marginals from final messages.""" marginals = {} for var_id, var in self.fg.variables.items(): if self.log_space: log_belief = np.zeros(var.cardinality) for factor_id in var.neighbors: log_belief += self.messages[(factor_id, var_id)] # Convert to probability log_belief -= logsumexp(log_belief) marginals[var_id] = np.exp(log_belief) else: belief = np.ones(var.cardinality) for factor_id in var.neighbors: belief *= self.messages[(factor_id, var_id)] marginals[var_id] = belief / np.sum(belief) return marginalsUnlike tree BP, loopy BP may not converge! Messages can oscillate indefinitely, especially on graphs with short cycles or strong potentials. Damping and careful initialization help, but convergence is never guaranteed for arbitrary graphs.
Understanding when loopy BP converges—and when its fixed points are accurate—requires analyzing the interplay between graph structure and potential strength.
Factors Affecting Convergence:
| Condition | Mathematical Form | Practical Implication |
|---|---|---|
| Walk-summability | ρ(|J|) < 1 for Gaussian MRFs | Spectral condition on interaction matrix; ensures unique fixed point |
| Contraction | Message update is a contraction mapping | Guarantees convergence to unique fixed point; checkable locally |
| Weak potentials | All log-potentials bounded by constant | Near-uniform distributions converge well |
| Large girth | Shortest cycle length ≥ O(log n) | Long loops prevent rapid information cycling |
| Belief propagation polytope | Messages stay in valid region | Ensures beliefs remain valid distributions |
The Correlation Decay Perspective:
Loopy BP tends to work when the model exhibits correlation decay: the influence of one variable on another decreases with graph distance. If setting one variable strongly affects distant variables, cycles can amplify this effect, causing convergence problems.
Gaussian MRFs: For Gaussian graphical models, convergence is well-understood. Let the precision matrix be Λ = diag(d) - J where J encodes pairwise interactions. BP converges if and only if ρ(|D⁻¹J|) < 1, where ρ is the spectral radius. This is called walk-summability and can be checked before running BP.
Even when loopy BP converges, the fixed point beliefs may not equal true marginals! Convergence only guarantees that messages have stabilized—not that they've stabilized to correct values. The accuracy of converged beliefs depends on the graph structure and how much 'double-counting' occurs.
When loopy BP oscillates or diverges, several techniques can improve stability. The most common is damping, which blends new messages with old ones to prevent rapid changes.
Damped Message Update:
μ_new = (1 - α) × μ_computed + α × μ_old
where α ∈ [0, 1) is the damping coefficient. Higher damping means slower but more stable convergence.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
class ResidualBeliefPropagation(LoopyBeliefPropagation): """ Residual BP: prioritize updates to messages with largest residuals. Often converges much faster than synchronous or random asynchronous schedules, especially for large sparse graphs. """ def __init__(self, factor_graph: FactorGraph, **kwargs): super().__init__(factor_graph, **kwargs) # Track residuals for each message self.residuals: Dict[Tuple[str, str], float] = {} self._initialize_residuals() def _initialize_residuals(self): """Initialize all residuals to infinity (all need updating).""" for edge in self.messages.keys(): self.residuals[edge] = float('inf') def run_residual( self, max_updates: int = 10000, tolerance: float = 1e-6 ) -> Tuple[Dict[str, np.ndarray], bool, int]: """ Run residual BP until convergence. Args: max_updates: Maximum total message updates tolerance: Stop when max residual falls below this Returns: - Approximate marginals - Whether converged - Number of updates performed """ for update_count in range(max_updates): # Find message with largest residual max_edge = max(self.residuals.keys(), key=lambda e: self.residuals[e]) max_residual = self.residuals[max_edge] if max_residual < tolerance: break # Update this message src, tgt = max_edge old_msg = self.messages[max_edge] if src in self.fg.variables: new_msg = self._compute_var_to_factor_message(src, tgt) else: new_msg = self._compute_factor_to_var_message(src, tgt) if self.damping > 0: new_msg = (1 - self.damping) * new_msg + self.damping * old_msg self.messages[max_edge] = new_msg self.residuals[max_edge] = 0.0 # Just updated # Update residuals of neighbors self._update_neighbor_residuals(tgt) marginals = self._compute_marginals() converged = max_residual < tolerance return marginals, converged, update_count + 1 def _update_neighbor_residuals(self, node: str): """Update residuals for messages leaving a node.""" if node in self.fg.variables: neighbors = self.fg.variables[node].neighbors for factor_id in neighbors: # Message from node to factor could have changed edge = (node, factor_id) new_msg = self._compute_var_to_factor_message(node, factor_id) old_msg = self.messages[edge] self.residuals[edge] = np.max(np.abs(new_msg - old_msg)) else: factor = self.fg.factors[node] for var_id in factor.variables: edge = (node, var_id) new_msg = self._compute_factor_to_var_message(node, var_id) old_msg = self.messages[edge] self.residuals[edge] = np.max(np.abs(new_msg - old_msg)) def annealed_bp( factor_graph: FactorGraph, temperatures: List[float] = [10.0, 5.0, 2.0, 1.0], **bp_kwargs) -> Dict[str, np.ndarray]: """ Annealed BP: gradually strengthen potentials for smoother convergence. Start with high temperature (weak potentials), cool down to T=1 (true potentials). Use previous solution to warm-start next temperature. """ # Create copy with modifiable potentials current_fg = factor_graph.copy() bp = None for temp in temperatures: # Scale potentials: potential^(1/temp) for factor_id, factor in current_fg.factors.items(): factor.potential = factor_graph.factors[factor_id].potential ** (1.0 / temp) # Initialize from previous solution if available bp = LoopyBeliefPropagation(current_fg, **bp_kwargs) if bp is not None: bp.messages = bp.messages.copy() # Warm start # Run until convergence marginals, converged, iters = bp.run(max_iter=200) if not converged: print(f"Warning: did not converge at temperature {temp}") return marginalsLoopy BP gains theoretical grounding through its connection to the Bethe free energy—a variational approximation to the log-partition function. This connection explains why LBP often works and provides a framework for improvement.
The Variational Perspective:
The goal of inference is to compute marginals from the partition function Z. Equivalently, we seek the Gibbs free energy F = -log Z. The Bethe approximation replaces the intractable entropy with a tractable approximation based on local consistency constraints.
The Bethe Free Energy:
For a pairwise MRF with node beliefs bᵢ(xᵢ) and edge beliefs bᵢⱼ(xᵢ, xⱼ):
F_Bethe = Σᵢⱼ Σ_{xᵢ,xⱼ} bᵢⱼ(xᵢ,xⱼ) log[bᵢⱼ(xᵢ,xⱼ) / ψᵢⱼ(xᵢ,xⱼ)] + Σᵢ (1 - dᵢ) Σ_{xᵢ} bᵢ(xᵢ) log[bᵢ(xᵢ) / ψᵢ(xᵢ)]
where dᵢ is the degree of node i.
The Bethe free energy approximates the true free energy by:
Key Result: The fixed points of loopy BP are exactly the stationary points of F_Bethe!
The Bethe-BP connection means: (1) LBP is doing principled approximate inference, not ad-hoc heuristics; (2) when LBP converges, it found a stationary point of an approximation to the free energy; (3) we can improve on LBP by using better free energy approximations (e.g., Kikuchi, region graphs).
Limitations:
Bethe approximation can be non-convex: multiple local minima, BP may find suboptimal one
Beliefs may violate probability constraints at non-global minima
Approximation error depends on loop structure; worse for short, strong cycles
Improvements:
Convex BP: Modify updates to ensure convexity; guarantees unique global minimum
Kikuchi methods: Larger region approximations reduce error
Tree-reweighted BP: Convex combination of tree distributions
Max-product loopy BP applies the same approximation philosophy to MAP inference: replace sums with maxes in the message updates and hope for convergence to a good solution.
The Algorithm:
Identical to sum-product LBP, except factor-to-variable messages use max instead of sum:
μ_{f→x}(x) = max{neighbors(f)\x} ψ_f(...) × ∏ μ{y→f}(y)
After convergence, each variable's MAP estimate is:
x*ᵢ = argmax_{xᵢ} ∏{f ∈ neighbors(i)} μ{f→i}(xᵢ)
Max-product loopy BP has an additional problem: even when it converges, locally optimal assignments may be globally inconsistent! Variable i's argmax may conflict with variable j's argmax if the factors between them prefer a different joint assignment. Tree BP automatically ensures consistency; loopy BP does not.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
class MaxProductLoopyBP: """ Max-product loopy BP for approximate MAP inference. Like sum-product LBP but uses max instead of sum. Returns approximate MAP assignment. """ def __init__(self, factor_graph: FactorGraph, damping: float = 0.5): self.fg = factor_graph self.damping = damping self.log_messages: Dict[Tuple[str, str], np.ndarray] = {} self._initialize_messages() def _initialize_messages(self): """Initialize log-messages to zeros (uniform).""" for var_id, var in self.fg.variables.items(): for factor_id in var.neighbors: self.log_messages[(var_id, factor_id)] = np.zeros(var.cardinality) for factor_id, factor in self.fg.factors.items(): for var_id in factor.variables: var = self.fg.variables[var_id] self.log_messages[(factor_id, var_id)] = np.zeros(var.cardinality) def _compute_factor_to_var_message_max( self, factor_id: str, var_id: str ) -> np.ndarray: """Factor-to-variable message using max instead of sum.""" factor = self.fg.factors[factor_id] # Start with log potential log_pot = np.log(factor.potential + 1e-10) result = log_pot.copy() # Add incoming messages from other variables for v_id in factor.variables: if v_id != var_id: incoming = self.log_messages[(v_id, factor_id)] axis = factor.variables.index(v_id) shape = [1] * len(factor.variables) shape[axis] = len(incoming) result += incoming.reshape(shape) # MAX over all variables except target target_axis = factor.variables.index(var_id) for axis in sorted(range(len(factor.variables)), reverse=True): if axis != target_axis: result = np.max(result, axis=axis) # Normalize for numerical stability result -= np.max(result) return result def run( self, max_iter: int = 100, tolerance: float = 1e-6 ) -> Tuple[Dict[str, int], bool, int]: """ Run max-product LBP to find approximate MAP. Returns: - MAP assignment (variable -> value) - Whether converged - Number of iterations """ for iteration in range(max_iter): max_change = self._iteration() if max_change < tolerance: break # Extract MAP assignment from converged messages assignment = self._extract_assignment() return assignment, max_change < tolerance, iteration + 1 def _extract_assignment(self) -> Dict[str, int]: """Extract MAP assignment from final messages.""" assignment = {} for var_id, var in self.fg.variables.items(): # Compute max-marginal log_belief = np.zeros(var.cardinality) for factor_id in var.neighbors: log_belief += self.log_messages[(factor_id, var_id)] # Take argmax assignment[var_id] = int(np.argmax(log_belief)) return assignment def verify_consistency(self, assignment: Dict[str, int]) -> List[str]: """ Check if assignment is consistent with factor preferences. Returns list of factors where local assignment differs from factor's preferred joint assignment. """ inconsistent = [] for factor_id, factor in self.fg.factors.items(): # Check if the assignment maximizes this factor local_assignment = tuple(assignment[v] for v in factor.variables) factor_value = factor.potential[local_assignment] max_value = np.max(factor.potential) if factor_value < max_value * 0.99: # Some tolerance inconsistent.append(factor_id) return inconsistentDespite its theoretical limitations, loopy BP has proven extraordinarily successful in several application domains. Understanding why it works in these cases illuminates the conditions under which LBP excels.
Why LDPC Codes Work So Well:
Sparse graphs: Few edges per node means long cycles
Large n: Many variables dilute per-variable error
Random construction: Avoids worst-case cycle patterns
Near-uniform potentials: Noise model is close to uniform
Why Vision Works Well:
Grid structure: Regular, predictable cycle lengths
Truncated potentials: Limiting potential strength improves convergence
Hierarchical: Coarse-to-fine reduces effective cycle impact
Robustness: Exact marginals not needed; approximate is sufficient
LBP works best when: (1) cycles are long relative to 'correlation length' of potentials, (2) potentials are not too strong, (3) the application is tolerant of approximate marginals, and (4) alternatives (junction trees, sampling) are too expensive. Many practical problems fit these criteria!
Research since the original loopy BP has produced many refinements addressing its limitations. These extensions improve convergence, accuracy, or applicability to new problem types.
| Extension | Key Idea | Benefit |
|---|---|---|
| Tree-Reweighted BP (TRW) | Convex combination of tree distributions | Provable bounds on log-partition function; convex optimization |
| Generalized BP (Kikuchi) | Larger regions instead of edges | Better approximation for densely connected subgraphs |
| Fractional BP | Fractional edge/node counting numbers | Convex Bethe; guaranteed convergence |
| Neural Message Passing | Learn message functions via neural networks | Handles continuous, complex, or unknown factors |
| Variational Message Passing | Exponential family message approximations | Extends BP to continuous latent variables |
| Lifted BP | Exploit symmetry to reduce computation | Efficient inference in relational/first-order models |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
class TreeReweightedBP: """ Tree-Reweighted BP: a convex variant with provable bounds. Key idea: express the model as a convex combination of tree distributions, then run BP with edge weights reflecting tree membership probabilities. Provides upper bound on log-partition function. """ def __init__( self, factor_graph: FactorGraph, edge_weights: Dict[Tuple[str, str], float] = None ): """ Initialize TRW-BP. Args: factor_graph: Factor graph (pairwise MRF) edge_weights: Probability that each edge appears in a random spanning tree. If None, computed automatically. """ self.fg = factor_graph if edge_weights is None: self.edge_weights = self._compute_uniform_tree_weights() else: self.edge_weights = edge_weights self.messages: Dict[Tuple[str, str], np.ndarray] = {} self._initialize_messages() def _compute_uniform_tree_weights(self) -> Dict[Tuple[str, str], float]: """ Compute edge appearance probabilities for uniform spanning trees. Uses the matrix-tree theorem for exact computation. For simplicity here, we use an approximation. """ # Simplified: uniform weights based on graph structure weights = {} for factor_id, factor in self.fg.factors.items(): if len(factor.variables) == 2: v1, v2 = factor.variables # Weight = 2/degree is a simple heuristic d1 = len(self.fg.variables[v1].neighbors) d2 = len(self.fg.variables[v2].neighbors) weights[(v1, v2)] = 2.0 / max(d1, d2) return weights def _compute_factor_to_var_message_trw( self, factor_id: str, var_id: str ) -> np.ndarray: """ TRW-BP factor-to-variable message. Differs from standard BP by using edge weights to scale the factor potential contribution. """ factor = self.fg.factors[factor_id] if len(factor.variables) != 2: # Only pairwise factors handled here return super()._compute_factor_to_var_message(factor_id, var_id) # Get edge weight v1, v2 = factor.variables other_var = v2 if var_id == v1 else v1 rho = self.edge_weights.get((v1, v2), 1.0) # Scale potential by edge weight log_pot = np.log(factor.potential + 1e-10) * rho # Standard message computation with scaled potential result = log_pot.copy() incoming = self.messages[(other_var, factor_id)] axis = factor.variables.index(other_var) shape = [1] * len(factor.variables) shape[axis] = len(incoming) result += incoming.reshape(shape) target_axis = factor.variables.index(var_id) for axis in sorted(range(len(factor.variables)), reverse=True): if axis != target_axis: result = logsumexp(result, axis=axis) return result - logsumexp(result) def compute_upper_bound(self) -> float: """ Compute upper bound on log-partition function. TRW-BP fixed points provide valid upper bounds. """ # After convergence, compute TRW free energy # This is an upper bound on -log Z # Implementation details omitted for brevity passLoopy belief propagation applies the elegant message-passing framework of BP to graphs where it isn't theoretically justified—and often gets surprisingly good results. Its practical success in coding, vision, and other domains has made it one of the most important algorithms in probabilistic inference.
Connection to What's Next:
Loopy BP is just one approach to approximate inference. The next page covers Approximate Inference more broadly, including sampling methods (MCMC, importance sampling) and variational approaches (mean-field, stochastic VI). Understanding when to use LBP versus these alternatives is key to practical inference in graphical models.
You now understand Loopy Belief Propagation—the pragmatic workhorse of approximate inference in graphical models. You can implement LBP with damping and residual scheduling, understand its Bethe free energy interpretation, and recognize when LBP is likely to succeed or fail.