Loading content...
At every internal node of a regression tree, the algorithm faces a critical decision: Which feature and which threshold should define the split? The answer to this question determines the entire structure of the resulting tree and, consequently, its predictive performance. The splitting criterion is the mathematical objective that guides this decision.
For regression trees, the canonical choice is Mean Squared Error (MSE)—also known as variance reduction, squared error loss, or simply the L2 criterion. This choice is not arbitrary; it emerges from deep connections to maximum likelihood estimation, analysis of variance, and optimal prediction theory. Understanding MSE as a splitting criterion reveals why regression trees behave the way they do and illuminates their strengths and limitations.
By the end of this page, you will understand the MSE splitting criterion from multiple perspectives: as a variance reduction objective, as maximum likelihood under Gaussian assumptions, as an ANOVA decomposition, and as a special case of more general loss functions. You will also understand the computational techniques that make MSE practical and the theoretical properties that make it effective.
Let's establish the precise mathematical framework for MSE-based splitting.
Setup:
Consider a node $t$ containing $n$ training samples with target values ${y_1, y_2, \ldots, y_n}$. The impurity (or error) of this node under MSE is:
$$I(t) = \frac{1}{n} \sum_{i=1}^{n} (y_i - \bar{y}_t)^2 = \text{Var}(y_t)$$
where $\bar{y}t = \frac{1}{n} \sum{i=1}^{n} y_i$ is the mean of targets at node $t$.
Split evaluation:
A candidate split $s$ partitions node $t$ into left child $t_L$ and right child $t_R$. The impurity after split is the weighted average of child impurities:
$$I(t, s) = \frac{n_L}{n} I(t_L) + \frac{n_R}{n} I(t_R) = \frac{n_L}{n} \text{Var}(y_{t_L}) + \frac{n_R}{n} \text{Var}(y_{t_R})$$
Impurity gain (reduction):
The gain from split $s$ is the reduction in impurity:
$$\Delta I(t, s) = I(t) - I(t, s) = \text{Var}(y_t) - \left[\frac{n_L}{n} \text{Var}(y_{t_L}) + \frac{n_R}{n} \text{Var}(y_{t_R})\right]$$
The optimal split $s^*$ maximizes this gain:
$$s^* = \arg\max_s \Delta I(t, s)$$
Variance and MSE are intimately related but distinct concepts. Variance measures dispersion around the mean: Var(y) = E[(y - E[y])²]. MSE measures error relative to a prediction: MSE = E[(y - ŷ)²]. When the prediction is the sample mean (as in regression tree leaves), MSE equals variance. The terms are used interchangeably in tree literature because the optimal leaf prediction is always the mean.
Equivalent formulations:
The impurity gain can be expressed in several equivalent ways:
Form 1: Variance reduction $$\Delta I = \text{Var}(y_t) - \frac{n_L}{n} \text{Var}(y_{t_L}) - \frac{n_R}{n} \text{Var}(y_{t_R})$$
Form 2: Between-group variance $$\Delta I = \frac{n_L n_R}{n^2} (\bar{y}{t_L} - \bar{y}{t_R})^2$$
This remarkable equivalence shows that maximizing variance reduction is equivalent to maximizing the separation between child means.
Form 3: Sum of squared residuals reduction $$\Delta I = \frac{1}{n} \left[ \sum_{i=1}^{n} (y_i - \bar{y}t)^2 - \sum{i \in t_L} (y_i - \bar{y}{t_L})^2 - \sum{i \in t_R} (y_i - \bar{y}_{t_R})^2 \right]$$
The MSE splitting criterion has deep connections to analysis of variance (ANOVA)—the classical statistical framework for understanding variability.
Total variance decomposition:
For any partition of data, total variance decomposes into within-group and between-group components:
$$\text{Var}{\text{total}} = \text{Var}{\text{within}} + \text{Var}_{\text{between}}$$
More precisely:
$$\underbrace{\frac{1}{n}\sum_{i}(y_i - \bar{y})^2}{\text{SST}/n} = \underbrace{\frac{1}{n}\sum{g}\sum_{i \in g}(y_i - \bar{y}g)^2}{\text{SSW}/n} + \underbrace{\frac{1}{n}\sum_{g} n_g (\bar{y}g - \bar{y})^2}{\text{SSB}/n}$$
where:
Split optimization through ANOVA lens:
Maximizing impurity gain is equivalent to:
This is exactly what one-way ANOVA does—except tree splitting considers all possible two-group partitions defined by feature thresholds.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
import numpy as np def variance_decomposition(y, groups): """ Decompose total variance into within and between components. This illustrates the ANOVA connection to MSE splitting. Parameters: ----------- y : array-like Target values groups : array-like Group membership for each sample (e.g., [0,0,1,1,1,0]) Returns: -------- dict with SST, SSW, SSB, and R² (proportion explained) """ y = np.asarray(y) groups = np.asarray(groups) n = len(y) # Overall mean y_bar = np.mean(y) # Total sum of squares SST = np.sum((y - y_bar) ** 2) # Within-group sum of squares SSW = 0.0 unique_groups = np.unique(groups) group_means = {} group_counts = {} for g in unique_groups: mask = groups == g y_g = y[mask] y_bar_g = np.mean(y_g) group_means[g] = y_bar_g group_counts[g] = len(y_g) SSW += np.sum((y_g - y_bar_g) ** 2) # Between-group sum of squares SSB = 0.0 for g in unique_groups: n_g = group_counts[g] y_bar_g = group_means[g] SSB += n_g * (y_bar_g - y_bar) ** 2 # Verify decomposition: SST = SSW + SSB assert np.isclose(SST, SSW + SSB), "Variance decomposition failed!" # R² is proportion of variance explained by grouping R_squared = SSB / SST if SST > 0 else 0.0 return { 'SST': SST, 'SSW': SSW, 'SSB': SSB, 'variance_total': SST / n, 'variance_within': SSW / n, 'variance_between': SSB / n, 'R_squared': R_squared, 'impurity_gain': SSB / n # This is what tree splitting maximizes } def demonstrate_anova_equivalence(): """ Show that MSE split gain equals between-group variance. """ # Example data y = np.array([1, 2, 2, 3, 8, 9, 10, 11]) # Consider split at position 4 (after index 3) split_point = 4 y_left = y[:split_point] y_right = y[split_point:] n = len(y) n_L, n_R = len(y_left), len(y_right) # Method 1: Direct variance reduction var_total = np.var(y) var_left = np.var(y_left) var_right = np.var(y_right) gain_method1 = var_total - (n_L/n * var_left + n_R/n * var_right) # Method 2: Between-group variance formula mean_left = np.mean(y_left) mean_right = np.mean(y_right) gain_method2 = (n_L * n_R / n**2) * (mean_left - mean_right)**2 print(f"Total variance: {var_total:.4f}") print(f"Left variance: {var_left:.4f}, Right variance: {var_right:.4f}") print(f"Weighted child variance: {n_L/n * var_left + n_R/n * var_right:.4f}") print(f"\nGain (variance reduction): {gain_method1:.4f}") print(f"Gain (between-group formula): {gain_method2:.4f}") print(f"Methods match: {np.isclose(gain_method1, gain_method2)}") # Method 3: Full ANOVA decomposition groups = np.array([0]*split_point + [1]*(n-split_point)) anova = variance_decomposition(y, groups) print(f"\nFull ANOVA decomposition:") print(f" Total variance (SST/n): {anova['variance_total']:.4f}") print(f" Within-group (SSW/n): {anova['variance_within']:.4f}") print(f" Between-group (SSB/n): {anova['variance_between']:.4f}") print(f" Impurity gain: {anova['impurity_gain']:.4f}") # Run demonstrationdemonstrate_anova_equivalence()The ratio SSB/SST gives R²—the proportion of variance explained by the split. A split with R² = 0.3 means the binary grouping explains 30% of target variability. Successive splits in a tree progressively explain more variance, with the full tree's R² being the sum across all splits (accounting for the fraction of data reaching each node).
The MSE criterion has a rigorous statistical foundation through maximum likelihood estimation under Gaussian assumptions.
The statistical model:
Assume targets follow a Gaussian distribution with region-specific means:
$$y_i | \mathbf{x}_i \in R_m \sim \mathcal{N}(\mu_m, \sigma^2)$$
where $R_m$ is the region corresponding to leaf $m$, $\mu_m$ is the true mean in that region, and $\sigma^2$ is constant noise variance.
Log-likelihood for the tree:
For a tree with $M$ leaves defining regions ${R_1, \ldots, R_M}$:
$$\log L({\mu_m}, \sigma^2) = -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2}\sum_{m=1}^{M}\sum_{i \in R_m}(y_i - \mu_m)^2$$
Maximum likelihood estimates:
For the region means: $\hat{\mu}_m = \bar{y}m = \frac{1}{n_m}\sum{i \in R_m} y_i$ (the sample mean)
For variance: $\hat{\sigma}^2 = \frac{1}{n}\sum_{m=1}^{M}\sum_{i \in R_m}(y_i - \bar{y}_m)^2$ (MSE over the tree)
Maximizing likelihood = Minimizing MSE:
Since $\sigma^2$ doesn't depend on tree structure, maximizing log-likelihood is equivalent to minimizing:
$$\sum_{m=1}^{M}\sum_{i \in R_m}(y_i - \bar{y}_m)^2 = \text{Total within-leaf sum of squares}$$
At each split, choosing the partition that most reduces this sum is precisely the MSE splitting criterion.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
import numpy as npfrom scipy import stats def gaussian_log_likelihood(y, predictions, sigma_sq=None): """ Compute Gaussian log-likelihood for tree predictions. Parameters: ----------- y : array-like True target values predictions : array-like Predicted values (leaf means) sigma_sq : float or None Noise variance. If None, uses MLE estimate. Returns: -------- log_lik : float Log-likelihood value """ y = np.asarray(y) predictions = np.asarray(predictions) n = len(y) residuals = y - predictions ss_residuals = np.sum(residuals ** 2) # MLE estimate of variance if not provided if sigma_sq is None: sigma_sq = ss_residuals / n # Log-likelihood under Gaussian model log_lik = (-n/2 * np.log(2 * np.pi * sigma_sq) - ss_residuals / (2 * sigma_sq)) return log_lik def compare_splits_via_likelihood(y, split_a_groups, split_b_groups): """ Compare two candidate splits using log-likelihood. Shows that the split with higher likelihood has lower MSE. """ def compute_predictions_and_mse(y, groups): """Get mean predictions for each group and overall MSE.""" predictions = np.zeros_like(y, dtype=float) for g in np.unique(groups): mask = groups == g group_mean = np.mean(y[mask]) predictions[mask] = group_mean mse = np.mean((y - predictions) ** 2) return predictions, mse # Split A pred_a, mse_a = compute_predictions_and_mse(y, split_a_groups) ll_a = gaussian_log_likelihood(y, pred_a) # Split B pred_b, mse_b = compute_predictions_and_mse(y, split_b_groups) ll_b = gaussian_log_likelihood(y, pred_b) print("Split comparison via maximum likelihood:") print(f"\nSplit A: MSE = {mse_a:.4f}, Log-Lik = {ll_a:.4f}") print(f"Split B: MSE = {mse_b:.4f}, Log-Lik = {ll_b:.4f}") if ll_a > ll_b: print(f"\nSplit A has higher likelihood (better)") assert mse_a < mse_b, "Higher LL should mean lower MSE!" else: print(f"\nSplit B has higher likelihood (better)") assert mse_b < mse_a, "Higher LL should mean lower MSE!" return {'split_a': {'mse': mse_a, 'log_lik': ll_a}, 'split_b': {'mse': mse_b, 'log_lik': ll_b}} # Demonstrationnp.random.seed(42)y = np.array([1, 2, 3, 4, 10, 11, 12, 13]) # Split A: optimal (separates low and high values)groups_a = np.array([0, 0, 0, 0, 1, 1, 1, 1]) # Split B: suboptimal (mixes values)groups_b = np.array([0, 1, 0, 1, 0, 1, 0, 1]) compare_splits_via_likelihood(y, groups_a, groups_b)While MSE is optimal under Gaussian noise, it remains a reasonable choice for non-Gaussian distributions due to the central limit theorem. For large sample sizes, sample means are approximately normally distributed regardless of the underlying distribution. However, for heavy-tailed distributions or data with outliers, robust alternatives like MAE (Mean Absolute Error) may be preferable.
One of the most elegant results in regression tree theory is that impurity gain depends only on group sizes and group means—not on within-group variance.
The closed-form formula:
For a binary split into groups $L$ and $R$:
$$\Delta I = \frac{n_L n_R}{n^2} (\bar{y}_L - \bar{y}_R)^2$$
This formula has profound implications:
Derivation:
We prove this starting from the variance reduction formula.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
import numpy as np def prove_between_group_formula(): """ Rigorous derivation of the between-group variance formula. We prove: variance_reduction = (n_L * n_R / n²) * (ȳ_L - ȳ_R)² """ # Let's work symbolically with concrete values y = np.array([1, 2, 3, 4, 10, 11, 12, 13], dtype=float) n = len(y) # Split into left (first 4) and right (last 4) y_L = y[:4] y_R = y[4:] n_L, n_R = len(y_L), len(y_R) # Overall mean y_bar = np.mean(y) # Group means y_bar_L = np.mean(y_L) y_bar_R = np.mean(y_R) print("=" * 60) print("DERIVATION: Between-Group Variance Formula") print("=" * 60) print(f"\nData: y = {y}") print(f"Left group: {y_L}, Right group: {y_R}") print(f"n = {n}, n_L = {n_L}, n_R = {n_R}") print(f"Overall mean: ȳ = {y_bar:.2f}") print(f"Left mean: ȳ_L = {y_bar_L:.2f}") print(f"Right mean: ȳ_R = {y_bar_R:.2f}") # Step 1: Total variance var_total = np.mean((y - y_bar) ** 2) print(f"\nStep 1: Total variance = {var_total:.4f}") # Step 2: Within-group variances var_L = np.mean((y_L - y_bar_L) ** 2) if n_L > 0 else 0 var_R = np.mean((y_R - y_bar_R) ** 2) if n_R > 0 else 0 weighted_within = (n_L/n) * var_L + (n_R/n) * var_R print(f"\nStep 2: Within-group variances") print(f" Var(left) = {var_L:.4f}") print(f" Var(right) = {var_R:.4f}") print(f" Weighted: (n_L/n)·Var(L) + (n_R/n)·Var(R) = {weighted_within:.4f}") # Step 3: Variance reduction (direct calculation) gain_direct = var_total - weighted_within print(f"\nStep 3: Direct variance reduction") print(f" Δ = Var_total - Var_within = {gain_direct:.4f}") # Step 4: Between-group formula mean_diff_sq = (y_bar_L - y_bar_R) ** 2 gain_formula = (n_L * n_R / n**2) * mean_diff_sq print(f"\nStep 4: Between-group formula") print(f" (ȳ_L - ȳ_R)² = ({y_bar_L:.2f} - {y_bar_R:.2f})² = {mean_diff_sq:.4f}") print(f" n_L·n_R/n² = {n_L}·{n_R}/{n}² = {n_L*n_R/n**2:.4f}") print(f" Δ = (n_L·n_R/n²)·(ȳ_L - ȳ_R)² = {gain_formula:.4f}") # Verify equality print(f"\nVERIFICATION:") print(f" Direct method: {gain_direct:.6f}") print(f" Formula method: {gain_formula:.6f}") print(f" Match: {np.isclose(gain_direct, gain_formula)}") # Mathematical proof sketch print("\n" + "=" * 60) print("MATHEMATICAL PROOF:") print("=" * 60) print(""" Key identity: For any partition into groups L and R, Var_total = Var_within + Var_between where: - Var_within = (n_L/n)·Var(y_L) + (n_R/n)·Var(y_R) - Var_between = (n_L/n)·(ȳ_L - ȳ)² + (n_R/n)·(ȳ_R - ȳ)² Since ȳ = (n_L/n)·ȳ_L + (n_R/n)·ȳ_R (weighted mean property): ȳ_L - ȳ = ȳ_L - (n_L/n)·ȳ_L - (n_R/n)·ȳ_R = (n_R/n)·(ȳ_L - ȳ_R) ȳ_R - ȳ = -(n_L/n)·(ȳ_L - ȳ_R) Substituting into Var_between: Var_between = (n_L/n)·[(n_R/n)·(ȳ_L - ȳ_R)]² + (n_R/n)·[(n_L/n)·(ȳ_L - ȳ_R)]² = (n_L·n_R²/n³ + n_R·n_L²/n³)·(ȳ_L - ȳ_R)² = (n_L·n_R/n²)·(n_R + n_L)/n·(ȳ_L - ȳ_R)² = (n_L·n_R/n²)·(ȳ_L - ȳ_R)² ∎ """) prove_between_group_formula()| n_L | n_R | n_L·n_R/n² | Effect on Gain |
|---|---|---|---|
| 50% | 50% | 0.2500 | Maximum (for fixed mean difference) |
| 40% | 60% | 0.2400 | 96% of maximum |
| 30% | 70% | 0.2100 | 84% of maximum |
| 20% | 80% | 0.1600 | 64% of maximum |
| 10% | 90% | 0.0900 | 36% of maximum |
| 5% | 95% | 0.0475 | 19% of maximum |
| 1% | 99% | 0.0099 | 4% of maximum |
The balance factor n_L·n_R/n² means that very unbalanced splits require much larger mean differences to be selected. This naturally discourages creating tiny leaves with few samples—a form of implicit regularization. The min_samples_leaf constraint adds explicit protection beyond this natural tendency.
Naive evaluation of all possible splits has quadratic complexity. The efficient approach uses incremental statistics with a sorted scan, achieving linearithmic complexity.
Key insight: Sufficient statistics
For variance computation, we only need three quantities:
Variance follows from: $$\text{Var}(y) = \frac{1}{n}\sum_i (y_i - \bar{y})^2 = \frac{Q}{n} - \left(\frac{S}{n}\right)^2$$
Incremental updates:
As we scan through sorted data, moving one sample from right to left updates statistics:
$$n_L \leftarrow n_L + 1, \quad S_L \leftarrow S_L + y_i, \quad Q_L \leftarrow Q_L + y_i^2$$ $$n_R \leftarrow n_R - 1, \quad S_R \leftarrow S_R - y_i, \quad Q_R \leftarrow Q_R - y_i^2$$
Each update is O(1), enabling O(n) evaluation of all thresholds for a sorted feature.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
import numpy as npfrom typing import Optional, Dict, Any class EfficientMSESplitter: """ Efficient MSE split finding using incremental statistics. This implementation achieves O(n log n) per feature through: 1. Sorting feature values once: O(n log n) 2. Scanning with O(1) updates: O(n) Total: O(p · n log n) for p features """ def __init__(self, min_samples_leaf: int = 1): self.min_samples_leaf = min_samples_leaf def find_best_split(self, X: np.ndarray, y: np.ndarray) -> Optional[Dict[str, Any]]: """ Find the best MSE-reducing split across all features. """ n_samples, n_features = X.shape # Compute parent statistics (for gain calculation) parent_sum = np.sum(y) parent_sum_sq = np.sum(y ** 2) parent_impurity = self._compute_impurity(n_samples, parent_sum, parent_sum_sq) best_split = None best_impurity_after = parent_impurity for feature_idx in range(n_features): split = self._find_best_threshold( X[:, feature_idx], y, parent_impurity ) if split is not None: if split['impurity_after'] < best_impurity_after: best_impurity_after = split['impurity_after'] best_split = { 'feature': feature_idx, 'threshold': split['threshold'], 'gain': parent_impurity - split['impurity_after'], 'n_left': split['n_left'], 'n_right': split['n_right'] } return best_split def _find_best_threshold(self, feature: np.ndarray, y: np.ndarray, parent_impurity: float) -> Optional[Dict[str, Any]]: """ Find best threshold for a single feature using sorted scan. """ n = len(y) # Sort by feature value sorted_indices = np.argsort(feature) sorted_feature = feature[sorted_indices] sorted_y = y[sorted_indices] # Initialize: all samples in right child n_left = 0 sum_left = 0.0 sum_sq_left = 0.0 n_right = n sum_right = np.sum(sorted_y) sum_sq_right = np.sum(sorted_y ** 2) best_threshold = None best_impurity = float('inf') best_n_left = 0 best_n_right = 0 # Scan left to right, moving samples from right to left for i in range(n - 1): # Stop before last sample yi = sorted_y[i] # Move sample i from right to left n_left += 1 sum_left += yi sum_sq_left += yi ** 2 n_right -= 1 sum_right -= yi sum_sq_right -= yi ** 2 # Skip if not a valid split point (duplicate feature values) if sorted_feature[i] == sorted_feature[i + 1]: continue # Check minimum samples constraint if n_left < self.min_samples_leaf: continue if n_right < self.min_samples_leaf: break # All remaining splits will violate constraint # Compute weighted impurity after split impurity_left = self._compute_impurity(n_left, sum_left, sum_sq_left) impurity_right = self._compute_impurity(n_right, sum_right, sum_sq_right) # Weighted average (sum form, not divided by n yet) impurity_after = impurity_left + impurity_right if impurity_after < best_impurity: best_impurity = impurity_after # Threshold at midpoint between consecutive values best_threshold = (sorted_feature[i] + sorted_feature[i + 1]) / 2 best_n_left = n_left best_n_right = n_right if best_threshold is None: return None return { 'threshold': best_threshold, 'impurity_after': best_impurity / n, # Normalize 'n_left': best_n_left, 'n_right': best_n_right } def _compute_impurity(self, n: int, sum_y: float, sum_sq_y: float) -> float: """ Compute (unnormalized) sum of squared deviations from mean. Using: Σ(y - ȳ)² = Σy² - (Σy)²/n = sum_sq - sum²/n Returns n * variance (unnormalized for stable computation) """ if n == 0: return 0.0 mean_sq = (sum_y / n) ** 2 variance = sum_sq_y / n - mean_sq # Numerical stability: variance can be slightly negative due to precision variance = max(0.0, variance) return n * variance # Return unnormalized for weighted combination def demonstrate_efficiency(): """ Compare naive vs efficient implementation complexity. """ import time np.random.seed(42) sizes = [100, 500, 1000, 5000, 10000] splitter = EfficientMSESplitter(min_samples_leaf=5) print("Timing comparison: Efficient MSE split finding") print("-" * 50) for n in sizes: X = np.random.randn(n, 10) # 10 features y = np.random.randn(n) start = time.time() for _ in range(5): # Average over 5 runs splitter.find_best_split(X, y) elapsed = (time.time() - start) / 5 print(f"n = {n:5d}: {elapsed*1000:.2f} ms per split") # demonstrate_efficiency()The formula Var = E[X²] - E[X]² can suffer from catastrophic cancellation when values are large but similar. The implementation above guards against this with max(0, variance). Production implementations often use numerically stable algorithms like Welford's method for running variance.
While MSE is the standard choice, alternative criteria exist for different objectives or robustness requirements.
Mean Absolute Error (MAE) / L1 Criterion:
$$I_{\text{MAE}}(t) = \frac{1}{n}\sum_{i \in t} |y_i - \text{median}(y_t)|$$
Friedman's MSE Improvement:
$$\Delta I_{\text{Friedman}} = \frac{n_L n_R}{n^2}(\bar{y}_L - \bar{y}_R)^2$$
This is the between-group formula we derived—used in scikit-learn as an optimization.
Poisson Deviance (for count data):
$$I_{\text{Poisson}}(t) = \sum_{i \in t} (y_i \log(y_i / \hat{y}_t) - (y_i - \hat{y}_t))$$
where $\hat{y}_t = \bar{y}_t$. Appropriate when targets are counts.
Huber Loss (robust MSE):
$$L_{\delta}(y, \hat{y}) = \begin{cases} \frac{1}{2}(y - \hat{y})^2 & |y - \hat{y}| \leq \delta \ \delta(|y - \hat{y}| - \frac{\delta}{2}) & \text{otherwise} \end{cases}$$
Quadratic for small errors, linear for large—balancing efficiency and robustness.
| Criterion | Optimal Prediction | Robustness | Computational Cost |
|---|---|---|---|
| MSE (L2) | Mean | Sensitive to outliers | O(1) update |
| MAE (L1) | Median | Robust to outliers | O(n) for median |
| Huber | M-estimator | Tunable robustness | O(n) iterative |
| Poisson | Mean (non-negative) | For count data | O(1) update |
| Quantile | τ-th quantile | Conditional quantiles | O(n) for quantile |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
import numpy as np def mae_impurity(y): """Mean Absolute Error impurity using median.""" if len(y) == 0: return 0.0 median = np.median(y) return np.mean(np.abs(y - median)) def mse_impurity(y): """Mean Squared Error impurity using mean.""" if len(y) == 0: return 0.0 return np.var(y) def huber_impurity(y, delta=1.0): """Huber loss impurity.""" if len(y) == 0: return 0.0 # M-estimator for Huber is computed iteratively # For simplicity, use median as approximation center = np.median(y) errors = y - center loss = np.where( np.abs(errors) <= delta, 0.5 * errors ** 2, delta * (np.abs(errors) - 0.5 * delta) ) return np.mean(loss) def compare_criteria_robustness(): """ Demonstrate robustness differences between criteria. """ np.random.seed(42) # Clean data y_clean = np.random.normal(5, 1, 100) # Data with outliers (10% contamination) y_outliers = np.concatenate([ np.random.normal(5, 1, 90), np.random.normal(50, 10, 10) # Outliers at 50 ± 10 ]) print("Impurity Comparison: MSE vs MAE vs Huber") print("=" * 55) print("\nClean data (N(5, 1)):") print(f" MSE: {mse_impurity(y_clean):.4f}") print(f" MAE: {mae_impurity(y_clean):.4f}") print(f" Huber: {huber_impurity(y_clean):.4f}") print("\nData with outliers (10% at N(50, 10)):") print(f" MSE: {mse_impurity(y_outliers):.4f}") print(f" MAE: {mae_impurity(y_outliers):.4f}") print(f" Huber: {huber_impurity(y_outliers):.4f}") print("\nRatio (contaminated / clean):") print(f" MSE: {mse_impurity(y_outliers)/mse_impurity(y_clean):.2f}x") print(f" MAE: {mae_impurity(y_outliers)/mae_impurity(y_clean):.2f}x") print(f" Huber: {huber_impurity(y_outliers)/huber_impurity(y_clean):.2f}x") compare_criteria_robustness()Different criteria lead to different splits and different tree structures. MAE-based trees tend to be more balanced because median is less influenced by extreme values at split boundaries. When data has outliers or heavy tails, MSE trees may waste splits trying to isolate extreme values rather than capturing the main structure.
The MSE criterion connects directly to the fundamental bias-variance trade-off in machine learning.
The decomposition:
For any prediction $\hat{f}(\mathbf{x})$ and true function $f(\mathbf{x})$ with noise $\epsilon \sim (0, \sigma^2)$:
$$\text{MSE} = \mathbb{E}[(y - \hat{f})^2] = \underbrace{(\mathbb{E}[\hat{f}] - f)^2}{\text{Bias}^2} + \underbrace{\mathbb{E}[(\hat{f} - \mathbb{E}[\hat{f}])^2]}{\text{Variance}} + \underbrace{\sigma^2}_{\text{Irreducible}}$$
How this manifests in trees:
Shallow trees (few splits):
Deep trees (many splits):
The role of MSE minimization:
Minimizing MSE at each split greedily reduces total empirical error. However:
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
import numpy as npfrom sklearn.tree import DecisionTreeRegressor def bias_variance_analysis(X, y, depths=[1, 3, 5, 7, 10, 15, 20], n_bootstrap=100, test_size=0.3): """ Empirical bias-variance decomposition for trees at different depths. Uses bootstrap resampling to estimate: - Bias: How far average prediction is from true function - Variance: How much predictions vary across training samples """ from sklearn.model_selection import train_test_split n_samples = len(y) results = [] for depth in depths: # Store predictions from multiple bootstrap samples all_predictions = [] # Generate bootstrap samples and train trees for _ in range(n_bootstrap): # Bootstrap sample indices = np.random.choice(n_samples, size=n_samples, replace=True) X_boot = X[indices] y_boot = y[indices] # Train tree tree = DecisionTreeRegressor(max_depth=depth, random_state=None) tree.fit(X_boot, y_boot) # Predict on original X predictions = tree.predict(X) all_predictions.append(predictions) all_predictions = np.array(all_predictions) # Shape: (n_bootstrap, n_samples) # Expected prediction (average across bootstrap samples) mean_prediction = np.mean(all_predictions, axis=0) # Bias² = E[(E[f̂] - y)²] ≈ mean((mean_prediction - y)²) # Note: Using y as proxy for f(x) - this conflates noise bias_sq = np.mean((mean_prediction - y) ** 2) # Variance = E[(f̂ - E[f̂])²] = mean(var across bootstraps) variance = np.mean(np.var(all_predictions, axis=0)) # Total MSE (averaged across bootstrap predictions) mse_per_sample = np.mean((all_predictions - y.reshape(1, -1)) ** 2, axis=0) total_mse = np.mean(mse_per_sample) results.append({ 'depth': depth, 'bias_sq': bias_sq, 'variance': variance, 'total_mse': total_mse, 'n_leaves_avg': np.mean([2**min(depth, int(np.log2(n_samples)))]) }) print(f"Depth {depth:2d}: Bias²={bias_sq:.4f}, Var={variance:.4f}, " f"MSE={total_mse:.4f}") return results # Example usage with synthetic dataif __name__ == "__main__": np.random.seed(42) # Generate data: y = sin(x) + noise X = np.linspace(0, 10, 200).reshape(-1, 1) y_true = np.sin(X.ravel()) y = y_true + np.random.randn(200) * 0.3 print("Bias-Variance Analysis for Different Tree Depths") print("=" * 60) results = bias_variance_analysis(X, y, depths=[1, 2, 3, 5, 7, 10, 15, 20]) # Find optimal depth (minimum MSE) optimal = min(results, key=lambda r: r['total_mse']) print(f"\nOptimal depth: {optimal['depth']} with MSE = {optimal['total_mse']:.4f}")We've conducted a thorough examination of the Mean Squared Error splitting criterion—the mathematical engine that drives regression tree construction.
What's next:
With the splitting criterion established, the next page examines leaf predictions—how regression trees make predictions within each region, the optimality of the mean, and extensions to more sophisticated leaf models.
You now have a deep understanding of the MSE splitting criterion from multiple perspectives: mathematical, statistical, computational, and theoretical. This knowledge is essential for understanding why regression trees behave as they do and for making informed choices about when to use alternatives.