Loading content...
Decision trees are not designed by hand—they are grown from data through an elegant algorithmic process called recursive partitioning. This process embodies the classic divide-and-conquer paradigm: split the problem into smaller subproblems, solve each recursively, and combine the results.
At the heart of recursive partitioning is a beautifully simple idea: given a set of samples at a node, find the best way to split them into two groups, then repeat on each group until some stopping condition is met. This recursive refinement naturally builds the hierarchical tree structure we've studied.
Understanding recursive partitioning is essential because:
By the end of this page, you will master: (1) the recursive tree-growing algorithm in complete detail, (2) the role of stopping conditions in controlling growth, (3) the depth-first construction process, (4) computational complexity analysis, and (5) the relationship between recursion depth and tree properties.
The essence of recursive partitioning is captured in a deceptively simple recursive function. Let's build it step by step.
function BuildTree(samples, depth):
// Create a node for these samples
node = CreateNode(samples)
// Check stopping conditions
if ShouldStop(samples, depth):
node.prediction = ComputeLeafPrediction(samples)
return node // This is a leaf
// Find the best split
(feature, threshold) = FindBestSplit(samples)
if NoGoodSplitFound(feature, threshold):
node.prediction = ComputeLeafPrediction(samples)
return node // No improvement possible; make leaf
// Partition samples
left_samples = {s ∈ samples : s.x[feature] ≤ threshold}
right_samples = {s ∈ samples : s.x[feature] > threshold}
// Recurse on children
node.feature = feature
node.threshold = threshold
node.left = BuildTree(left_samples, depth + 1)
node.right = BuildTree(right_samples, depth + 1)
return node
// Entry point
root = BuildTree(all_training_samples, depth=0)
The algorithm's power comes from its recursive structure:
Each recursive call operates on a subset of the parent's samples. The samples are never copied—typically indices are passed, pointing into the original dataset.
Consider a tiny dataset with 100 samples:
BuildTree(samples[0:100], depth=0) — Root node, all samples
BuildTree(samples[0:40], depth=1) — Left child
BuildTree(samples[0:15], depth=2) — Left-left grandchild
Returns continue, building right subtrees...
The algorithm proceeds depth-first, fully constructing the left subtree before starting the right subtree.
Standard implementations use depth-first construction (recursion). This is memory-efficient—only one path from root to current node is on the stack. Breadth-first construction (level-by-level) is possible but uses more memory as it must track all nodes at the current level. Some implementations use best-first growth (expand the most impure leaf next), which can find important structure faster but is more complex.
Without stopping conditions, recursive partitioning would continue until every leaf contains a single sample—massive overfitting. Stopping conditions (also called pre-pruning or early stopping) control tree complexity during growth.
1. Maximum Depth (max_depth)
if depth >= max_depth:
return CreateLeaf(samples)
The most direct complexity control. Limits the tree to at most $2^{\text{max_depth}}$ leaves.
2. Minimum Samples to Split (min_samples_split)
if len(samples) < min_samples_split:
return CreateLeaf(samples)
Nodes with fewer samples than this threshold become leaves. Prevents tiny splits.
3. Minimum Samples per Leaf (min_samples_leaf)
if len(left_samples) < min_samples_leaf OR len(right_samples) < min_samples_leaf:
skip this split (try next best, or make leaf if none valid)
Ensures each child will have sufficient samples. Often more important than min_samples_split.
4. Minimum Impurity Decrease (min_impurity_decrease)
if gain < min_impurity_decrease:
return CreateLeaf(samples)
Only split if the improvement exceeds a threshold. Prevents marginal/noise-driven splits.
5. Pure Node
if all samples have the same label (classification) OR
variance(targets) ≈ 0 (regression):
return CreateLeaf(samples)
No split can improve a pure node. This is always checked in practice.
6. Maximum Leaf Nodes (max_leaf_nodes)
// Global constraint across entire tree
if current_leaf_count >= max_leaf_nodes:
stop growing
Direct control over tree complexity. Requires best-first growth to use effectively (otherwise, depth-first would fill leaves unevenly).
7. Maximum Features (max_features)
Not a stopping condition per se, but limits features considered at each split:
features_to_try = RandomSample(all_features, max_features)
best_split = FindBestSplit(samples, features_to_try)
Used in Random Forests to decorrelate trees. Can lead to earlier stopping if no good split is found among sampled features.
| Parameter | Typical Default | Effect if Increased | Effect if Decreased |
|---|---|---|---|
max_depth | None (∞) | Less overfitting, simpler tree | More complex, higher variance |
min_samples_split | 2 | Fewer splits, smaller tree | More splits, larger tree |
min_samples_leaf | 1 | Larger leaves, more stable | Smaller leaves, more variance |
min_impurity_decrease | 0.0 | Fewer low-value splits | More splits (including noise) |
max_leaf_nodes | None (∞) | Explicit size limit | Larger tree allowed |
Start with max_depth control (try 3, 5, 10) as it's most intuitive. Add min_samples_leaf (try 5, 10, 20) for stability. Use cross-validation to tune. For Random Forests, defaults often work well since ensemble averaging handles variance. For single trees, aggressive stopping is usually needed.
Let's present the full recursive partitioning algorithm with all practical details included.
class DecisionTree:
def __init__(self, max_depth, min_samples_split, min_samples_leaf,
min_impurity_decrease, max_features):
self.config = store all hyperparameters
self.root = None
def fit(self, X, y):
self.n_features = X.shape[1]
self.n_classes = len(unique(y)) // for classification
self.root = self._build_tree(X, y, sample_indices=range(len(y)), depth=0)
def _build_tree(self, X, y, sample_indices, depth):
node = Node()
node.n_samples = len(sample_indices)
node.impurity = compute_impurity(y[sample_indices])
// Compute prediction value for this node (used if it becomes a leaf)
node.value = compute_prediction(y[sample_indices])
// Check stopping conditions
if self._should_stop(node, depth, y[sample_indices]):
node.is_leaf = True
return node
// Find best split
best_feature, best_threshold, best_gain = self._find_best_split(
X, y, sample_indices
)
if best_gain <= self.config.min_impurity_decrease:
node.is_leaf = True
return node
// Perform the split
left_indices, right_indices = self._partition(
X, sample_indices, best_feature, best_threshold
)
// Check min_samples_leaf constraint
if len(left_indices) < self.config.min_samples_leaf or
len(right_indices) < self.config.min_samples_leaf:
node.is_leaf = True
return node
// Record split information
node.is_leaf = False
node.feature = best_feature
node.threshold = best_threshold
// Recurse
node.left = self._build_tree(X, y, left_indices, depth + 1)
node.right = self._build_tree(X, y, right_indices, depth + 1)
return node
def _should_stop(self, node, depth, labels):
if depth >= self.config.max_depth:
return True
if node.n_samples < self.config.min_samples_split:
return True
if node.impurity == 0: // Pure node
return True
return False
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
import numpy as npfrom dataclasses import dataclassfrom typing import Optional, Tuple @dataclassclass Node: """Decision tree node.""" n_samples: int = 0 impurity: float = 0.0 value: Optional[np.ndarray] = None # Class counts or mean is_leaf: bool = False feature: int = -1 threshold: float = 0.0 left: Optional['Node'] = None right: Optional['Node'] = None class SimpleDecisionTree: """Simplified decision tree for educational purposes.""" def __init__(self, max_depth=5, min_samples_split=2, min_samples_leaf=1): self.max_depth = max_depth self.min_samples_split = min_samples_split self.min_samples_leaf = min_samples_leaf self.root = None self.n_classes = 0 def fit(self, X: np.ndarray, y: np.ndarray) -> 'SimpleDecisionTree': """Fit the decision tree using recursive partitioning.""" self.n_classes = len(np.unique(y)) sample_indices = np.arange(len(y)) print("Starting recursive partitioning...") print("="*50) self.root = self._build_tree(X, y, sample_indices, depth=0) print("="*50) print(f"Tree complete! Total nodes: {self._count_nodes(self.root)}") print(f"Total leaves: {self._count_leaves(self.root)}") return self def _build_tree(self, X, y, indices, depth) -> Node: """Recursively build the tree.""" indent = " " * depth n_samples = len(indices) # Create node and compute statistics node = Node() node.n_samples = n_samples node.value = np.bincount(y[indices], minlength=self.n_classes) node.impurity = self._gini(y[indices]) print(f"{indent}Node at depth {depth}: {n_samples} samples, " f"impurity={node.impurity:.3f}, class_dist={node.value}") # Check stopping conditions if self._should_stop(node, depth, y[indices]): node.is_leaf = True print(f"{indent} -> LEAF (prediction: class {np.argmax(node.value)})") return node # Find best split best_feat, best_thresh, best_gain = self._find_best_split(X, y, indices) if best_gain <= 0: node.is_leaf = True print(f"{indent} -> LEAF (no good split found)") return node # Partition the data left_mask = X[indices, best_feat] <= best_thresh left_indices = indices[left_mask] right_indices = indices[~left_mask] # Check min_samples_leaf if len(left_indices) < self.min_samples_leaf or \ len(right_indices) < self.min_samples_leaf: node.is_leaf = True print(f"{indent} -> LEAF (min_samples_leaf constraint)") return node print(f"{indent} -> SPLIT on feature {best_feat} <= {best_thresh:.3f} " f"(gain={best_gain:.4f})") print(f"{indent} Left: {len(left_indices)}, Right: {len(right_indices)}") # Record split and recurse node.is_leaf = False node.feature = best_feat node.threshold = best_thresh node.left = self._build_tree(X, y, left_indices, depth + 1) node.right = self._build_tree(X, y, right_indices, depth + 1) return node def _should_stop(self, node, depth, labels) -> bool: """Check stopping conditions.""" if depth >= self.max_depth: return True if node.n_samples < self.min_samples_split: return True if node.impurity == 0: # Pure node return True return False def _gini(self, y: np.ndarray) -> float: """Compute Gini impurity.""" if len(y) == 0: return 0.0 proportions = np.bincount(y, minlength=self.n_classes) / len(y) return 1.0 - np.sum(proportions ** 2) def _find_best_split(self, X, y, indices) -> Tuple[int, float, float]: """Find the best split for given samples.""" best_gain = -np.inf best_feature = -1 best_threshold = 0.0 parent_impurity = self._gini(y[indices]) n_samples = len(indices) for feature in range(X.shape[1]): feature_values = X[indices, feature] thresholds = np.unique(feature_values) for i in range(len(thresholds) - 1): threshold = (thresholds[i] + thresholds[i + 1]) / 2 left_mask = feature_values <= threshold left_y = y[indices[left_mask]] right_y = y[indices[~left_mask]] if len(left_y) == 0 or len(right_y) == 0: continue # Weighted impurity of children weighted_impurity = ( (len(left_y) / n_samples) * self._gini(left_y) + (len(right_y) / n_samples) * self._gini(right_y) ) gain = parent_impurity - weighted_impurity if gain > best_gain: best_gain = gain best_feature = feature best_threshold = threshold return best_feature, best_threshold, best_gain def _count_nodes(self, node) -> int: if node is None: return 0 return 1 + self._count_nodes(node.left) + self._count_nodes(node.right) def _count_leaves(self, node) -> int: if node is None: return 0 if node.is_leaf: return 1 return self._count_leaves(node.left) + self._count_leaves(node.right) # Demonstrationnp.random.seed(42)X = np.random.randn(50, 3)y = (X[:, 0] + X[:, 1] > 0).astype(int) tree = SimpleDecisionTree(max_depth=3, min_samples_split=5, min_samples_leaf=3)tree.fit(X, y)The recursive nature of tree building has practical implications for memory and system limits.
Each recursive call to _build_tree adds a frame to the call stack:
Typical limits:
Stack frame contents:
For a tree of depth $h$, peak stack usage is $O(h)$ frames.
To avoid stack overflow for very deep trees, implementations can use explicit stacks:
def build_tree_iterative(X, y):
root = create_pending_node(all_indices, depth=0, parent=None, is_left=None)
stack = [root]
while stack:
node = stack.pop()
if should_stop(node):
finalize_as_leaf(node)
continue
feature, threshold = find_best_split(node.indices)
if no_good_split:
finalize_as_leaf(node)
continue
left_indices, right_indices = partition(node.indices, feature, threshold)
node.feature = feature
node.threshold = threshold
node.left = create_pending_node(left_indices, node.depth+1, node, True)
node.right = create_pending_node(right_indices, node.depth+1, node, False)
stack.append(node.right) # Push right first (LIFO = depth-first left)
stack.append(node.left)
return root
This achieves the same depth-first traversal without recursive calls.
While rare, training on pathological data (e.g., each sample unique, no stopping conditions) could theoretically create trees of depth N for N samples, causing stack overflow. Production libraries like scikit-learn use depth limits and iterative implementations to prevent this. Always set a reasonable max_depth as a safety measure.
Understanding the computational cost of recursive partitioning helps predict training time and identify bottlenecks.
At each node with $n$ samples:
Total per-node: $O(d \cdot n \log n)$
The dominant cost is finding the best split.
Key observation: Each sample passes through exactly one node at each level. If the tree has height $h$, each sample is processed in $O(d \log n)$ per level, across $h$ levels.
Balanced tree case (height $h \approx \log N$):
$$T_{\text{total}} = O\left( \sum_{\text{level }l=0}^{h} \text{(work at level } l) \right)$$
At each level, total samples across all nodes is $N$. Work per sample: $O(d \log N)$ (amortized sorting cost).
$$T_{\text{total}} = O(N \cdot d \cdot \log N \cdot h) = O(N \cdot d \cdot \log^2 N)$$
Unbalanced tree case (height $h \approx N$):
In the worst case (chain-like tree): $$T_{\text{total}} = O(N^2 \cdot d \cdot \log N)$$
This is why max_depth limits are important for training efficiency as well as generalization.
| Scenario | Time Complexity | Space Complexity |
|---|---|---|
| Balanced tree | $O(N \cdot d \cdot \log^2 N)$ | $O(N + L)$ where $L$ = leaves |
| Unbalanced tree | $O(N^2 \cdot d \cdot \log N)$ | $O(N)$ |
| Max depth = $k$ | $O(N \cdot d \cdot k \cdot \log N)$ | $O(2^k)$ nodes max |
| Prediction (one sample) | $O(h)$ where $h$ = height | $O(1)$ |
| Prediction (batch of $M$) | $O(M \cdot h)$ | $O(1)$ |
During training:
Final tree:
For a tree with $L$ leaves, memory is $O(L \cdot C)$ where $C$ is the cost per node.
Production implementations employ several optimizations to accelerate recursive partitioning.
Instead of sorting at each node, sort once at the root:
# At root:
for each feature j:
sorted_indices[j] = argsort(X[:, j])
# At each node:
# Use sorted_indices to sweep through candidates without re-sorting
This converts $O(N \log N)$ sorting per node per feature to $O(N)$ traversal.
Trade-off: Requires $O(N \cdot d)$ storage for sorted indices.
Bin continuous features into discrete buckets:
# Preprocessing:
bins[j] = create_bins(X[:, j], max_bins=256)
X_binned[:, j] = digitize(X[:, j], bins[j])
# At each node:
# Build histogram of gradient sums per bin (O(n))
# Find best split by scanning histogram (O(max_bins))
This reduces $O(n)$ per-feature to $O(\max_bins)$, dramatically faster for large $n$.
Splits for different features are independent:
# Parallel over features
best_splits = parallel_map(features, lambda j: find_best_split_for_feature(j))
best_overall = max(best_splits, key=lambda s: s.gain)
With $P$ processors and $d$ features, speedup is $\min(P, d)$.
For computing impurity statistics, maintain running sums:
# Instead of recomputing from scratch:
class_counts = [count(y == k) for k in classes] # O(n)
# Incrementally update:
class_counts[y[moved_sample]] -= 1 # O(1)
class_counts_other[y[moved_sample]] += 1 # O(1)
The sorted sweep algorithm exploits this for $O(1)$ instead of $O(n)$ per candidate threshold.
Libraries like XGBoost, LightGBM, and CatBoost use histogram-based algorithms that are orders of magnitude faster than naive implementations. They achieve near-linear scaling with data size. For large datasets (>100k samples), these optimizations are essential—vanilla recursive partitioning would be impractically slow.
The recursive partitioning process creates a specific kind of partition with well-defined mathematical properties.
At each split: $$|\mathcal{D}{\text{left}}| + |\mathcal{D}{\text{right}}| = |\mathcal{D}_{\text{parent}}|$$
No samples are lost or duplicated. The partition is exhaustive and disjoint.
After a split, the weighted average impurity of children is never greater than the parent's impurity:
$$\frac{n_L}{n} \mathcal{I}(\mathcal{D}_L) + \frac{n_R}{n} \mathcal{I}(\mathcal{D}_R) \leq \mathcal{I}(\mathcal{D})$$
This follows from the convexity of impurity measures (Gini, entropy).
Consequence: Impurity can only decrease (or stay constant) as we descend the tree. Leaves have minimal impurity given the data that reached them.
Every sample follows exactly one path from root to leaf. The path is determined entirely by the sample's feature values and the thresholds along the way.
$$\text{path}(\mathbf{x}) = (r, v_1, v_2, \ldots, \ell)$$
Where each transition $v_i \to v_{i+1}$ is determined by: $$v_{i+1} = \begin{cases} v_i.\text{left} & \text{if } x_{j(v_i)} \leq \theta(v_i) \ v_i.\text{right} & \text{otherwise} \end{cases}$$
Each level of the tree refines the previous level's partition:
The partition becomes progressively finer as depth increases.
We have thoroughly examined the recursive partitioning algorithm that builds decision trees. Let us consolidate the key insights:
What's next:
With the algorithm understood, we now turn to one of decision trees' most valuable properties: interpretability. The next page explores why decision trees are uniquely explainable, how to extract rules and insights, and the tradeoffs between model complexity and human understanding.
You now have a complete understanding of how decision trees are grown through recursive partitioning. This algorithmic foundation enables you to understand training behavior, predict computational costs, and appreciate why hyperparameter choices matter. Next, we explore the interpretability that makes trees uniquely valuable in many applications.