Loading learning content...
Variable elimination and belief propagation are powerful—but they face a fundamental limitation: graphs with cycles. Standard BP is only guaranteed correct on trees; VE's efficiency depends critically on elimination ordering. What if we could transform any graphical model into a tree structure where exact inference is tractable?
The Junction Tree Algorithm accomplishes exactly this. By carefully clustering variables into cliques and organizing them into a tree, we create a structure where BP produces exact marginals for all cliques simultaneously. This is the gold standard for exact inference in probabilistic graphical models.
By the end of this page, you will understand: (1) why cycles break belief propagation and how junction trees resolve this, (2) the complete construction of junction trees from graphical models, (3) the running intersection property and its crucial role in correctness, (4) the message-passing algorithm on junction trees, and (5) complexity analysis and practical implementation considerations.
Before constructing junction trees, we must understand why cycles cause problems for belief propagation. The issue is fundamental: cycles allow information to loop back, leading to double-counting of evidence.
The Double-Counting Problem:
Consider a simple triangle graph with variables A, B, C and factors on edges: f₁(A,B), f₂(B,C), f₃(A,C).
When we run BP:
This creates a feedback loop where the same evidence gets counted multiple times. The resulting 'beliefs' don't equal true marginals.
Symptoms of Loopy BP:
Incorrect marginals: Beliefs may not equal true P(X)
Non-convergence: Messages may oscillate forever
Multiple fixed points: Different schedules may converge to different (wrong) answers
Overconfident beliefs: Repeated evidence counting makes beliefs too peaked
Despite theoretical issues, loopy BP often works well in practice! It's widely used in error-correcting codes, computer vision, and other applications. But when exact inference is needed, junction trees provide the principled solution.
The Key Insight:
The problem with cycles is that messages aren't independent—each message may include information that came from the same evidence via a different path. In a tree, this can't happen: there's only one path between any two nodes, so messages carry genuinely disjoint information.
The Junction Tree Solution:
A junction tree solves this by:
With these properties, BP on the junction tree is exact—each clique's message genuinely summarizes independent evidence.
Junction tree construction proceeds in three main phases: moralization, triangulation, and tree construction. Each phase transforms the graph while preserving the original probability distribution.
Phase 1: Moralization (for Directed Graphs)
If starting from a Bayesian network (directed graph):
The result is an undirected graph called the moral graph. Moralization ensures that all variables in each conditional probability table are connected.
Phase 2: Triangulation (Chordal Graph Construction)
A graph is chordal (or triangulated) if every cycle of length ≥ 4 has a chord (an edge between non-adjacent cycle vertices). Triangulation adds edges to make the graph chordal.
Why triangulation matters:
Triangulation is equivalent to running variable elimination and noting which edges the fill-in creates. Different elimination orderings produce different triangulations—the goal is to minimize the maximum clique size (which determines inference complexity).
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
from typing import Dict, List, Set, Tuple, FrozenSetfrom dataclasses import dataclass, fieldfrom collections import defaultdictimport numpy as np @dataclassclass UndirectedGraph: """Simple undirected graph for junction tree construction.""" nodes: Set[str] edges: Set[FrozenSet[str]] def add_edge(self, u: str, v: str): self.edges.add(frozenset([u, v])) def neighbors(self, node: str) -> Set[str]: result = set() for edge in self.edges: if node in edge: result.update(edge - {node}) return result def copy(self) -> 'UndirectedGraph': return UndirectedGraph( nodes=self.nodes.copy(), edges=self.edges.copy() ) def moralize(dag: Dict[str, List[str]]) -> UndirectedGraph: """ Convert a directed acyclic graph to its moral graph. Args: dag: Dict mapping each node to its list of parents Returns: Undirected moral graph """ nodes = set(dag.keys()) edges = set() for child, parents in dag.items(): # Add edges from each parent to child for parent in parents: edges.add(frozenset([parent, child])) # "Marry" the parents: connect all pairs for i, p1 in enumerate(parents): for p2 in parents[i+1:]: edges.add(frozenset([p1, p2])) return UndirectedGraph(nodes=nodes, edges=edges) def triangulate(graph: UndirectedGraph, order: List[str] = None) -> Tuple[UndirectedGraph, List[Set[str]]]: """ Triangulate a graph using variable elimination ordering. Args: graph: Undirected graph to triangulate order: Elimination order (if None, uses min-fill heuristic) Returns: - Triangulated graph - List of cliques formed during elimination """ result = graph.copy() cliques = [] remaining = graph.nodes.copy() if order is None: order = _min_fill_ordering(graph) for node in order: # Get current neighbors neighbors = result.neighbors(node) & remaining # Record the clique (node + neighbors) clique = neighbors | {node} cliques.append(clique) # Add fill edges: connect all pairs of neighbors neighbors_list = list(neighbors) for i, n1 in enumerate(neighbors_list): for n2 in neighbors_list[i+1:]: result.add_edge(n1, n2) remaining.remove(node) return result, cliques def _min_fill_ordering(graph: UndirectedGraph) -> List[str]: """Compute min-fill elimination ordering.""" result = graph.copy() remaining = graph.nodes.copy() order = [] while remaining: # Find node with minimum fill-in best_node = None best_fill = float('inf') for node in remaining: neighbors = result.neighbors(node) & remaining fill = 0 neighbors_list = list(neighbors) for i, n1 in enumerate(neighbors_list): for n2 in neighbors_list[i+1:]: if frozenset([n1, n2]) not in result.edges: fill += 1 if fill < best_fill: best_fill = fill best_node = node # Add fill edges neighbors = result.neighbors(best_node) & remaining neighbors_list = list(neighbors) for i, n1 in enumerate(neighbors_list): for n2 in neighbors_list[i+1:]: result.add_edge(n1, n2) order.append(best_node) remaining.remove(best_node) return order def find_maximal_cliques(triangulated: UndirectedGraph, cliques: List[Set[str]]) -> List[Set[str]]: """ From elimination cliques, extract maximal cliques only. A clique is maximal if no other clique is a proper superset. """ maximal = [] for clique in cliques: is_maximal = True for other in cliques: if clique < other: # Proper subset is_maximal = False break if is_maximal and clique not in maximal: maximal.append(clique) return maximalThe maximum clique size minus 1 equals the treewidth of the elimination ordering. Finding the optimal triangulation (minimum treewidth) is NP-hard, but min-fill and min-degree heuristics work well in practice. The treewidth directly determines junction tree inference complexity!
The running intersection property (RIP) is the key invariant that makes junction trees work. It ensures that information about any variable is contained in a connected subtree, preventing the double-counting problems of loopy BP.
Definition: Running Intersection Property
A tree of cliques T satisfies the running intersection property if:
For every variable X, the set of cliques containing X forms a connected subtree of T.
Equivalently: if variable X appears in cliques C₁ and C₂, then X appears in every clique on the path from C₁ to C₂.
Why RIP Matters:
Correct marginalization: When summing over a variable, we only sum in cliques at the 'boundary' of its subtree
No double-counting: A variable's factor contributions stay local to its subtree
Consistent beliefs: All cliques containing a variable will agree on its marginal
Message independence: Messages between cliques carry genuinely disjoint information
Sepsets (Separators):
For adjacent cliques Cᵢ and Cⱼ in the junction tree, the sepset Sᵢⱼ = Cᵢ ∩ Cⱼ.
Sepsets are crucial:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
@dataclassclass JunctionTree: """ Junction tree structure for exact inference. A tree of cliques with sepsets on edges, satisfying the running intersection property. """ cliques: Dict[int, Set[str]] # Clique ID -> variables edges: List[Tuple[int, int]] # Edges between clique IDs sepsets: Dict[Tuple[int, int], Set[str]] # Edge -> separator potentials: Dict[int, np.ndarray] # Clique potentials variable_to_cliques: Dict[str, List[int]] # Which cliques contain each var def verify_running_intersection(self) -> bool: """ Verify the running intersection property holds. For each variable, check that its cliques form a connected subtree. """ # Build adjacency list adj = defaultdict(list) for c1, c2 in self.edges: adj[c1].append(c2) adj[c2].append(c1) for var, clique_ids in self.variable_to_cliques.items(): if len(clique_ids) <= 1: continue # Check connectivity of cliques containing var # BFS from first clique, only traverse through cliques with var visited = {clique_ids[0]} queue = [clique_ids[0]] while queue: current = queue.pop(0) for neighbor in adj[current]: if neighbor not in visited and var in self.cliques[neighbor]: visited.add(neighbor) queue.append(neighbor) # All cliques with var should be reachable if len(visited) != len(clique_ids): return False return True def build_junction_tree( maximal_cliques: List[Set[str]], original_factors: List[Factor]) -> JunctionTree: """ Build a junction tree from maximal cliques. Uses maximum weight spanning tree on clique graph, where weight = |intersection| (maxmizes sepset size). """ n_cliques = len(maximal_cliques) # Build weighted clique graph # Weight of edge = |C_i ∩ C_j| (size of intersection) weights = {} for i in range(n_cliques): for j in range(i + 1, n_cliques): intersection = maximal_cliques[i] & maximal_cliques[j] if intersection: weights[(i, j)] = len(intersection) # Maximum weight spanning tree (Kruskal's algorithm) edges = sorted(weights.keys(), key=lambda e: -weights[e]) # Union-find for cycle detection parent = list(range(n_cliques)) def find(x): if parent[x] != x: parent[x] = find(parent[x]) return parent[x] tree_edges = [] for i, j in edges: pi, pj = find(i), find(j) if pi != pj: parent[pi] = pj tree_edges.append((i, j)) if len(tree_edges) == n_cliques - 1: break # Compute sepsets sepsets = {} for i, j in tree_edges: sepsets[(i, j)] = maximal_cliques[i] & maximal_cliques[j] sepsets[(j, i)] = sepsets[(i, j)] # Map variables to cliques var_to_cliques = defaultdict(list) for clique_id, clique in enumerate(maximal_cliques): for var in clique: var_to_cliques[var].append(clique_id) # Assign factors to cliques # Each factor assigned to smallest clique containing its scope clique_potentials = {i: None for i in range(n_cliques)} for factor in original_factors: # Find smallest clique containing factor scope factor_scope = set(factor.variables) assigned = False for clique_id, clique in sorted( enumerate(maximal_cliques), key=lambda x: len(x[1]) ): if factor_scope <= clique: # Initialize or multiply into clique potential if clique_potentials[clique_id] is None: clique_potentials[clique_id] = _expand_factor_to_clique( factor, clique ) else: expanded = _expand_factor_to_clique(factor, clique) clique_potentials[clique_id] *= expanded assigned = True break if not assigned: raise ValueError(f"Factor scope {factor_scope} not contained in any clique") # Initialize unassigned clique potentials to 1 for clique_id in clique_potentials: if clique_potentials[clique_id] is None: shape = tuple(2 for _ in maximal_cliques[clique_id]) # Assume binary clique_potentials[clique_id] = np.ones(shape) return JunctionTree( cliques={i: c for i, c in enumerate(maximal_cliques)}, edges=tree_edges, sepsets=sepsets, potentials=clique_potentials, variable_to_cliques=dict(var_to_cliques) ) def _expand_factor_to_clique( factor: Factor, clique: Set[str]) -> np.ndarray: """ Expand a factor to match clique dimensions. Add singleton dimensions for variables in clique but not in factor. """ # This is a simplified version; production code handles variable ordering clique_vars = sorted(clique) factor_vars = list(factor.variables) result = factor.values for var in clique_vars: if var not in factor_vars: # Add new axis for this variable result = np.expand_dims(result, axis=-1) return resultWith the junction tree constructed, inference proceeds via message passing between cliques—essentially belief propagation on the tree of cliques. The algorithm has two phases: collect (toward a root) and distribute (away from root).
Clique-to-Clique Messages:
Let Ψᵢ denote the potential of clique Cᵢ, and let Sᵢⱼ = Cᵢ ∩ Cⱼ be the sepset.
The message from clique Cᵢ to clique Cⱼ is:
μᵢ→ⱼ(Sᵢⱼ) = Σ_{Cᵢ\Sᵢⱼ} Ψᵢ(Cᵢ) × ∏_{k ∈ neighbors(i)\j} μₖ→ᵢ(Sₖᵢ)
In words: take the clique potential, multiply by all incoming messages (except from target), and marginalize down to the sepset.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
class JunctionTreeInference: """ Exact inference on junction trees via message passing. Implements the Shafer-Shenoy algorithm for computing marginals. """ def __init__(self, junction_tree: JunctionTree): self.jt = junction_tree # Messages indexed by directed edge (from_clique, to_clique) self.messages: Dict[Tuple[int, int], np.ndarray] = {} # Build adjacency list self.adj = defaultdict(list) for c1, c2 in self.jt.edges: self.adj[c1].append(c2) self.adj[c2].append(c1) def run_inference(self, evidence: Dict[str, int] = None) -> Dict[str, np.ndarray]: """ Run complete junction tree inference. Returns marginal distributions for all variables. """ # Apply evidence by modifying clique potentials potentials = self._apply_evidence(evidence or {}) # Choose arbitrary root root = 0 # Phase 1: Collect (leaves toward root) self._collect(root, potentials, visited=set()) # Phase 2: Distribute (root toward leaves) self._distribute(root, potentials, visited=set()) # Compute marginals by marginalizing calibrated cliques return self._compute_marginals(potentials) def _apply_evidence( self, evidence: Dict[str, int] ) -> Dict[int, np.ndarray]: """ Apply evidence by zeroing out inconsistent entries in clique potentials. """ potentials = {i: p.copy() for i, p in self.jt.potentials.items()} for var, value in evidence.items(): # Find a clique containing var if var not in self.jt.variable_to_cliques: continue clique_id = self.jt.variable_to_cliques[var][0] clique_vars = sorted(self.jt.cliques[clique_id]) var_axis = clique_vars.index(var) # Zero out all entries where var != value slices = [slice(None)] * len(clique_vars) cardinality = potentials[clique_id].shape[var_axis] for v in range(cardinality): if v != value: slices[var_axis] = v potentials[clique_id][tuple(slices)] = 0 slices[var_axis] = slice(None) return potentials def _compute_message( self, from_clique: int, to_clique: int, potentials: Dict[int, np.ndarray] ) -> np.ndarray: """ Compute message from one clique to another. μ_{i→j} = Σ_{C_i \ S_ij} ψ_i × ∏_{k ∈ N(i)\j} μ_{k→i} """ # Start with clique potential result = potentials[from_clique].copy() # Multiply incoming messages from other neighbors for neighbor in self.adj[from_clique]: if neighbor != to_clique: incoming_msg = self.messages.get((neighbor, from_clique)) if incoming_msg is not None: result = self._multiply_sepset_message( result, incoming_msg, self.jt.cliques[from_clique], self.jt.sepsets[(neighbor, from_clique)] ) # Marginalize to sepset sepset = self.jt.sepsets[(from_clique, to_clique)] result = self._marginalize_to_sepset( result, self.jt.cliques[from_clique], sepset ) return result def _multiply_sepset_message( self, clique_potential: np.ndarray, sepset_message: np.ndarray, clique_vars: Set[str], sepset_vars: Set[str] ) -> np.ndarray: """Multiply a sepset message into a clique potential.""" clique_vars_list = sorted(clique_vars) sepset_vars_list = sorted(sepset_vars) # Expand message dimensions to match clique msg_expanded = sepset_message for var in clique_vars_list: if var not in sepset_vars_list: msg_expanded = np.expand_dims(msg_expanded, axis=-1) # Reorder axes to match clique variable ordering # (Simplified: assumes consistent ordering) return clique_potential * msg_expanded def _marginalize_to_sepset( self, potential: np.ndarray, clique_vars: Set[str], sepset_vars: Set[str] ) -> np.ndarray: """Marginalize a clique potential down to sepset variables.""" clique_vars_list = sorted(clique_vars) sepset_vars_list = sorted(sepset_vars) # Sum out variables not in sepset result = potential for i, var in reversed(list(enumerate(clique_vars_list))): if var not in sepset_vars_list: result = np.sum(result, axis=i) return result def _collect( self, node: int, potentials: Dict[int, np.ndarray], visited: Set[int] ): """ Collect phase: recursively gather messages toward root. Post-order traversal: children send to parent after receiving from their children. """ visited.add(node) for neighbor in self.adj[node]: if neighbor not in visited: # Recurse to children first self._collect(neighbor, potentials, visited) # Child sends message to current node self.messages[(neighbor, node)] = self._compute_message( neighbor, node, potentials ) def _distribute( self, node: int, potentials: Dict[int, np.ndarray], visited: Set[int] ): """ Distribute phase: send messages from root toward leaves. Pre-order traversal: parent sends to children before they recurse. """ visited.add(node) for neighbor in self.adj[node]: if neighbor not in visited: # Current node sends message to child self.messages[(node, neighbor)] = self._compute_message( node, neighbor, potentials ) # Recurse to child self._distribute(neighbor, potentials, visited) def _compute_marginals( self, potentials: Dict[int, np.ndarray] ) -> Dict[str, np.ndarray]: """ Compute marginals from calibrated junction tree. For each variable, find a clique containing it and marginalize. """ marginals = {} for var, clique_ids in self.jt.variable_to_cliques.items(): # Use first clique containing the variable clique_id = clique_ids[0] # Compute calibrated clique belief belief = potentials[clique_id].copy() for neighbor in self.adj[clique_id]: incoming = self.messages.get((neighbor, clique_id)) if incoming is not None: belief = self._multiply_sepset_message( belief, incoming, self.jt.cliques[clique_id], self.jt.sepsets[(neighbor, clique_id)] ) # Marginalize to single variable clique_vars = sorted(self.jt.cliques[clique_id]) var_axis = clique_vars.index(var) for i in reversed(range(len(clique_vars))): if i != var_axis: belief = np.sum(belief, axis=i) # Normalize marginals[var] = belief / np.sum(belief) return marginalsAfter both phases complete, the junction tree is calibrated: for any adjacent cliques Cᵢ and Cⱼ, marginalizing either clique belief to the sepset yields the same result. This consistency is guaranteed by the running intersection property.
The computational complexity of junction tree inference depends directly on the size of the largest clique—which is determined by the graph's treewidth.
Key Complexity Parameters:
| Phase | Time Complexity | Space Complexity | Notes |
|---|---|---|---|
| Moralization | O(n + m) | O(n²) edges | Adding marriage edges |
| Triangulation | O(n³) worst case | O(n²) fill edges | Min-fill heuristic; optimal is NP-hard |
| Clique identification | O(n + edges) | O(c · w) | Finding maximal cliques in chordal graph |
| Tree construction | O(c² · w) | O(c) | Maximum spanning tree on clique graph |
| Message passing | O(c · k^{w+1}) | O(c · k^w) | Dominates total cost |
| Marginal computation | O(n · k^{w+1}) | O(n · k) | Marginalize from clique beliefs |
The Exponential Wall:
The k^{w+1} term means junction tree inference is exponential in treewidth. This is unavoidable for exact inference (under standard complexity assumptions):
Practical Implications:
Junction trees are practical for models with treewidth up to about 20-30 (with binary variables). Beyond that, approximate methods become necessary.
Many real-world models have high treewidth (dense connections, grid structures). For these, junction trees are impractical, and we must turn to approximate inference: loopy belief propagation, variational methods, or sampling techniques (covered in the next page).
Two main algorithmic architectures exist for junction tree inference: Shafer-Shenoy and Hugin. They compute the same marginals but differ in how they handle sepset beliefs and message computation.
Shafer-Shenoy Architecture:
Hugin Architecture:
Hugin maintains sepset beliefs that are updated during message passing:
Collect phase:
Distribute phase:
The key difference: Hugin modifies clique potentials in place, while Shafer-Shenoy keeps them unchanged.
Which to Use?
Building production-quality junction tree implementations involves several practical considerations beyond the core algorithm.
For production use, leverage well-tested libraries: pgmpy (Python), libDAI (C++), or HUGIN/GeNIe (commercial). These handle numerical issues, sparse representations, and optimized algorithms that would take significant effort to implement correctly from scratch.
The junction tree algorithm is the definitive method for exact inference in probabilistic graphical models. By transforming any model into a tree of cliques satisfying the running intersection property, it enables exact belief propagation regardless of the original graph's structure.
Connection to What's Next:
Junction trees provide exact inference—but what about graphs with high treewidth where exact methods are intractable? The next page introduces Loopy Belief Propagation, which applies BP directly to cyclic graphs as an approximation. Though not exact, loopy BP often works remarkably well in practice and forms the basis for modern applications in coding theory, computer vision, and machine learning.
You now understand the Junction Tree Algorithm—the gold standard for exact inference in graphical models. You can construct junction trees from any graphical model, run message passing for exact marginals, and analyze when junction trees are practical vs. when approximation is needed.