Loading content...
Variable elimination computes marginals one query at a time—but what if we need all marginals? Running VE n times (once per variable) is computationally wasteful because independent runs recompute the same intermediate factors.
Belief Propagation (BP) provides an elegant solution: organize inference as message passing on the graph structure. Nodes (variables or factors) exchange local summaries called messages, and after sufficient propagation, each node has accumulated enough information to compute its marginal. On trees, BP computes all marginals in just two passes—a stunning improvement over separate VE runs.
By the end of this page, you will understand: (1) the factor graph representation as the native home of belief propagation, (2) the sum-product algorithm with complete message update equations, (3) why BP is exact on trees and how it relates to variable elimination, (4) the max-product variant for MAP inference, and (5) the computational and memory trade-offs of message passing.
While belief propagation can be formulated on Bayesian networks or Markov random fields, its most natural expression is on factor graphs—bipartite graphs that explicitly represent the factorization structure of a probability distribution.
Definition: Factor Graph
A factor graph G = (V, F, E) contains:
The joint distribution is:
P(X) = (1/Z) ∏_{a∈F} fₐ(Xₐ)
where Xₐ denotes the variables connected to factor fₐ, and Z is the normalizing constant.
Why Factor Graphs?
Unified representation: Both directed (Bayesian networks) and undirected (Markov random fields) models convert to factor graphs
Explicit factorization: Factor nodes make the factorization structure visible, clarifying which variables interact
Natural BP formulation: Messages flow along edges between variables and factors, matching the bipartite structure
Higher-order interactions: Factor graphs easily represent potentials over any subset of variables (not just pairs)
Bayesian network: each CPT P(Xᵢ | Parents) becomes a factor node connected to Xᵢ and its parents. Markov random field: each clique potential becomes a factor node connected to clique variables. The conversion is straightforward and preserves the distribution.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
from typing import Dict, List, Set, Tuplefrom dataclasses import dataclass, fieldimport numpy as np @dataclassclass FactorNode: """Represents a factor (potential function) in a factor graph.""" id: str variables: Tuple[str, ...] # Connected variable nodes cardinalities: Tuple[int, ...] # Domain size of each variable potential: np.ndarray # The potential function values @property def scope(self) -> Set[str]: return set(self.variables) def get_variable_axis(self, var: str) -> int: """Get the axis index for a variable in the potential array.""" return self.variables.index(var) @dataclassclass VariableNode: """Represents a variable in a factor graph.""" id: str cardinality: int # Number of possible values neighbors: List[str] = field(default_factory=list) # Factor IDs @dataclassclass FactorGraph: """ Complete factor graph representation for belief propagation. Bipartite graph with variable nodes (circles) and factor nodes (squares). Edges connect variables to factors that include them in their scope. """ variables: Dict[str, VariableNode] factors: Dict[str, FactorNode] def __init__(self): self.variables = {} self.factors = {} def add_variable(self, var_id: str, cardinality: int): """Add a variable node to the graph.""" self.variables[var_id] = VariableNode( id=var_id, cardinality=cardinality ) def add_factor( self, factor_id: str, variables: Tuple[str, ...], potential: np.ndarray ): """Add a factor node connected to specified variables.""" cardinalities = tuple( self.variables[v].cardinality for v in variables ) self.factors[factor_id] = FactorNode( id=factor_id, variables=variables, cardinalities=cardinalities, potential=potential ) # Update variable neighbors for var in variables: self.variables[var].neighbors.append(factor_id) def get_neighbors(self, node_id: str) -> List[str]: """Get neighbors of a node (variable or factor).""" if node_id in self.variables: return self.variables[node_id].neighbors elif node_id in self.factors: return list(self.factors[node_id].variables) else: raise ValueError(f"Unknown node: {node_id}") def is_tree(self) -> bool: """Check if the factor graph is a tree (no cycles).""" # A factor graph is a tree if |E| = |V| + |F| - 1 num_edges = sum( len(f.variables) for f in self.factors.values() ) num_nodes = len(self.variables) + len(self.factors) return num_edges == num_nodes - 1 @classmethod def from_bayesian_network(cls, bn_factors: List[Dict]) -> 'FactorGraph': """ Convert a Bayesian network to a factor graph. Args: bn_factors: List of dicts with 'child', 'parents', 'cpt' representing P(child | parents) """ fg = cls() # Add all variables for i, factor in enumerate(bn_factors): child = factor['child'] if child not in fg.variables: fg.add_variable(child, factor['cpt'].shape[-1]) for j, parent in enumerate(factor['parents']): if parent not in fg.variables: fg.add_variable(parent, factor['cpt'].shape[j]) # Add factors for each CPT for i, factor in enumerate(bn_factors): variables = tuple(factor['parents']) + (factor['child'],) fg.add_factor(f"f_{i}", variables, factor['cpt']) return fgThe sum-product algorithm (also called belief propagation) computes marginals by passing messages along factor graph edges. Messages summarize the 'belief' about a variable from one part of the graph to another.
Two Types of Messages:
Message Computation Rules:
For a variable-to-factor message:
μ_{x→f}(x) = ∏{g ∈ N(x)\f} μ{g→x}(x)
where N(x)\f means all factor neighbors of x except f.
For a factor-to-variable message:
μ_{f→x}(x) = Σ_{x_f\x} f(x_f) ∏{y ∈ N(f)\x} μ{y→f}(y)
where x_f denotes all variables in factor f's scope, and the sum is over all variables except x.
The factor-to-variable message computes a marginal of the factor potential, weighted by incoming messages from other variables. The variable-to-factor message multiplies together all incoming messages except from the target factor. This 'exclude self' pattern prevents double-counting information.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
class SumProductBP: """ Sum-Product Belief Propagation on factor graphs. Computes marginal distributions by message passing. Exact on trees; see LoopyBP for cyclic graphs. """ def __init__(self, factor_graph: FactorGraph): self.fg = factor_graph self.messages: Dict[Tuple[str, str], np.ndarray] = {} self._initialize_messages() def _initialize_messages(self): """Initialize all messages to uniform distributions.""" # Variable-to-factor messages for var_id, var in self.fg.variables.items(): for factor_id in var.neighbors: key = (var_id, factor_id) self.messages[key] = np.ones(var.cardinality) # Factor-to-variable messages for factor_id, factor in self.fg.factors.items(): for var_id in factor.variables: key = (factor_id, var_id) var = self.fg.variables[var_id] self.messages[key] = np.ones(var.cardinality) def compute_variable_to_factor_message( self, var_id: str, target_factor_id: str ) -> np.ndarray: """ Compute message from variable to factor. μ_{x→f}(x) = ∏_{g ∈ N(x)\f} μ_{g→x}(x) Product of all incoming factor-to-variable messages, excluding the target factor. """ var = self.fg.variables[var_id] message = np.ones(var.cardinality) for factor_id in var.neighbors: if factor_id != target_factor_id: incoming = self.messages[(factor_id, var_id)] message = message * incoming # Normalize for numerical stability message = message / (np.sum(message) + 1e-10) return message def compute_factor_to_variable_message( self, factor_id: str, target_var_id: str ) -> np.ndarray: """ Compute message from factor to variable. μ_{f→x}(x) = Σ_{x_f\x} f(x_f) ∏_{y ∈ N(f)\x} μ_{y→f}(y) Marginalize the factor potential weighted by incoming variable-to-factor messages from other variables. """ factor = self.fg.factors[factor_id] target_var = self.fg.variables[target_var_id] # Start with the factor potential result = factor.potential.copy() # Multiply in incoming messages from other variables for var_id in factor.variables: if var_id != target_var_id: incoming = self.messages[(var_id, factor_id)] # Expand message to match factor dimensions axis = factor.get_variable_axis(var_id) shape = [1] * len(factor.variables) shape[axis] = len(incoming) incoming_expanded = incoming.reshape(shape) result = result * incoming_expanded # Marginalize over all variables except target target_axis = factor.get_variable_axis(target_var_id) axes_to_sum = [i for i in range(len(factor.variables)) if i != target_axis] # Sum out all other variables for axis in sorted(axes_to_sum, reverse=True): result = np.sum(result, axis=axis) # Normalize for numerical stability result = result / (np.sum(result) + 1e-10) return result def run_on_tree(self) -> Dict[str, np.ndarray]: """ Run sum-product on a tree factor graph. Two passes: collect messages toward root, then distribute back. Returns marginal distributions for all variables. """ if not self.fg.is_tree(): raise ValueError("Graph has cycles; use loopy BP instead") # Choose arbitrary root (first variable) root = list(self.fg.variables.keys())[0] # BFS to determine message schedule # Pass 1: Leaves to root self._upward_pass(root) # Pass 2: Root to leaves self._downward_pass(root) # Compute marginals return self.compute_marginals() def _upward_pass(self, root: str, visited: Set[str] = None): """Collect messages from leaves toward root (DFS post-order).""" if visited is None: visited = set() visited.add(root) if root in self.fg.variables: # Process factor neighbors first for factor_id in self.fg.variables[root].neighbors: if factor_id not in visited: self._upward_pass(factor_id, visited) # After children processed, send message to this var self.messages[(factor_id, root)] = \ self.compute_factor_to_variable_message(factor_id, root) else: # Factor node: process variable neighbors factor = self.fg.factors[root] for var_id in factor.variables: if var_id not in visited: self._upward_pass(var_id, visited) # After children processed, send message to this factor self.messages[(var_id, root)] = \ self.compute_variable_to_factor_message(var_id, root) def _downward_pass(self, root: str, parent: str = None, visited: Set = None): """Distribute messages from root to leaves (DFS pre-order).""" if visited is None: visited = set() visited.add(root) # Send message to all unvisited neighbors neighbors = self.fg.get_neighbors(root) for neighbor in neighbors: if neighbor not in visited: # Compute and send message if root in self.fg.variables: self.messages[(root, neighbor)] = \ self.compute_variable_to_factor_message(root, neighbor) else: self.messages[(root, neighbor)] = \ self.compute_factor_to_variable_message(root, neighbor) self._downward_pass(neighbor, root, visited) def compute_marginals(self) -> Dict[str, np.ndarray]: """ Compute marginal distributions for all variables. Marginal = product of all incoming factor-to-variable messages. """ marginals = {} for var_id, var in self.fg.variables.items(): marginal = np.ones(var.cardinality) for factor_id in var.neighbors: incoming = self.messages[(factor_id, var_id)] marginal = marginal * incoming # Normalize to valid probability distribution marginals[var_id] = marginal / np.sum(marginal) return marginalsA remarkable property of belief propagation is its exactness on trees: when the factor graph has no cycles, two passes of message propagation compute the exact marginal distribution for every variable.
Why Trees are Special:
In a tree, there is exactly one path between any two nodes. This means:
Belief propagation on a tree is equivalent to running Variable Elimination from each variable's perspective simultaneously! Each message μ_{f→x} represents the intermediate factor created when eliminating variables 'beyond' factor f. The genius of BP is caching and reusing these intermediate computations.
Formal Correctness (Sketch):
Let the joint distribution factor as P(X) = (1/Z) ∏_{a} f_a(X_a).
The marginal P(xᵢ) is obtained by summing over all other variables:
P(xᵢ) = (1/Z) Σ_{X\xᵢ} ∏_{a} f_a(X_a)
In a tree, we can group this sum hierarchically. The factor-to-variable message μ_{f→x}(x) exactly computes the partial sum over the subtree 'behind' f relative to x:
μ_{f→x}(xᵢ) = Σ_{subtree} ∏_{factors in subtree} f(...)
The product of all incoming messages to xᵢ thus equals P(xᵢ) up to normalization:
∏{f ∈ N(xᵢ)} μ{f→xᵢ}(xᵢ) ∝ P(xᵢ)
This is why tree BP is exact: each message independently and correctly summarizes its subtree.
| Aspect | Variable Elimination | Belief Propagation (Trees) |
|---|---|---|
| Single marginal query | O(n · k^w) | O(n · k^w) after initial O(n · k^w) setup |
| All marginals query | O(n² · k^w) | O(n · k^w) — single run computes all |
| Memory usage | O(k^w) for intermediate factors | O(E · k) for messages + O(n · k^w) for factors |
| Incremental updates | Must rerun from scratch | Can update messages locally |
| Parallelization | Limited by elimination order | Messages along different branches are independent |
| Code complexity | Factor manipulation | Message passing protocol |
For graphs with cycles, one approach is to select a spanning tree and run exact BP on it—but this ignores some factors. The resulting marginals are approximations. Better approaches (junction trees, loopy BP) handle cycles directly, as we'll see in later pages.
The order in which messages are computed affects both correctness (on trees) and convergence (on cyclic graphs). Different scheduling strategies offer trade-offs between simplicity, parallelization, and convergence speed.
On Trees:
Any schedule that respects the following constraint produces correct results:
This naturally leads to a two-phase schedule:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
from typing import List, Tuple, Callablefrom collections import dequeimport heapq class MessageScheduler: """Different strategies for scheduling BP message updates.""" @staticmethod def tree_schedule(fg: FactorGraph, root: str) -> List[Tuple[str, str]]: """ Generate message schedule for tree BP. Returns list of (source, target) pairs in correct order. First phase: leaves to root Second phase: root to leaves """ schedule = [] visited = set() # BFS to find tree structure and levels levels = {root: 0} queue = deque([root]) parent = {root: None} while queue: node = queue.popleft() neighbors = fg.get_neighbors(node) for neighbor in neighbors: if neighbor not in levels: levels[neighbor] = levels[node] + 1 parent[neighbor] = node queue.append(neighbor) # Phase 1: Sort nodes by level (deepest first) - leaves to root nodes_by_level = sorted(levels.keys(), key=lambda n: -levels[n]) for node in nodes_by_level: if parent[node] is not None: schedule.append((node, parent[node])) # Phase 2: Sort by level (shallowest first) - root to leaves nodes_by_level = sorted(levels.keys(), key=lambda n: levels[n]) for node in nodes_by_level: for neighbor in fg.get_neighbors(node): if parent.get(neighbor) == node: schedule.append((node, neighbor)) return schedule @staticmethod def synchronous_schedule( fg: FactorGraph ) -> List[List[Tuple[str, str]]]: """ Generate synchronous BP schedule. Returns list of rounds, where each round contains messages that can be computed in parallel (all using previous iteration's values). """ all_messages = [] # Variable-to-factor messages for var_id, var in fg.variables.items(): for factor_id in var.neighbors: all_messages.append((var_id, factor_id)) # Factor-to-variable messages for factor_id, factor in fg.factors.items(): for var_id in factor.variables: all_messages.append((factor_id, var_id)) # In synchronous schedule, all messages form one "round" # This is repeated until convergence return [all_messages] @staticmethod def residual_bp_schedule( fg: FactorGraph, messages: Dict[Tuple[str, str], np.ndarray], compute_message: Callable[[str, str], np.ndarray] ) -> Tuple[str, str]: """ Select next message to update using residual-based priority. Returns the message (source, target) with largest potential change. """ max_residual = -1.0 best_message = None # Check all possible messages for var_id, var in fg.variables.items(): for factor_id in var.neighbors: # Compute what the new message would be new_msg = compute_message(var_id, factor_id) old_msg = messages.get((var_id, factor_id), np.ones(var.cardinality)) # Residual = change in message residual = np.max(np.abs(new_msg - old_msg)) if residual > max_residual: max_residual = residual best_message = (var_id, factor_id) for factor_id, factor in fg.factors.items(): for var_id in factor.variables: new_msg = compute_message(factor_id, var_id) old_msg = messages.get((factor_id, var_id), np.ones(fg.variables[var_id].cardinality)) residual = np.max(np.abs(new_msg - old_msg)) if residual > max_residual: max_residual = residual best_message = (factor_id, var_id) return best_message, max_residualBelief propagation naturally extends from marginal inference to MAP inference by replacing summation with maximization. This variant is called max-product (or max-sum when working in log space).
Max-Product Message Updates:
Variable-to-factor (unchanged): μ_{x→f}(x) = ∏{g ∈ N(x)\f} μ{g→x}(x)
Factor-to-variable (max instead of sum): μ_{f→x}(x) = max{x_f\x} f(x_f) ∏{y ∈ N(f)\x} μ_{y→f}(y)
The only change is replacing Σ with max in the factor-to-variable message. This propagates 'best configurations' instead of 'total probability mass.'
In practice, max-product is implemented as max-sum in log space. Log-probabilities are added (products become sums), and we take the max. This is numerically stable and exactly equivalent to max-product. The Viterbi algorithm for HMMs is max-sum BP on a chain graph!
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
class MaxSumBP: """ Max-Sum Belief Propagation for MAP inference. Works in log-space for numerical stability. Computes the most probable configuration. """ def __init__(self, factor_graph: FactorGraph): self.fg = factor_graph # Messages stored in log space self.log_messages: Dict[Tuple[str, str], np.ndarray] = {} # Track argmax for backtracking self.traceback: Dict[Tuple[str, str], np.ndarray] = {} self._initialize_messages() def _initialize_messages(self): """Initialize log-messages to zero (uniform in probability space).""" for var_id, var in self.fg.variables.items(): for factor_id in var.neighbors: self.log_messages[(var_id, factor_id)] = np.zeros(var.cardinality) for factor_id, factor in self.fg.factors.items(): for var_id in factor.variables: self.log_messages[(factor_id, var_id)] = np.zeros( self.fg.variables[var_id].cardinality ) def compute_factor_to_variable_message( self, factor_id: str, target_var_id: str ) -> Tuple[np.ndarray, np.ndarray]: """ Compute max-sum message from factor to variable. Returns: - log_message: The log-space message values - argmax: Best assignments to other variables for backtracking """ factor = self.fg.factors[factor_id] target_var = self.fg.variables[target_var_id] # Start with log of factor potential log_potential = np.log(factor.potential + 1e-10) result = log_potential.copy() # Add incoming log-messages from other variables for var_id in factor.variables: if var_id != target_var_id: incoming = self.log_messages[(var_id, factor_id)] axis = factor.get_variable_axis(var_id) shape = [1] * len(factor.variables) shape[axis] = len(incoming) result = result + incoming.reshape(shape) # Max over all variables except target, keeping track of argmax target_axis = factor.get_variable_axis(target_var_id) axes_to_max = [i for i in range(len(factor.variables)) if i != target_axis] # Store argmax for backtracking argmax_values = [] for axis in sorted(axes_to_max, reverse=True): argmax_values.insert(0, np.argmax(result, axis=axis)) result = np.max(result, axis=axis) # Normalize for numerical stability result = result - np.max(result) return result, argmax_values def compute_map_assignment(self) -> Dict[str, int]: """ Compute the MAP assignment after message passing. Uses backtracking through stored argmax values. """ if not self.fg.is_tree(): raise ValueError("Exact MAP only guaranteed on trees") # Run message passing (similar to sum-product) root = list(self.fg.variables.keys())[0] self._run_message_passing(root) # Find best assignment for each variable assignment = {} # Start with root: take argmax of its belief root_belief = self._compute_belief(root) assignment[root] = int(np.argmax(root_belief)) # Backtrack through tree to recover full assignment self._backtrack(root, assignment) return assignment def _compute_belief(self, var_id: str) -> np.ndarray: """Compute log-belief for a variable (sum of incoming log-messages).""" var = self.fg.variables[var_id] belief = np.zeros(var.cardinality) for factor_id in var.neighbors: belief = belief + self.log_messages[(factor_id, var_id)] return belief def _backtrack(self, current: str, assignment: Dict[str, int], visited: Set[str] = None): """Recursively backtrack to recover MAP assignment.""" if visited is None: visited = set() visited.add(current) # For each unvisited neighbor, determine its best value # given the current assignment for factor_id in self.fg.variables[current].neighbors: factor = self.fg.factors[factor_id] for var_id in factor.variables: if var_id not in visited and var_id not in assignment: # Use stored traceback to find best value if (factor_id, var_id) in self.traceback: idx = assignment[current] assignment[var_id] = int( self.traceback[(factor_id, var_id)][idx] ) self._backtrack(var_id, assignment, visited)When multiple configurations have the same maximum probability, max-product may return different results depending on tie-breaking. Also, marginal 'max-marginals' (the value max-product computes at each node) don't directly give marginal probabilities—they give the probability of the MAP assignment involving each value.
Understanding the computational complexity of belief propagation reveals when it's efficient and what factors drive runtime.
Message Size:
Each message μ_{x→f}(x) or μ_{f→x}(x) is a vector of size |Val(x)| = k (the cardinality of variable x). If all variables are binary, all messages have size 2.
Number of Messages:
Each edge (x, f) carries two messages: one in each direction. Total messages = 2|E|, where |E| is the number of edges in the factor graph.
| Operation | Time Complexity | Notes |
|---|---|---|
| Variable-to-factor message | O(d · k) | d = degree of variable, k = cardinality. Product of d-1 incoming messages. |
| Factor-to-variable message | O(k^f) | f = factor arity. Must sum/max over k^(f-1) configurations. |
| One complete iteration | O(|E| · k^{max arity}) | Sum over all message computations. |
| Tree BP (exact) | O(|E| · k^{max arity}) | Two passes suffice. |
| Storage | O(|E| · k) | Store one vector per directed edge. |
The Factor Arity Bottleneck:
The most expensive operation is the factor-to-variable message, which requires summing (or maximizing) over all configurations of other variables in the factor. For a factor of arity f with cardinality k:
This means high-arity factors (many variables) create computational bottlenecks. Strategies to mitigate this include:
BP stores O(|E| · k) for messages, much less than VE which may store O(k^w) for intermediate factors. This memory efficiency is why BP is preferred in many implementations, especially for loopy graphs where junction trees (requiring full clique storage) would be prohibitive.
Production-quality BP implementations incorporate several optimizations beyond the basic algorithm. Here are key techniques used in libraries like libDAI, pgmpy, and OpenGM.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
import numpy as npfrom scipy.special import logsumexpfrom typing import Dict, Tupleimport opt_einsum as oe class EfficientSumProductBP: """ Optimized sum-product BP with: - Log-space computation - Vectorized operations - Message damping - Convergence tracking """ def __init__( self, factor_graph: FactorGraph, damping: float = 0.5, log_space: bool = True ): self.fg = factor_graph self.damping = damping self.log_space = log_space self.messages: Dict[Tuple[str, str], np.ndarray] = {} # Pre-compute einsum contraction paths for each factor self.contraction_paths = {} self._precompute_contractions() self._initialize_messages() def _precompute_contractions(self): """ Pre-compute optimal contraction paths for factor-to-variable messages. Using opt_einsum for optimal tensor contraction ordering. """ for factor_id, factor in self.fg.factors.items(): for target_var in factor.variables: # Build einsum equation # Factor indices: assume variables ordered as factor.variables factor_indices = list(range(len(factor.variables))) # Other variables get their own message index other_vars = [v for v in factor.variables if v != target_var] # Target variable index in output target_idx = factor.variables.index(target_var) # Cache the variable ordering for this message self.contraction_paths[(factor_id, target_var)] = { 'target_idx': target_idx, 'other_vars': other_vars, 'factor_shape': factor.potential.shape } def _initialize_messages(self): """Initialize messages to uniform (zeros in log-space).""" for var_id, var in self.fg.variables.items(): for factor_id in var.neighbors: init_val = np.zeros(var.cardinality) if self.log_space else \ np.ones(var.cardinality) / var.cardinality self.messages[(var_id, factor_id)] = init_val.copy() for factor_id, factor in self.fg.factors.items(): for var_id in factor.variables: var = self.fg.variables[var_id] init_val = np.zeros(var.cardinality) if self.log_space else \ np.ones(var.cardinality) / var.cardinality self.messages[(factor_id, var_id)] = init_val.copy() def compute_factor_to_variable_message_optimized( self, factor_id: str, target_var_id: str ) -> np.ndarray: """ Compute factor-to-variable message with vectorized operations. """ factor = self.fg.factors[factor_id] path_info = self.contraction_paths[(factor_id, target_var_id)] if self.log_space: # Log-space computation log_potential = np.log(factor.potential + 1e-10) result = log_potential.copy() # Add incoming messages for other variables for var_id in path_info['other_vars']: incoming = self.messages[(var_id, factor_id)] axis = factor.variables.index(var_id) shape = [1] * len(factor.variables) shape[axis] = len(incoming) result = result + incoming.reshape(shape) # Marginalize with logsumexp target_idx = path_info['target_idx'] axes_to_marginalize = tuple( i for i in range(len(factor.variables)) if i != target_idx ) for axis in sorted(axes_to_marginalize, reverse=True): result = logsumexp(result, axis=axis) # Normalize result = result - logsumexp(result) else: # Standard space result = factor.potential.copy() for var_id in path_info['other_vars']: incoming = self.messages[(var_id, factor_id)] axis = factor.variables.index(var_id) shape = [1] * len(factor.variables) shape[axis] = len(incoming) result = result * incoming.reshape(shape) target_idx = path_info['target_idx'] axes_to_marginalize = tuple( i for i in range(len(factor.variables)) if i != target_idx ) for axis in sorted(axes_to_marginalize, reverse=True): result = np.sum(result, axis=axis) result = result / (np.sum(result) + 1e-10) return result def run_iteration(self) -> float: """ Run one iteration of BP, updating all messages. Returns: Maximum message change (for convergence check) """ max_change = 0.0 new_messages = {} # Compute all new messages for (src, tgt), old_msg in self.messages.items(): if src in self.fg.variables: new_msg = self._compute_var_to_factor_msg(src, tgt) else: new_msg = self.compute_factor_to_variable_message_optimized( src, tgt ) # Apply damping if self.damping > 0: new_msg = self.damping * new_msg + (1 - self.damping) * old_msg # Track convergence change = np.max(np.abs(new_msg - old_msg)) max_change = max(max_change, change) new_messages[(src, tgt)] = new_msg self.messages = new_messages return max_change def run_until_convergence( self, max_iters: int = 100, tolerance: float = 1e-6 ) -> Tuple[Dict[str, np.ndarray], int]: """ Run BP until convergence or max iterations. Returns: (marginals, iterations_used) """ for iteration in range(max_iters): max_change = self.run_iteration() if max_change < tolerance: break marginals = self._compute_all_marginals() return marginals, iteration + 1Belief propagation reframes inference as a message-passing protocol on factor graphs. On trees, it computes all marginals exactly in time linear in the number of edges—a dramatic improvement over running separate variable elimination queries.
Core Concepts:
Connection to What's Next:
Belief propagation on trees is beautiful but limited—real-world models often have cycles. The next page introduces the Junction Tree Algorithm, which converts any graph into a tree of cliques where exact inference remains tractable. This is the gold standard for exact inference in moderately-sized graphical models.
You now understand Belief Propagation—the message-passing paradigm for graphical model inference. You can implement the sum-product and max-product algorithms, analyze their complexity, and appreciate why BP is exact on trees. Next, we'll see how junction trees extend exact inference to graphs with cycles.