Loading learning content...
KD-trees partition space using axis-aligned hyperplanes—vertical and horizontal cuts that slice the space into rectangular regions. This works beautifully for low-dimensional Euclidean data, but has two fundamental limitations:
Axis-alignment is arbitrary: Why should data care about our coordinate axes? Natural clusters are often spherical, not rectangular.
Hyperplanes require coordinates: What if we only have a distance function, not explicit feature vectors? Non-vectorial data like strings (edit distance), graphs (graph kernels), or probability distributions (KL divergence) can't use hyperplane splits.
The Ball Tree (also called metric tree or VP-tree in its vantage-point variant) addresses both limitations by partitioning space using hyperspheres instead of hyperplanes. Each node owns a ball—all points within some radius of a center point—and recursively subdivides into smaller balls.
By the end of this page, you will understand how Ball trees partition space using hyperspheres, implement construction and search algorithms, analyze their complexity and pruning conditions, understand when Ball trees outperform KD-trees, and recognize their role in general metric space search.
To appreciate Ball trees, let's examine why spherical partitions can be superior to rectangular partitions.
The Geometry of Nearest Neighbor Search:
When searching for the nearest neighbor of query $\mathbf{q}$, we need to determine which regions of space might contain a closer point than our current best at distance $r$. This means checking if a region intersects the query ball: the hypersphere centered at $\mathbf{q}$ with radius $r$.
For KD-tree rectangles:
A rectangle intersects a query ball if any part of the rectangle is within distance $r$ of $\mathbf{q}$. Computing this intersection requires checking multiple faces and edges—complexity grows with dimension.
Worse, in high dimensions, rectangles tend to be 'spiky'—much longer along some dimensions than others. The corners of a high-dimensional rectangle are far from its center, so even when the center is distant, a corner might be close to the query.
For Ball tree hyperspheres:
A ball with center $\mathbf{c}$ and radius $R$ intersects the query ball of radius $r$ if and only if:
$$d(\mathbf{q}, \mathbf{c}) \leq r + R$$
This is just one distance computation—the same cost regardless of dimension. No corners, no edges, no geometric complexity.
| Property | KD-tree (Rectangles) | Ball Tree (Spheres) |
|---|---|---|
| Intersection test | Check distance to hyperplane | Check distance to center |
| Test complexity | $O(d)$ for all faces | $O(d)$ for one point |
| Shape matches query ball | Poor (corners protrude) | Perfect (sphere vs sphere) |
| Volume efficiency | Poor in high-d | Optimal for radial queries |
| Metric space support | Requires coordinates | Only needs distances |
A ball is the tightest convex container for radial queries. When we're searching for 'everything within distance r', the natural boundary is a sphere. Rectangular bounds waste space by including corners that are farther away than the rectangle's 'radius' suggests.
A Ball tree is a binary tree where each node represents a ball (hypersphere) containing a subset of data points.
Node Components:
| Component | Description | Purpose |
|---|---|---|
center | Center of the ball | Reference point for distance checks |
radius | Radius of the ball | Maximum distance from center to any contained point |
points | Data points in this ball (leaf only) | Actual data for search results |
left | Left child ball | First partition |
right | Right child ball | Second partition |
Key Invariants:
Covering property: For any node $v$ with center $\mathbf{c}_v$ and radius $R_v$: $$\forall \mathbf{x} \in v: d(\mathbf{x}, \mathbf{c}_v) \leq R_v$$
Hierarchical containment: Child balls are contained within (or equal to) parent balls
Partition property: Every point appears in exactly one leaf
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
import numpy as npfrom dataclasses import dataclass, fieldfrom typing import Optional, List, Callable @dataclassclass BallNode: """ A node in a Ball tree. Attributes: ----------- center : np.ndarray Center of the bounding ball (d-dimensional) radius : float Radius of the bounding ball (max distance to any contained point) points : np.ndarray, optional For leaf nodes, the actual data points (shape: n_points × d) indices : np.ndarray, optional Original indices of contained points left : BallNode, optional Left child (first partition) right : BallNode, optional Right child (second partition) n_points : int Number of points in this subtree """ center: np.ndarray radius: float points: Optional[np.ndarray] = None indices: Optional[np.ndarray] = None left: Optional['BallNode'] = None right: Optional['BallNode'] = None n_points: int = 0 def is_leaf(self) -> bool: """Check if this is a leaf node.""" return self.left is None and self.right is None def min_distance_to_query(self, query: np.ndarray) -> float: """ Compute minimum possible distance from query to any point in this ball. This is the crucial pruning criterion: - If query is inside the ball: min distance is 0 - If query is outside: min distance is dist(query, center) - radius """ dist_to_center = np.sqrt(np.sum((query - self.center) ** 2)) return max(0.0, dist_to_center - self.radius) def max_distance_to_query(self, query: np.ndarray) -> float: """ Compute maximum possible distance from query to any point in this ball. Useful for bounding computations. """ dist_to_center = np.sqrt(np.sum((query - self.center) ** 2)) return dist_to_center + self.radius @dataclassclass BallTree: """ Ball Tree for efficient nearest neighbor search. Attributes: ----------- root : BallNode Root of the tree n_points : int Total number of indexed points n_dims : int Dimensionality of the data distance : Callable Distance function (default: Euclidean) leaf_size : int Maximum points in a leaf node """ root: Optional[BallNode] n_points: int n_dims: int distance: Callable = field(default=lambda x, y: np.sqrt(np.sum((x - y) ** 2))) leaf_size: int = 20 def __repr__(self) -> str: return f"BallTree(n_points={self.n_points}, n_dims={self.n_dims})"The 'center' of a ball node can be computed different ways: (1) centroid of contained points, (2) center of minimum enclosing ball, or (3) an actual data point. Each choice has tradeoffs in construction cost, ball tightness, and search efficiency.
Ball tree construction involves recursively partitioning points into two groups and computing bounding balls. The key challenge is choosing how to split points.
Construction Steps:
Splitting Strategies:
| Strategy | Description | Cost per Node | Quality |
|---|---|---|---|
| Principal Axis | Split along direction of maximum variance | $O(nd)$ | Good |
| Ball Split | Pick two distant points, assign others to closer | $O(nd)$ | Good |
| Random | Randomly partition | $O(n)$ | Poor |
| Five-point | Sample 5 points, use extreme pair | $O(n)$ | Good |
The Ball Split (or "pivoting") method is most common:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
import numpy as npfrom typing import Optional, Tuple def compute_bounding_ball(points: np.ndarray) -> Tuple[np.ndarray, float]: """ Compute the bounding ball for a set of points. Uses centroid as center (fast approximation). For minimum enclosing ball, use Welzl's algorithm. Parameters: ----------- points : np.ndarray, shape (n, d) Points to enclose Returns: -------- center : np.ndarray Center of bounding ball radius : float Radius of bounding ball """ # Center = centroid of points center = np.mean(points, axis=0) # Radius = distance to farthest point distances = np.sqrt(np.sum((points - center) ** 2, axis=1)) radius = np.max(distances) return center, radius def ball_split( points: np.ndarray, indices: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Split points into two groups using the pivot method. 1. Pick a random point p1 2. Find p2 = farthest point from p1 3. Assign each point to closer of {p1, p2} Parameters: ----------- points : np.ndarray, shape (n, d) Points to split indices : np.ndarray, shape (n,) Original indices of points Returns: -------- left_points, left_indices, right_points, right_indices """ n = len(points) # Step 1: Pick p1 (use first point; random would also work) p1_idx = 0 p1 = points[p1_idx] # Step 2: Find p2 = farthest from p1 dist_from_p1 = np.sqrt(np.sum((points - p1) ** 2, axis=1)) p2_idx = np.argmax(dist_from_p1) p2 = points[p2_idx] # Step 3: Assign each point to closer pivot dist_to_p1 = np.sqrt(np.sum((points - p1) ** 2, axis=1)) dist_to_p2 = np.sqrt(np.sum((points - p2) ** 2, axis=1)) # Points closer to p1 go left, closer to p2 go right left_mask = dist_to_p1 <= dist_to_p2 right_mask = ~left_mask # Handle edge case: ensure both partitions are non-empty if not np.any(left_mask): left_mask[0] = True right_mask[0] = False if not np.any(right_mask): right_mask[-1] = True left_mask[-1] = False return ( points[left_mask], indices[left_mask], points[right_mask], indices[right_mask] ) def build_ball_tree( points: np.ndarray, indices: Optional[np.ndarray] = None, leaf_size: int = 20) -> Optional[BallNode]: """ Build a Ball tree from a set of points. Time Complexity: O(n log n) expected Space Complexity: O(n) for the tree Parameters: ----------- points : np.ndarray, shape (n, d) Data points indices : np.ndarray, optional Original indices (for tracking) leaf_size : int Maximum number of points in a leaf Returns: -------- BallNode Root of the constructed tree """ n = len(points) if n == 0: return None if indices is None: indices = np.arange(n) # Compute bounding ball for this node center, radius = compute_bounding_ball(points) # Base case: create leaf node if n <= leaf_size: return BallNode( center=center, radius=radius, points=points.copy(), indices=indices.copy(), n_points=n ) # Split points into two groups left_pts, left_idx, right_pts, right_idx = ball_split(points, indices) # Recursively build children left_child = build_ball_tree(left_pts, left_idx, leaf_size) right_child = build_ball_tree(right_pts, right_idx, leaf_size) return BallNode( center=center, radius=radius, left=left_child, right=right_child, n_points=n ) def build_ball_tree_wrapper(points: np.ndarray, leaf_size: int = 20) -> BallTree: """ Public interface for Ball tree construction. """ n, d = points.shape root = build_ball_tree(points, leaf_size=leaf_size) return BallTree(root=root, n_points=n, n_dims=d, leaf_size=leaf_size)Complexity Analysis:
Time Complexity:
At each level:
Total per level: $O(nd)$
With $O(\log n)$ levels (assuming balanced splits):
$$\boxed{T_{\text{build}} = O(nd \log n)}$$
Space Complexity:
$$\boxed{S_{\text{tree}} = O(nd)}$$
Note: Ball trees use more space than KD-trees due to storing explicit ball centers at internal nodes.
Ball tree search uses the same branch-and-bound paradigm as KD-trees, but with simpler pruning geometry.
Pruning Criterion:
Given:
We can prune this ball if the minimum possible distance from $\mathbf{q}$ to any point in the ball exceeds $d_{\text{best}}$:
$$d_{\min} = \max(0, d(\mathbf{q}, \mathbf{c}) - R)$$
If $d_{\min} \geq d_{\text{best}}$, prune the entire subtree.
Intuition: If the query is at distance $D$ from the ball's center, the closest point in the ball is at distance $D - R$ (on the line from query to center). If even this closest possible point is farther than our current best, no point in the ball can be better.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
import numpy as npimport heapqfrom typing import Tuple, List, Optional def ball_tree_nn_search( node: Optional[BallNode], query: np.ndarray, best: Tuple[float, Optional[np.ndarray], int] = (float('inf'), None, -1)) -> Tuple[float, Optional[np.ndarray], int]: """ Find nearest neighbor in Ball tree using branch-and-bound. Parameters: ----------- node : BallNode Current node in the search query : np.ndarray Query point best : Tuple[float, np.ndarray, int] Current best (distance, point, index) Returns: -------- Tuple containing (distance, point, index) of nearest neighbor """ if node is None: return best best_dist, best_point, best_idx = best # Pruning check: can this ball contain a closer point? min_dist = node.min_distance_to_query(query) if min_dist >= best_dist: # PRUNE: entire ball is too far return best # Leaf node: check all points if node.is_leaf(): for i, point in enumerate(node.points): dist = np.sqrt(np.sum((point - query) ** 2)) if dist < best_dist: best_dist = dist best_point = point best_idx = node.indices[i] return (best_dist, best_point, best_idx) # Internal node: recurse into children # Visit closer child first (better pruning) left_min = node.left.min_distance_to_query(query) if node.left else float('inf') right_min = node.right.min_distance_to_query(query) if node.right else float('inf') if left_min <= right_min: first, second = node.left, node.right else: first, second = node.right, node.left # Search first (closer) child best = ball_tree_nn_search(first, query, best) # Search second child if it might contain closer point best = ball_tree_nn_search(second, query, best) return best def ball_tree_knn_search( node: Optional[BallNode], query: np.ndarray, k: int, heap: List = None) -> List[Tuple[float, np.ndarray, int]]: """ Find K nearest neighbors in Ball tree. Uses a max-heap to track the k closest points. Parameters: ----------- node : BallNode Current node query : np.ndarray Query point k : int Number of neighbors heap : List Max-heap of (-distance, point, index) Returns: -------- List of (distance, point, index) tuples """ if heap is None: heap = [] if node is None: return heap # Get current k-th best distance for pruning kth_best = -heap[0][0] if len(heap) >= k else float('inf') # Pruning check min_dist = node.min_distance_to_query(query) if min_dist >= kth_best: return heap # PRUNE # Leaf node: check all points if node.is_leaf(): for i, point in enumerate(node.points): dist = np.sqrt(np.sum((point - query) ** 2)) if len(heap) < k: heapq.heappush(heap, (-dist, id(point), point, node.indices[i])) elif dist < -heap[0][0]: heapq.heapreplace(heap, (-dist, id(point), point, node.indices[i])) return heap # Internal node: visit children in order of minimum distance left_min = node.left.min_distance_to_query(query) if node.left else float('inf') right_min = node.right.min_distance_to_query(query) if node.right else float('inf') if left_min <= right_min: first, second = node.left, node.right else: first, second = node.right, node.left ball_tree_knn_search(first, query, k, heap) ball_tree_knn_search(second, query, k, heap) return heap def query_ball_tree(tree: BallTree, query: np.ndarray, k: int = 1): """ Public interface for Ball tree KNN query. """ if k == 1: dist, point, idx = ball_tree_nn_search(tree.root, query) return [(dist, point, idx)] else: heap = ball_tree_knn_search(tree.root, query, k) results = [(-d, point, idx) for d, _, point, idx in heap] return sorted(results)For even better pruning, use a priority queue to always expand the node with smallest minimum distance first. This 'best-first' strategy finds good candidates faster, enabling more aggressive pruning later. scikit-learn's BallTree uses this approach.
Ball trees share similar complexity characteristics with KD-trees but with different constants and failure modes.
Query Complexity:
When Ball Trees Outperform KD-Trees:
Ball trees tend to be faster than KD-trees when:
Data is clustered — Spherical partitions match cluster structure better than axis-aligned rectangles
Dimension is moderate (10-50) — Ball pruning degrades more gracefully than hyperplane pruning
Using non-Euclidean metrics — Ball trees work with any metric that satisfies the triangle inequality
When KD-Trees Win:
| Dimension | KD-Tree Query | Ball Tree Query | Winner |
|---|---|---|---|
| 2 | 0.02 ms | 0.03 ms | KD-Tree (1.5×) |
| 5 | 0.04 ms | 0.04 ms | Tie |
| 10 | 0.15 ms | 0.12 ms | Ball Tree (1.25×) |
| 20 | 1.2 ms | 0.8 ms | Ball Tree (1.5×) |
| 50 | 15 ms | 8 ms | Ball Tree (1.9×) |
| 100 | 95 ms | 75 ms | Ball Tree (1.3×) |
Both KD-trees and Ball trees suffer from the curse of dimensionality. In very high dimensions (d > 100), both degrade to near-linear behavior. Ball trees just degrade slightly more gracefully. For truly high-dimensional data, approximate methods (LSH, HNSW) are necessary.
One of Ball trees' most powerful features is their applicability to general metric spaces. Unlike KD-trees, which require vector representations, Ball trees only need a distance function.
Definition: Metric Space
A metric space $(X, d)$ consists of a set $X$ and a distance function $d: X \times X \to \mathbb{R}_{\geq 0}$ satisfying:
Why the Triangle Inequality Matters:
The Ball tree pruning criterion relies on:
$$d_{\min}(q, \text{ball}) = \max(0, d(q, c) - R)$$
This is valid because for any point $p$ in the ball: $$d(q, p) \geq d(q, c) - d(c, p) \geq d(q, c) - R$$
This uses the triangle inequality! Without it, pruning would be unsound.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
import numpy as npfrom typing import Callable, List, Any class MetricBallTree: """ Ball tree for general metric spaces. Works with any distance function satisfying the metric axioms. Points can be any hashable objects, not just vectors. """ def __init__( self, points: List[Any], distance: Callable[[Any, Any], float], leaf_size: int = 20 ): """ Build a Ball tree for arbitrary metric space. Parameters: ----------- points : List[Any] Objects to index (strings, graphs, etc.) distance : Callable Metric distance function d(x, y) -> float leaf_size : int Maximum leaf size """ self.distance = distance self.leaf_size = leaf_size self.points = points self.root = self._build(list(range(len(points)))) def _build(self, indices: List[int]) -> dict: """Recursive construction.""" if len(indices) == 0: return None # Compute bounding ball # For general metrics, use a point as center # (centroid doesn't exist in general metric spaces) center_idx = indices[0] center = self.points[center_idx] # Radius = max distance to center radius = max( self.distance(center, self.points[i]) for i in indices ) if len(indices) <= self.leaf_size: return { 'center_idx': center_idx, 'center': center, 'radius': radius, 'indices': indices, 'left': None, 'right': None } # Split: find farthest point from center, partition by closer pivot farthest_idx = max( indices, key=lambda i: self.distance(center, self.points[i]) ) pivot1, pivot2 = center_idx, farthest_idx left_indices = [] right_indices = [] for i in indices: d1 = self.distance(self.points[pivot1], self.points[i]) d2 = self.distance(self.points[pivot2], self.points[i]) if d1 <= d2: left_indices.append(i) else: right_indices.append(i) return { 'center_idx': center_idx, 'center': center, 'radius': radius, 'indices': None, 'left': self._build(left_indices), 'right': self._build(right_indices) } def query(self, q: Any, k: int = 1) -> List[tuple]: """Find k nearest neighbors to query object.""" import heapq heap = [] # max-heap: (-dist, idx) self._search(self.root, q, k, heap) return sorted([(-d, i) for d, i in heap]) def _search(self, node: dict, q: Any, k: int, heap: list): if node is None: return # Pruning check d_center = self.distance(q, node['center']) d_min = max(0, d_center - node['radius']) kth_best = -heap[0][0] if len(heap) >= k else float('inf') if d_min >= kth_best: return # PRUNE if node['left'] is None: # Leaf for i in node['indices']: d = self.distance(q, self.points[i]) if len(heap) < k: heapq.heappush(heap, (-d, i)) elif d < -heap[0][0]: heapq.heapreplace(heap, (-d, i)) else: # Visit closer child first d_left = self.distance(q, node['left']['center']) if node['left'] else float('inf') d_right = self.distance(q, node['right']['center']) if node['right'] else float('inf') if d_left <= d_right: self._search(node['left'], q, k, heap) self._search(node['right'], q, k, heap) else: self._search(node['right'], q, k, heap) self._search(node['left'], q, k, heap) # Example: String similarity search using edit distancedef levenshtein_distance(s1: str, s2: str) -> int: """Compute Levenshtein (edit) distance between strings.""" if len(s1) < len(s2): return levenshtein_distance(s2, s1) if len(s2) == 0: return len(s1) prev_row = range(len(s2) + 1) for i, c1 in enumerate(s1): curr_row = [i + 1] for j, c2 in enumerate(s2): insertions = prev_row[j + 1] + 1 deletions = curr_row[j] + 1 substitutions = prev_row[j] + (c1 != c2) curr_row.append(min(insertions, deletions, substitutions)) prev_row = curr_row return prev_row[-1] # Usage:# words = ["apple", "application", "apply", "banana", "bandana", ...]# tree = MetricBallTree(words, levenshtein_distance)# similar = tree.query("aple", k=5) # Find 5 most similar stringsBall trees belong to a family of metric space data structures. Two important relatives are:
Vantage-Point Trees (VP-Trees):
VP-trees use a different partitioning scheme:
This creates complementary partitions: a close ball and a far shell.
VP-Tree Pruning:
Given query $q$, current best distance $r$, and vantage point $v$ with threshold $\mu$:
Cover Trees:
Cover trees provide stronger theoretical guarantees through a multi-resolution hierarchy:
Cover trees have the best theoretical guarantees for intrinsically low-dimensional data in high-dimensional ambient space.
| Structure | Partition Style | Build Time | Query Time | Best For |
|---|---|---|---|---|
| Ball Tree | Nested balls | $O(n \log n)$ | $O(\log n)$* | General use |
| VP-Tree | Ball + shell | $O(n \log n)$ | $O(\log n)$* | Varies by data |
| Cover Tree | Multi-resolution | $O(n \log n)$ | $O(c^{12} \log n)$ | Low intrinsic dim |
| Metric Skip List | Probabilistic | $O(n \log n)$ | $O(\log n)$ expected | Dynamic data |
In practice, Ball trees and VP-trees perform similarly for most applications. Cover trees have better worst-case guarantees but higher constants. For production systems, start with scikit-learn's BallTree implementation and only switch if performance is insufficient.
What's Next:
We've seen how KD-trees and Ball trees partition space to enable efficient search. But there's a beautiful theoretical foundation underlying nearest neighbor search: Voronoi diagrams. The next page explores how Voronoi tessellation provides the theoretical optimum for 2D/3D exact search and illuminates the geometry of KNN decision boundaries.
You now understand Ball trees: their spherical partitioning strategy, construction and search algorithms, and application to general metric spaces. This completes our coverage of exact tree-based methods. Next, we explore Voronoi diagrams for a geometric perspective on nearest neighbor regions.