Loading content...
The structure of a decision tree emerges from a sequence of splitting decisions. At each internal node, the algorithm must answer a seemingly simple question: Which feature should we split on, and at what threshold? This choice—repeated recursively—determines everything about the resulting tree: its depth, its accuracy, its interpretability, and its generalization ability.
The challenge is that the space of possible splits is enormous. For a dataset with $d$ features and $N$ samples, each node potentially has $O(d \times N)$ candidate splits to evaluate. Making the wrong choice early can doom the entire subtree to suboptimal performance. Making the right choice requires a principled measure of split quality—a way to quantify how much a split improves our predictions.
This page lays the foundation for understanding split selection. We will explore what makes a good split, how to enumerate candidate splits, and the exhaustive search algorithm that decision trees use to find optimal splits. This sets the stage for the next topic: impurity measures like Gini and entropy that quantify split quality.
By the end of this page, you will understand: (1) the mathematical formulation of the split selection problem, (2) how candidate splits are generated from training data, (3) the relationship between splits and data partitioning, (4) the exhaustive search approach for optimal split selection, and (5) the computational complexity of split finding.
At each internal node of a decision tree, we face an optimization problem: find the split that best separates the data according to some criterion. Let us formalize this precisely.
Given:
Find: $$ (j^, \theta^) = \arg\max_{j \in {1, \ldots, d}, \theta \in \Theta_j} \Delta(v, j, \theta) $$
Where $\Theta_j$ is the set of candidate thresholds for feature $j$.
The gain from a split measures how much it reduces impurity. For a split on feature $j$ at threshold $\theta$:
$$\Delta(v, j, \theta) = \mathcal{I}(\mathcal{D}_v) - \left( \frac{|\mathcal{D}_L|}{|\mathcal{D}_v|} \mathcal{I}(\mathcal{D}_L) + \frac{|\mathcal{D}_R|}{|\mathcal{D}_v|} \mathcal{I}(\mathcal{D}_R) \right)$$
Where:
The gain is the impurity of the current node minus the weighted average impurity of the children. A higher gain means a more valuable split.
Decision trees use greedy optimization—each split is chosen to maximize local gain without considering future splits. This is computationally tractable but not globally optimal. The globally optimal tree (minimizing overall error) is NP-complete to find. Greedy splitting is a heuristic that works remarkably well in practice.
Despite being locally optimal, greedy splitting often produces near-optimal trees because:
Good early splits simplify later decisions: A feature that perfectly separates two classes at the root eliminates the need for complex downstream splits
Recursive refinement corrects mistakes: Even if an early split is suboptimal, subsequent splits can compensate by focusing on mislabeled regions
Pruning post-hoc: After growing a full tree greedily, pruning can remove suboptimal subtrees, partially recovering from greedy errors
Ensemble methods average out errors: Random Forests and boosting combine many greedy trees, reducing the impact of any single suboptimal split
Before optimizing, we must enumerate the set of candidate splits. For continuous features, what thresholds should we consider?
For a continuous feature $x_j$, we could theoretically split at any real-valued threshold $\theta \in \mathbb{R}$. However, only finitely many distinct splits produce different partitions of the training data.
Key insight: If threshold $\theta$ falls between consecutive sorted feature values $x_{j,(i)}$ and $x_{j,(i+1)}$, the resulting partition is identical to splitting at any other point in that interval.
Consequence: We only need to consider thresholds at the midpoints between consecutive unique values:
$$\Theta_j = \left{ \frac{x_{j,(i)} + x_{j,(i+1)}}{2} : i \in {1, \ldots, m_j - 1} \right}$$
Where $x_{j,(1)} < x_{j,(2)} < \cdots < x_{j,(m_j)}$ are the sorted unique values of feature $j$ in the current node's data, and $m_j$ is the number of unique values.
function GenerateCandidateThresholds(feature_values):
sorted_unique = Sort(Unique(feature_values))
thresholds = []
for i in range(len(sorted_unique) - 1):
midpoint = (sorted_unique[i] + sorted_unique[i+1]) / 2
thresholds.append(midpoint)
return thresholds
Number of candidate thresholds per feature:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
import numpy as npfrom collections import defaultdict def generate_all_candidate_splits(X, node_samples): """ Generate all candidate splits for samples at current node. Parameters: ----------- X : np.ndarray of shape (n_samples, n_features) Full feature matrix node_samples : np.ndarray Indices of samples at the current node Returns: -------- candidates : list of (feature_idx, threshold) All candidate splits to evaluate """ X_node = X[node_samples] # Samples at this node n_features = X_node.shape[1] candidates = [] for feature_idx in range(n_features): # Extract feature values at this node feature_values = X_node[:, feature_idx] # Sort unique values unique_sorted = np.unique(feature_values) # Generate midpoint thresholds if len(unique_sorted) < 2: continue # Cannot split on constant feature for i in range(len(unique_sorted) - 1): threshold = (unique_sorted[i] + unique_sorted[i + 1]) / 2 candidates.append((feature_idx, threshold)) return candidates def count_candidate_splits(X): """Analyze candidate split counts per feature.""" n_samples, n_features = X.shape print("="*60) print("CANDIDATE SPLIT ANALYSIS") print("="*60) print(f"Total samples: {n_samples}") print(f"Total features: {n_features}") print() total_candidates = 0 for j in range(n_features): n_unique = len(np.unique(X[:, j])) n_thresholds = max(0, n_unique - 1) total_candidates += n_thresholds print(f"Feature {j}: {n_unique} unique values -> {n_thresholds} candidate thresholds") print() print(f"TOTAL CANDIDATE SPLITS: {total_candidates}") print(f"Maximum possible: {n_features * (n_samples - 1)} = {n_features}×{n_samples-1}") return total_candidates # Example with synthetic datanp.random.seed(42)X = np.column_stack([ np.random.randn(100), # Continuous: many unique values np.random.randint(0, 5, 100), # Discrete: 5 unique values np.random.choice([0, 1], 100), # Binary: 2 unique values np.ones(100) * 3.14 # Constant: no splits possible]) count_splits = count_candidate_splits(X)Real datasets often have features with few unique values (categorical encoded as integers, quantized measurements, etc.). This dramatically reduces the number of candidate splits compared to the theoretical maximum. Feature preprocessing that creates many unique values (e.g., adding noise for privacy) can significantly slow tree training.
For each candidate split, we must compute its gain. This involves partitioning the samples and calculating impurity for both children. Understanding the mechanics reveals opportunities for computational optimization.
Given a node with samples $\mathcal{D}_v$ and candidate split $(j, \theta)$:
function EvaluateSplit(samples, feature_j, threshold):
left_samples = [sample for sample in samples if sample.x[j] <= threshold]
right_samples = [sample for sample in samples if sample.x[j] > threshold]
impurity_left = ComputeImpurity(left_samples)
impurity_right = ComputeImpurity(right_samples)
n_total = len(samples)
n_left = len(left_samples)
n_right = len(right_samples)
weighted_child_impurity = (n_left/n_total) * impurity_left +
(n_right/n_total) * impurity_right
gain = ComputeImpurity(samples) - weighted_child_impurity
return gain
Per-split evaluation cost:
Total per node:
This quadratic cost in samples per feature is problematic for large datasets. Fortunately, there's a better way.
By sorting samples by feature value once, we can evaluate all thresholds for that feature in a single pass:
Key observation: As threshold increases, samples move from right child to left child one at a time. We can update impurity incrementally rather than recomputing from scratch.
Algorithm:
function EfficientSplitSearch(samples, feature_j):
# Sort samples by feature value
sorted_samples = Sort(samples, key=lambda s: s.x[j])
# Initialize: all samples in right child
left_stats = EmptyStats()
right_stats = ComputeStats(sorted_samples)
best_gain = -infinity
best_threshold = None
for i in range(len(sorted_samples) - 1):
sample = sorted_samples[i]
# Move sample from right to left
left_stats.Add(sample)
right_stats.Remove(sample)
# Only evaluate if this creates a distinct split
if sorted_samples[i].x[j] == sorted_samples[i+1].x[j]:
continue # Same value, skip duplicate threshold
# Compute gain from incremental stats
gain = ComputeGainFromStats(left_stats, right_stats)
if gain > best_gain:
best_gain = gain
best_threshold = (sorted_samples[i].x[j] + sorted_samples[i+1].x[j]) / 2
return best_threshold, best_gain
Improved complexity:
This is a dramatic improvement over the naive $O(d \times |\mathcal{D}_v|^2)$ approach.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
import numpy as npfrom dataclasses import dataclass @dataclassclass SplitStats: """Incremental statistics for Gini impurity computation.""" class_counts: np.ndarray total: int @classmethod def from_labels(cls, y: np.ndarray, n_classes: int): """Initialize from label array.""" counts = np.bincount(y, minlength=n_classes) return cls(class_counts=counts.astype(float), total=len(y)) @classmethod def empty(cls, n_classes: int): """Create empty stats.""" return cls(class_counts=np.zeros(n_classes), total=0) def add(self, label: int): """Add one sample to the stats.""" self.class_counts[label] += 1 self.total += 1 def remove(self, label: int): """Remove one sample from the stats.""" self.class_counts[label] -= 1 self.total -= 1 def gini_impurity(self) -> float: """Compute Gini impurity from current stats.""" if self.total == 0: return 0.0 proportions = self.class_counts / self.total return 1.0 - np.sum(proportions ** 2) def efficient_best_split_for_feature(X_feature: np.ndarray, y: np.ndarray, n_classes: int): """ Find best split for a single feature using sorted sweep. Returns: -------- best_threshold : float or None best_gain : float """ n_samples = len(X_feature) # Sort samples by feature value sort_idx = np.argsort(X_feature) sorted_features = X_feature[sort_idx] sorted_labels = y[sort_idx] # Parent impurity (for gain computation) parent_stats = SplitStats.from_labels(sorted_labels, n_classes) parent_impurity = parent_stats.gini_impurity() # Initialize: left empty, right has all left_stats = SplitStats.empty(n_classes) right_stats = SplitStats.from_labels(sorted_labels, n_classes) best_gain = -np.inf best_threshold = None # Sweep from left to right for i in range(n_samples - 1): # Move sample i from right to left label = sorted_labels[i] left_stats.add(label) right_stats.remove(label) # Skip if next sample has same feature value if sorted_features[i] == sorted_features[i + 1]: continue # Compute weighted child impurity n_left = left_stats.total n_right = right_stats.total weighted_impurity = ( (n_left / n_samples) * left_stats.gini_impurity() + (n_right / n_samples) * right_stats.gini_impurity() ) gain = parent_impurity - weighted_impurity if gain > best_gain: best_gain = gain best_threshold = (sorted_features[i] + sorted_features[i + 1]) / 2 return best_threshold, best_gain # Demonstrationnp.random.seed(42)X_single = np.array([1.5, 2.3, 0.8, 3.1, 2.1, 0.5, 3.8, 1.9])y = np.array([0, 0, 1, 1, 0, 1, 1, 0]) # Binary classification threshold, gain = efficient_best_split_for_feature(X_single, y, n_classes=2)print(f"Best threshold: {threshold:.3f}")print(f"Best gain: {gain:.4f}") # Verify by checking multiple thresholds manuallyprint("\n--- Verification by enumeration ---")unique_vals = np.unique(X_single)for i in range(len(unique_vals) - 1): t = (unique_vals[i] + unique_vals[i+1]) / 2 left_mask = X_single <= t right_mask = ~left_mask # Compute Gini for each side def gini(labels): if len(labels) == 0: return 0 probs = np.bincount(labels, minlength=2) / len(labels) return 1 - np.sum(probs**2) parent_gini = gini(y) left_gini = gini(y[left_mask]) right_gini = gini(y[right_mask]) n = len(y) n_l, n_r = np.sum(left_mask), np.sum(right_mask) weighted = (n_l/n)*left_gini + (n_r/n)*right_gini g = parent_gini - weighted print(f"Threshold {t:.2f}: gain = {g:.4f}")At each node, the decision tree algorithm performs an exhaustive search over all features and all candidate thresholds to find the optimal split. Let us trace through this process in detail.
function FindBestSplit(node):
if StoppingConditionMet(node):
return None # Make this a leaf
best_gain = -infinity
best_feature = None
best_threshold = None
for feature_j in range(d): # For each feature
# Efficient sorted sweep for this feature
threshold, gain = EfficientSplitSearch(node.samples, feature_j)
if gain > best_gain:
best_gain = gain
best_feature = feature_j
best_threshold = threshold
if best_gain <= 0: # No beneficial split found
return None
return (best_feature, best_threshold)
The exhaustive search is not performed when stopping conditions are met:
max_depth parametermin_samples_splitmin_samples_leaf samplesThese conditions prevent overfitting by limiting tree complexity.
When multiple splits achieve the same gain (common for discrete features), implementations typically:
| Component | Complexity | Notes |
|---|---|---|
| Sorting per feature | $O(N \log N)$ | Precomputed once per feature |
| Sweep per feature | $O(N)$ | Incremental stats update |
| All features at one node | $O(d \cdot N \log N)$ | Dominant cost |
| All nodes in balanced tree | $O(d \cdot N \log^2 N)$ | Each sample in $\log N$ nodes |
| All nodes in unbalanced tree | $O(d \cdot N^2 \log N)$ | Worst case: chain-like tree |
For typical datasets, decision tree training is very fast—often faster than training linear models on the same data. The $O(d \cdot N \log N)$ per-node cost is moderate, and trees rarely grow to extreme depth on real data. This efficiency is one reason trees (and tree-based ensembles) are so widely used.
Decision trees handle different feature types through variations of the splitting mechanism. Understanding these variations is essential for proper feature engineering.
For continuous features, splits are threshold-based: $$\text{condition: } x_j \leq \theta$$
As discussed, optimal thresholds are found via sorted sweep over midpoints between consecutive values. This is the standard CART approach.
Ordinal features (e.g., education level: high school < bachelor's < master's < PhD) are treated identically to continuous features. The sorted sweep respects the natural ordering, and threshold splits preserve ordinal relationships.
$$\text{if education} \leq \text{bachelor's} \Rightarrow {\text{high school, bachelor's}} \text{ vs } {\text{master's, PhD}}$$
Nominal features have no natural ordering (e.g., color: red, blue, green). Two approaches exist:
Approach 1: One-hot encoding (CART-style)
Convert to binary features and use standard threshold splits:
Splits become: "Is color_red ≤ 0.5?" (equivalently: "Is it not red?")
Pros: Conceptually simple, works with standard algorithm Cons: Creates many features; splits can only isolate one category at a time
Approach 2: Subset splits (ID3/C4.5-style)
Consider all possible subsets of categories: $$\text{condition: } x_j \in S \text{ vs } x_j \notin S$$
For a feature with $k$ categories, there are $2^{k-1} - 1$ possible subset splits.
Pros: Can find optimal groupings; single split can separate multiple categories Cons: Exponential in number of categories; computationally expensive for $k > 10$
High-cardinality categorical features (e.g., user_id, zip_code) require special handling. Pure ID3-style subset enumeration becomes infeasible. Common approaches include: target encoding (convert to numeric), frequency encoding, hashing tricks, or treating as ordinal by frequency. Modern gradient boosting libraries like LightGBM and CatBoost have specialized categorical feature handlers.
Binary features (0/1, true/false) are trivial: only one possible split exists: $$x_j \leq 0.5 \Leftrightarrow x_j = 0$$
Either this split is beneficial (applied) or not (feature ignored at this node).
Missing values pose a challenge: which child should a sample with $x_j = \text{NaN}$ follow?
CART approach (surrogate splits):
XGBoost/LightGBM approach (default direction):
Before diving into formal impurity measures (next page), let's build intuition about what makes a split good.
A perfect split completely separates classes:
Both children are pure (impurity = 0). The split has maximum gain.
A useless split preserves the parent's class distribution in both children:
Children are just as impure as parent. The split has zero gain.
Consider a 2D classification problem where we want to separate red and blue points:
Case 1: Good split
Case 2: Bad split
Case 3: Unbalanced split
A split that creates one perfectly pure child of size 1 and leaves the rest unchanged is not valuable. The weighted average impurity is almost unchanged because the tiny child contributes proportionally to its sample count. Good splits create substantial pure regions, not just token ones. This is why minimum sample constraints exist.
We have established a comprehensive understanding of how decision trees select splits. Let us consolidate the key concepts:
What's next:
We now understand how splits are searched and what makes a split valuable at an intuitive level. The next critical question is: How do we quantify split quality precisely? The following pages explore leaf predictions, recursive partitioning, and the interpretability properties that make decision trees uniquely valuable in machine learning.
You now understand the mechanics of split selection in decision trees—from candidate generation through exhaustive search. These splitting rules are the engine that drives tree construction, determining how the feature space is recursively partitioned. Next, we explore how leaf nodes make predictions once the tree is grown.