Loading learning content...
Binary search is one of the most elegant algorithms in computer science: by repeatedly halving a sorted list, we transform O(n) search into O(log n). But binary search works on one-dimensional data. How do we extend this idea to multidimensional points living in $\mathbb{R}^d$?
The KD-tree (k-dimensional tree), invented by Jon Bentley in 1975, is the answer. It generalizes binary search to arbitrary dimensions by recursively partitioning space using axis-aligned hyperplanes. Just as binary search bisects an interval, a KD-tree bisects space—alternating between dimensions to create a balanced partition.
KD-trees power nearest neighbor search in everything from computer graphics collision detection to geographic information systems. Understanding them deeply provides insights that transfer to other spatial data structures.
By the end of this page, you will understand how KD-trees partition space, be able to implement both construction and search algorithms, analyze their time and space complexity, understand when they excel and when they fail, and recognize the connection to binary search and decision trees.
Before diving into algorithms, let's understand the geometric intuition behind KD-trees.
The Problem with Sorting in Multiple Dimensions:
In one dimension, there's a natural ordering: 1 < 2 < 3. We can sort points and use binary search. But in two or more dimensions, there's no single natural order. Is (3, 1) greater than (1, 3)? It depends on which dimension you prioritize.
The KD-Tree Solution:
Instead of defining a total order, KD-trees create a spatial hierarchy through recursive partitioning:
This creates a binary tree where:
Visualizing the Partition (2D Example):
Consider 8 points in 2D space. A KD-tree partitions them as follows:
Level 0 (X-axis): Split at median x-coordinate
|
Level 1 (Y-axis): | Split each half at median y-coordinate
| | |
Level 2 (X-axis): | | | | Each quarter split by x
| | | | | | | |
Level 3 (leaves): Each octant contains 1 point
Each split halves the remaining points, achieving $O(\log n)$ depth for $n$ points (when balanced). A query point follows the tree structure, eliminating half of remaining candidates at each level—just like binary search, but in multiple dimensions.
Cycling through dimensions ensures balanced partitioning across all axes. If we always split on the same dimension, the tree would only partition one axis, leaving other dimensions unsorted. Cycling distributes the partitioning effort evenly, which is crucial for query efficiency.
Let's formalize the KD-tree data structure. Each node in a KD-tree contains:
Node Components:
| Component | Description | Purpose |
|---|---|---|
point | The data point stored at this node | Reference for splitting and result |
split_dim | The dimension used for splitting | Determines which coordinate to compare |
split_value | The value of point[split_dim] | The partition threshold |
left | Subtree with points where x[split_dim] ≤ split_value | Left partition |
right | Subtree with points where x[split_dim] > split_value | Right partition |
Key Invariant:
For any node $v$ splitting on dimension $i$:
left(v) satisfy: $x_i \leq v.\text{split_value}$right(v) satisfy: $x_i > v.\text{split_value}$12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
import numpy as npfrom dataclasses import dataclassfrom typing import Optional, List, Tuple @dataclassclass KDNode: """ A node in a KD-tree. Attributes: ----------- point : np.ndarray The d-dimensional point stored at this node split_dim : int The dimension used for partitioning at this node split_value : float The splitting threshold (value of point[split_dim]) left : Optional[KDNode] Left subtree (points with x[split_dim] <= split_value) right : Optional[KDNode] Right subtree (points with x[split_dim] > split_value) index : int Original index of this point in the dataset """ point: np.ndarray split_dim: int split_value: float left: Optional['KDNode'] = None right: Optional['KDNode'] = None index: int = -1 def is_leaf(self) -> bool: """Check if this node is a leaf (no children).""" return self.left is None and self.right is None @property def dimension(self) -> int: """Return the dimensionality of points in this tree.""" return len(self.point) @dataclassclass KDTree: """ KD-Tree for efficient nearest neighbor search. Attributes: ----------- root : Optional[KDNode] Root node of the tree n_points : int Total number of points in the tree n_dims : int Dimensionality of the data """ root: Optional[KDNode] n_points: int n_dims: int def __repr__(self) -> str: return f"KDTree(n_points={self.n_points}, n_dims={self.n_dims})"For memory efficiency, production implementations often use implicit tree representations where the tree structure is encoded in array indices rather than explicit pointers. This improves cache locality and reduces memory overhead, but makes the logic harder to follow. We use explicit nodes here for clarity.
Constructing a KD-tree involves recursively partitioning the data. The key decisions are:
Dimension Selection Strategies:
| Strategy | Description | Trade-off |
|---|---|---|
| Round-robin | Cycle through dimensions 0, 1, ..., d-1, 0, ... | Simple, balanced depth |
| Maximum variance | Split on dimension with highest spread | Better partitioning, O(d) per node |
| Maximum extent | Split on dimension with largest range | Similar to variance, simpler |
Split Point Selection:
| Strategy | Description | Trade-off |
|---|---|---|
| Median | Use the median point along split dimension | Balanced tree, O(n) selection |
| Mean | Use the mean value | Fast (O(1) if precomputed), can be unbalanced |
| Midpoint | Use midpoint of range | Fast, can be very unbalanced |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
import numpy as npfrom typing import Optional, List def build_kd_tree( points: np.ndarray, indices: Optional[np.ndarray] = None, depth: int = 0, split_strategy: str = "round_robin") -> Optional[KDNode]: """ Build a KD-tree from a set of points. Time Complexity: O(n log n) with median selection Space Complexity: O(n) for the tree structure Parameters: ----------- points : np.ndarray Array of points, shape (n, d) indices : np.ndarray, optional Original indices of points (for tracking) depth : int Current depth in the tree (for dimension cycling) split_strategy : str How to select split dimension: "round_robin", "max_variance", "max_extent" Returns: -------- KDNode or None Root of the constructed (sub)tree """ n, d = points.shape # Base case: no points if n == 0: return None # Initialize indices if not provided if indices is None: indices = np.arange(n) # Select splitting dimension if split_strategy == "round_robin": split_dim = depth % d elif split_strategy == "max_variance": # Split on dimension with maximum variance variances = np.var(points, axis=0) split_dim = int(np.argmax(variances)) elif split_strategy == "max_extent": # Split on dimension with maximum range extents = np.ptp(points, axis=0) # peak-to-peak split_dim = int(np.argmax(extents)) else: split_dim = depth % d # Base case: single point if n == 1: return KDNode( point=points[0], split_dim=split_dim, split_value=points[0, split_dim], index=indices[0] ) # Find median along split dimension # O(n) using selection algorithm, or O(n log n) using sort sorted_idx = np.argsort(points[:, split_dim]) median_idx = n // 2 # Handle even-sized arrays: median_idx is right-of-center # This ensures left subtree has <= ceil(n/2) - 1 points median_global_idx = sorted_idx[median_idx] median_point = points[median_global_idx] # Partition points left_mask = sorted_idx[:median_idx] right_mask = sorted_idx[median_idx + 1:] # Recursively build subtrees left_child = build_kd_tree( points[left_mask], indices[left_mask], depth + 1, split_strategy ) if len(left_mask) > 0 else None right_child = build_kd_tree( points[right_mask], indices[right_mask], depth + 1, split_strategy ) if len(right_mask) > 0 else None return KDNode( point=median_point, split_dim=split_dim, split_value=median_point[split_dim], left=left_child, right=right_child, index=indices[median_global_idx] ) def build_kd_tree_wrapper(points: np.ndarray) -> KDTree: """ Public interface for KD-tree construction. """ n, d = points.shape root = build_kd_tree(points) return KDTree(root=root, n_points=n, n_dims=d)Complexity Analysis:
Time Complexity:
At each level of the tree:
With $O(\log n)$ levels and $O(n)$ work per level:
$$T(n) = O(n) + 2T(n/2) = O(n \log n)$$
This is the same recurrence as merge sort, giving:
$$\boxed{T_{\text{build}} = O(n \log n)}$$
Space Complexity:
Each of the $n$ points becomes a node:
$$\boxed{S_{\text{tree}} = O(n)}$$
Note: The tree doesn't use $O(nd)$ space because points are stored once (or referenced by index), not duplicated.
The nearest neighbor search algorithm is where KD-trees truly shine. The key insight is branch-and-bound pruning: we can often eliminate entire subtrees without examining their contents.
Algorithm Overview:
The Pruning Condition:
At a node splitting on dimension $i$ with value $s$:
Let $d_{\text{best}}$ be the distance to our current best candidate. Let $d_{\text{plane}} = |q_i - s|$ be the query's distance to the splitting hyperplane.
If $d_{\text{plane}} \geq d_{\text{best}}$, then all points on the other side of the hyperplane are further than our current best. We can prune that entire subtree.
This is because the closest possible point on the other side is at distance $d_{\text{plane}}$, and all actual points are at least that far.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
import numpy as npfrom typing import Tuple, Optionalimport heapq def nn_search( node: Optional[KDNode], query: np.ndarray, best: Tuple[float, Optional[KDNode]] = (float('inf'), None)) -> Tuple[float, Optional[KDNode]]: """ Find nearest neighbor in KD-tree using branch-and-bound. Time Complexity: O(log n) average, O(n) worst case Parameters: ----------- node : KDNode Current node in the search query : np.ndarray Query point best : Tuple[float, KDNode] Current best (distance, node) pair Returns: -------- Tuple[float, KDNode] Distance and node of nearest neighbor """ if node is None: return best best_dist, best_node = best # Compute distance from query to current node's point curr_dist = np.sqrt(np.sum((node.point - query) ** 2)) # Update best if current node is closer if curr_dist < best_dist: best_dist = curr_dist best_node = node # Determine which subtree to search first split_dim = node.split_dim split_val = node.split_value if query[split_dim] <= split_val: first, second = node.left, node.right else: first, second = node.right, node.left # Search the subtree containing the query first best_dist, best_node = nn_search(first, query, (best_dist, best_node)) # Check if we need to search the other subtree # Distance from query to the splitting hyperplane dist_to_plane = abs(query[split_dim] - split_val) if dist_to_plane < best_dist: # The other subtree might contain closer points best_dist, best_node = nn_search(second, query, (best_dist, best_node)) # else: PRUNED! The entire other subtree is guaranteed to be farther return best_dist, best_node def knn_search( node: Optional[KDNode], query: np.ndarray, k: int, heap: list = None, depth: int = 0) -> list: """ Find K nearest neighbors in KD-tree. Uses a max-heap of size k to track the k closest points found so far. The max-heap allows O(1) access to the farthest of the k candidates, enabling efficient pruning. Parameters: ----------- node : KDNode Current node in the search query : np.ndarray Query point k : int Number of neighbors to find heap : list Max-heap of (-distance, node) pairs Returns: -------- list of (distance, node) pairs The k nearest neighbors """ if heap is None: heap = [] if node is None: return heap # Compute distance from query to current node's point dist = np.sqrt(np.sum((node.point - query) ** 2)) # Update heap if len(heap) < k: # Heap not full, add directly heapq.heappush(heap, (-dist, id(node), node)) elif dist < -heap[0][0]: # Current node closer than farthest in heap heapq.heapreplace(heap, (-dist, id(node), node)) # Determine search order split_dim = node.split_dim split_val = node.split_value if query[split_dim] <= split_val: first, second = node.left, node.right else: first, second = node.right, node.left # Search first subtree knn_search(first, query, k, heap) # Pruning check: distance to splitting plane dist_to_plane = abs(query[split_dim] - split_val) # Only prune if heap is full AND plane is farther than k-th closest if len(heap) < k or dist_to_plane < -heap[0][0]: knn_search(second, query, k, heap) return heap def query_kd_tree(tree: KDTree, query: np.ndarray, k: int = 1): """ Public interface for KNN query. """ if k == 1: dist, node = nn_search(tree.root, query) return [(dist, node.point, node.index)] else: heap = knn_search(tree.root, query, k) results = [(-d, node.point, node.index) for d, _, node in heap] return sorted(results) # Sort by distance ascendingWe always search the subtree containing the query first. This finds a good candidate quickly, making the pruning condition more likely to succeed when we consider the other subtree. Searching in the wrong order degrades to near-linear performance.
Understanding when KD-trees provide logarithmic search versus when they degrade is crucial for practical application.
Best Case: O(log n)
When pruning is effective, the search visits $O(\log n)$ nodes:
Total: $O(\log n)$ distance computations, each $O(d)$:
$$T_{\text{best}} = O(d \log n)$$
Worst Case: O(n)
When pruning fails everywhere, we visit all $n$ nodes:
$$T_{\text{worst}} = O(dn)$$
This is no better than brute force!
When Does Pruning Fail?
Pruning fails when $d_{\text{plane}} < d_{\text{best}}$, meaning the query is close to many splitting hyperplanes. This happens when:
The Curse of Dimensionality for KD-Trees:
Friedman, Bentley, and Finkel (1977) showed that the expected number of nodes visited is:
$$E[\text{nodes visited}] \approx 2^d \cdot \left(\frac{1}{n}\right)^{1-1/d} \cdot n$$
For this to be sublinear in $n$, we need:
$$2^d \cdot n^{-1/d} \ll 1$$
This is satisfied roughly when:
$$d \lesssim \log n$$
Rule of Thumb: KD-trees provide logarithmic search only when:
$$\boxed{d \lesssim 2\log_2 n}$$
| $n$ | Max effective $d$ |
|---|---|
| 1,000 | ~20 |
| 10,000 | ~26 |
| 100,000 | ~33 |
| 1,000,000 | ~40 |
| Dimension (d) | Avg. Nodes Visited | Query Time | vs Brute Force |
|---|---|---|---|
| 2 | ~15 | 0.02 ms | 250× faster |
| 5 | ~40 | 0.05 ms | 100× faster |
| 10 | ~200 | 0.25 ms | 20× faster |
| 20 | ~2,000 | 2.5 ms | 2× faster |
| 50 | ~30,000 | 35 ms | ~same |
| 100 | ~80,000 | 100 ms | slower (overhead) |
Several optimizations can improve KD-tree performance in practice, even if they don't change asymptotic complexity.
1. Leaf Size > 1 (Bucket KD-Tree)
Instead of splitting down to single points, stop when a node has ≤ leaf_size points and store them all in that leaf. During search, do brute-force on leaf contents.
leaf_size typically 10-4012345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
@dataclassclass BucketKDNode: """ KD-tree node that can store multiple points in leaves. """ points: np.ndarray = None # For leaf nodes: the actual points indices: np.ndarray = None # Original indices of points split_dim: int = -1 # For internal nodes split_value: float = 0.0 left: 'BucketKDNode' = None right: 'BucketKDNode' = None def is_leaf(self) -> bool: return self.points is not None def build_bucket_kd_tree( points: np.ndarray, indices: np.ndarray = None, depth: int = 0, leaf_size: int = 20) -> BucketKDNode: """ Build KD-tree with bucket leaves for improved cache performance. """ n, d = points.shape if indices is None: indices = np.arange(n) # Base case: create leaf if small enough if n <= leaf_size: return BucketKDNode(points=points.copy(), indices=indices.copy()) # Split on cycling dimension split_dim = depth % d # Find median sorted_idx = np.argsort(points[:, split_dim]) median_idx = n // 2 split_value = points[sorted_idx[median_idx], split_dim] # Partition left_mask = sorted_idx[:median_idx] right_mask = sorted_idx[median_idx:] return BucketKDNode( split_dim=split_dim, split_value=split_value, left=build_bucket_kd_tree( points[left_mask], indices[left_mask], depth + 1, leaf_size ), right=build_bucket_kd_tree( points[right_mask], indices[right_mask], depth + 1, leaf_size ) )2. Sliding Midpoint Split
Standard median split can create long thin cells in unbalanced data. The sliding midpoint variant:
This prevents empty partitions that waste tree depth.
3. Implied Bounding Boxes
Instead of storing explicit bounding boxes at each node, compute them on-the-fly during search. The root covers the entire data range; each split narrows one dimension. This saves memory and cache space.
4. Best-Bin-First (BBF) Search
For approximate nearest neighbor, use a priority queue to always explore the most promising unexplored subtree first. Limit total nodes visited to achieve guaranteed query time at the cost of exactness.
scikit-learn's KDTree uses bucket leaves (default leaf_size=40) and implicit bounding boxes. It also supports approximate search via the 'dualtree' algorithm for batch queries, which can be significantly faster than independent single queries.
How does the KD-tree compare to alternatives? The answer depends heavily on data characteristics.
| Method | Build Time | Query Time | Space | Best For |
|---|---|---|---|---|
| KD-Tree | $O(n \log n)$ | $O(\log n)$ to $O(n)$ | $O(n)$ | Low-dim exact NN |
| Ball Tree | $O(n \log n)$ | $O(\log n)$ to $O(n)$ | $O(n)$ | General metrics |
| Voronoi | $O(n \log n)$ | $O(\log n)$ | $O(n^{d})$ | 2D/3D exact NN |
| R-Tree | $O(n \log n)$ | $O(\log n)$ | $O(n)$ | Range queries |
| LSH | $O(n)$ | $O(1)$ to $O(n)$ | $O(n)$ | High-dim approx NN |
| HNSW | $O(n \log n)$ | $O(\log n)$ | $O(n)$ | High-dim approx NN |
Most modern ML embeddings (BERT: 768, ResNet: 2048, CLIP: 512) far exceed the effective dimensionality limit for KD-trees. This is why production systems like Spotify, Pinterest, and Google use approximate methods like HNSW or FAISS, which we cover in later pages.
We've explored KD-trees in depth—from their geometric foundations through construction and search algorithms to their practical limitations. Let's consolidate the key insights:
What's Next:
KD-trees partition space using axis-aligned hyperplanes, which is simple but not always optimal. The next page explores Ball Trees, which partition space using hyperspheres. Ball trees handle certain data distributions and distance metrics better than KD-trees, and understanding both illuminates the broader design space of spatial data structures.
You now understand how KD-trees work, when they're effective, and why they struggle in high dimensions. This foundation is essential for understanding Ball trees, Voronoi diagrams, and the approximate methods that power modern similarity search at scale. Next, we explore Ball trees and their hyperspherical partitioning approach.