Loading learning content...
Random sampling, despite its theoretical elegance, can fail dramatically in practice. Consider a fraud detection dataset with 1% fraud cases: a purely random 20% test split might, by chance, contain only 0.3% fraud—or 2%. The model's performance on such a skewed split tells you little about its true capability.
Stratification solves this problem by ensuring that each data partition mirrors the statistical properties of the whole dataset. It's the bridge between random sampling theory and the messy reality of finite, often imbalanced, datasets.
This isn't just a nice-to-have—stratification can mean the difference between a reliable model assessment and a result that fails catastrophically in production because it was evaluated on an unrepresentative split.
This page covers the theory of stratified sampling, implementation for classification and regression tasks, handling of multi-label and multi-output scenarios, grouped stratification, continuous variable binning strategies, and quality verification methods. You'll master stratification at the level expected of senior data scientists.
To understand why stratification matters, we must first understand how random sampling can fail—and quantify when this failure becomes problematic.
Random Sampling and the Small Sample Problem
When we split data randomly, each sample has equal probability of appearing in training or test. For large datasets, the law of large numbers ensures proportions converge. But for finite samples, substantial deviations are possible.
Quantifying the Deviation
Consider a binary classification problem with minority class proportion $p$. If we draw $n$ samples randomly, the number of minority samples follows a binomial distribution: $$N_{minority} \sim \text{Binomial}(n, p)$$
The expected proportion is $p$, but the standard deviation is: $$\sigma = \sqrt{\frac{p(1-p)}{n}}$$
For $p = 0.05$ (5% minority) and $n = 100$ (test set size): $$\sigma = \sqrt{\frac{0.05 \times 0.95}{100}} \approx 0.022$$
A 2-sigma deviation (95% of cases) gives proportions ranging from 0.6% to 9.4%—a factor of 2x variation from the true 5%!
| Minority % | n = 100 | n = 500 | n = 1000 | n = 5000 |
|---|---|---|---|---|
| 50% | ±10% | ±4.5% | ±3.2% | ±1.4% |
| 10% | ±6.0% | ±2.7% | ±1.9% | ±0.8% |
| 5% | ±4.4% | ±1.9% | ±1.4% | ±0.6% |
| 1% | ±2.0% | ±0.9% | ±0.6% | ±0.3% |
| 0.1% | ±0.6% | ±0.3% | ±0.2% | ±0.1% |
For very rare classes (< 1%), random splitting can leave zero examples in the test set. With 0.5% minority and 200 test samples, there's a 37% chance of getting ≤ 0 minority samples—making evaluation of minority class performance impossible.
Consequences of Unrepresentative Splits
Unreliable Metrics: Metrics computed on unrepresentative splits don't reflect true performance. A model might appear excellent on a split with 1% fraud, but actually perform poorly on the true 5% fraud rate.
Metric Variance: Different random seeds produce wildly different results, making model comparison unreliable.
Impossible Evaluation: If a class is absent from test, you simply can't evaluate that class.
False Confidence: A lucky split might suggest a model is ready for production when it's not.
Training Issues: For very imbalanced data, training might receive even fewer minority examples than expected, degrading learning.
The Solution: Stratified Sampling
Stratification forces the random split to maintain specified proportions exactly (or as close as possible given integer constraints). Instead of hoping randomness produces a representative split, we guarantee it.
Stratified sampling has a rigorous foundation in survey sampling theory, with direct applications to machine learning. Understanding the theory helps us apply stratification correctly and know its limitations.
Formal Definition
Given a population partitioned into $K$ non-overlapping strata $S_1, S_2, \ldots, S_K$, stratified sampling selects a simple random sample from each stratum independently.
Let:
Proportional Allocation
The most common approach: allocate sample sizes proportionally to stratum sizes: $$n_k = n \cdot W_k = n \cdot \frac{N_k}{N}$$
This ensures each stratum is represented in the sample in proportion to its population representation.
Why Stratification Reduces Variance
For estimating a population mean $\bar{Y}$, the variance of the stratified estimator is: $$\text{Var}(\bar{y}{str}) = \sum{k=1}^{K} W_k^2 \cdot \frac{S_k^2}{n_k}$$
where $S_k^2$ is the variance within stratum $k$.
Compare to simple random sampling: $$\text{Var}(\bar{y}_{srs}) = \frac{S^2}{n}$$
where $S^2$ is total population variance.
Key Insight: Stratification partitions total variance into within-stratum and between-stratum components: $$S^2 = \sum_k W_k S_k^2 + \sum_k W_k (\bar{Y}_k - \bar{Y})^2$$
Stratified sampling eliminates the between-stratum variance from our estimation error. When strata are homogeneous internally but different from each other, the gains are substantial.
For Classification: Stratifying by class labels eliminates variance due to class proportion differences between splits—exactly the problem we identified.
Stratification helps most when strata differ significantly. For class-based stratification, this is almost always true: fraud cases are very different from legitimate cases, so ensuring both are represented reduces evaluation variance. If strata were identical, stratification wouldn't hurt but wouldn't help either.
Integer Constraints and Rounding
In practice, we can't have fractional samples. If stratum $k$ should contribute $n_k = 5.7$ samples, we must round.
Common strategies:
For ML purposes, strategy 2 or 3 works well. Most libraries implement this automatically.
Minimum Samples per Stratum
A critical constraint: every stratum needs at least one sample (preferably more). For rare classes: $$n_k = \max(\text{minimum}, \lfloor n \cdot W_k \rfloor)$$
This may require larger test sets than planned, or accepting that rare classes are over-represented in test relative to training.
Classification is the most straightforward case for stratification: use class labels as strata. Let's examine implementation in detail.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
import numpy as npfrom sklearn.model_selection import train_test_split, StratifiedShuffleSplitfrom collections import Counter # ============================================# Basic Stratified Split# ============================================# Generate imbalanced datanp.random.seed(42)n_samples = 1000X = np.random.randn(n_samples, 10)y = np.array([0]*900 + [1]*80 + [2]*20) # 90%, 8%, 2% distribution print("Original class distribution:")print(Counter(y)) # Non-stratified split (WRONG for imbalanced data)X_train_wrong, X_test_wrong, y_train_wrong, y_test_wrong = train_test_split( X, y, test_size=0.2, random_state=42, shuffle=True)print("\nNon-stratified test distribution:")print(Counter(y_test_wrong)) # Stratified split (CORRECT)X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y)print("\nStratified test distribution:")print(Counter(y_test)) # ============================================# Verifying Stratification Quality# ============================================def verify_stratification(y_full, y_train, y_test, tolerance=0.02): """ Verify that train and test sets maintain class proportions. Parameters: ----------- tolerance : Maximum acceptable deviation from original proportions Returns: -------- bool : True if stratification is acceptable """ full_dist = {cls: count/len(y_full) for cls, count in Counter(y_full).items()} train_dist = {cls: count/len(y_train) for cls, count in Counter(y_train).items()} test_dist = {cls: count/len(y_test) for cls, count in Counter(y_test).items()} print(f"{'Class':<10} {'Original':<12} {'Train':<12} {'Test':<12} {'Status'}") print("-" * 60) all_ok = True for cls in sorted(full_dist.keys()): orig = full_dist.get(cls, 0) train = train_dist.get(cls, 0) test = test_dist.get(cls, 0) train_dev = abs(train - orig) test_dev = abs(test - orig) status = "✓" if max(train_dev, test_dev) < tolerance else "⚠" all_ok = all_ok and (status == "✓") print(f"{cls:<10} {orig:<12.4f} {train:<12.4f} {test:<12.4f} {status}") return all_ok is_valid = verify_stratification(y, y_train, y_test)print(f"\nStratification valid: {is_valid}") # ============================================# Multiple Stratified Splits (for Monte Carlo)# ============================================sss = StratifiedShuffleSplit( n_splits=5, # Generate 5 different stratified splits test_size=0.2, random_state=42) for i, (train_idx, test_idx) in enumerate(sss.split(X, y)): test_dist = Counter(y[test_idx]) print(f"Split {i+1}: {dict(test_dist)}")Multi-Class Stratification
For multi-class problems, stratification extends naturally: each class becomes a stratum. The key considerations:
Many classes: With 100+ classes, some may have very few samples. Ensure minimum representation.
Class hierarchy: If classes have a hierarchy (e.g., species → genus → family), stratify by the most appropriate level.
Rare classes: Classes with fewer samples than needed for representation may need special handling:
For stratification to work, each class needs at least K samples where K = num_splits (for cross-validation) or K = 2 (for single train-test split). With fewer samples, stratification fails. Either merge rare classes, use leave-one-out, or accept that some classes won't be in all folds.
Regression targets are continuous, so direct stratification isn't possible. The solution: bin the target into discrete strata and stratify by bins. This ensures the target distribution is preserved across splits.
Binning Strategies
1. Quantile-Based Binning Divide target into percentile-based bins. Each bin has approximately equal samples.
y_binned = pd.qcut(y, q=10, labels=False, duplicates='drop')
Advantages:
2. Fixed-Width Binning Divide target range into equal-width intervals.
y_binned = pd.cut(y, bins=10, labels=False)
Advantages:
Disadvantages:
3. Custom Binning Define bins based on domain knowledge.
bins = [0, 10, 50, 100, 500, float('inf')]
y_binned = pd.cut(y, bins=bins, labels=False)
Advantages:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
import numpy as npimport pandas as pdfrom sklearn.model_selection import train_test_splitfrom scipy import stats # ============================================# Create Skewed Regression Target# ============================================np.random.seed(42)n_samples = 1000X = np.random.randn(n_samples, 10)y = np.exp(np.random.randn(n_samples)) # Log-normal distribution (skewed) print(f"Target statistics:")print(f" Mean: {y.mean():.2f}, Median: {np.median(y):.2f}")print(f" Min: {y.min():.2f}, Max: {y.max():.2f}")print(f" Skewness: {stats.skew(y):.2f}") # ============================================# Quantile-Based Stratification# ============================================n_bins = 10y_binned_quantile = pd.qcut(y, q=n_bins, labels=False, duplicates='drop') X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y_binned_quantile) # Verify distribution preservationprint(f"\nOriginal target - Mean: {y.mean():.3f}, Std: {y.std():.3f}")print(f"Training target - Mean: {y_train.mean():.3f}, Std: {y_train.std():.3f}")print(f"Test target - Mean: {y_test.mean():.3f}, Std: {y_test.std():.3f}") # ============================================# Quantile Comparison (More Thorough Check)# ============================================def compare_distributions(y_full, y_train, y_test): """Compare quantile distributions across splits.""" quantiles = [0.1, 0.25, 0.5, 0.75, 0.9] print(f"\n{'Quantile':<12} {'Original':<12} {'Train':<12} {'Test':<12}") print("-" * 50) for q in quantiles: orig_q = np.quantile(y_full, q) train_q = np.quantile(y_train, q) test_q = np.quantile(y_test, q) print(f"{q:<12} {orig_q:<12.3f} {train_q:<12.3f} {test_q:<12.3f}") # Statistical test for distribution equality ks_stat, ks_pval = stats.ks_2samp(y_train, y_test) print(f"\nKS test (train vs test): statistic={ks_stat:.4f}, p-value={ks_pval:.4f}") if ks_pval > 0.05: print("Distributions are not significantly different (good!)") else: print("WARNING: Distributions differ significantly") compare_distributions(y, y_train, y_test) # ============================================# Handling Extreme Values and Outliers# ============================================def robust_quantile_bins(y, n_bins=10, outlier_percentile=5): """ Create bins that handle outliers gracefully. Strategy: Clip extreme values before binning, ensuring outliers don't dominate bin edges. """ lower = np.percentile(y, outlier_percentile) upper = np.percentile(y, 100 - outlier_percentile) y_clipped = np.clip(y, lower, upper) try: bins = pd.qcut(y_clipped, q=n_bins, labels=False, duplicates='drop') except ValueError: # Fall back to fewer bins if duplicates cause issues bins = pd.qcut(y_clipped, q=n_bins//2, labels=False, duplicates='drop') return bins y_binned_robust = robust_quantile_bins(y, n_bins=10)More bins = finer stratification = better distribution matching, BUT each bin needs adequate samples. A rule of thumb: n_bins ≤ n_samples / (10 × num_splits). For 1000 samples and 5-fold CV, use ≤ 20 bins. For single train-test, up to 50 bins may work.
Multi-label classification—where each sample can belong to multiple classes simultaneously—presents a unique stratification challenge. Simple label-based stratification doesn't work because samples exist in a combinatorial space of label sets.
The Challenge
With $K$ possible labels, there are $2^K$ possible label combinations. For $K = 20$ labels, that's over 1 million combinations. Most combinations appear rarely or never, making stratification by exact label set infeasible.
The Iterative Stratification Approach
The gold standard for multi-label stratification is iterative stratification (Sechidis et al., 2011), which:
This ensures that even rare labels appear in all folds proportionally, while common labels naturally distribute well.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
import numpy as npfrom iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplitfrom sklearn.datasets import make_multilabel_classificationfrom collections import Counter # ============================================# Generate Multi-Label Data# ============================================X, Y = make_multilabel_classification( n_samples=1000, n_features=20, n_classes=10, n_labels=3, # Average labels per sample random_state=42) print(f"Multi-label data shape: X={X.shape}, Y={Y.shape}")print(f"Label frequencies: {Y.sum(axis=0)}") # Sum per label # ============================================# Iterative Stratified Split# ============================================from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit msss = MultilabelStratifiedShuffleSplit( n_splits=1, test_size=0.2, random_state=42) for train_idx, test_idx in msss.split(X, Y): X_train, X_test = X[train_idx], X[test_idx] Y_train, Y_test = Y[train_idx], Y[test_idx] # ============================================# Verify Per-Label Proportions# ============================================def verify_multilabel_stratification(Y_full, Y_train, Y_test): """Check that each label is proportionally represented.""" n_labels = Y_full.shape[1] print(f"{'Label':<8} {'Original':<12} {'Train':<12} {'Test':<12} {'Diff':<10}") print("-" * 55) max_diff = 0 for i in range(n_labels): orig_rate = Y_full[:, i].mean() train_rate = Y_train[:, i].mean() test_rate = Y_test[:, i].mean() diff = abs(train_rate - orig_rate) + abs(test_rate - orig_rate) max_diff = max(max_diff, diff) print(f"{i:<8} {orig_rate:<12.4f} {train_rate:<12.4f} {test_rate:<12.4f} {diff:<10.4f}") print(f"\nMaximum deviation from original: {max_diff:.4f}") return max_diff verify_multilabel_stratification(Y, Y_train, Y_test) # ============================================# Alternative: Label Powerset for Small Label Sets# ============================================def label_powerset_stratify(Y, n_bins=50): """ Convert multi-label to single-label by treating each unique label combination as a class. Works for small numbers of unique combinations. """ # Convert each row to a tuple (hashable) label_tuples = [tuple(row) for row in Y] # Map unique combinations to integers unique_combos = list(set(label_tuples)) combo_to_idx = {c: i for i, c in enumerate(unique_combos)} y_powerset = np.array([combo_to_idx[c] for c in label_tuples]) print(f"Found {len(unique_combos)} unique label combinations") # If too many combinations, reduce via clustering if len(unique_combos) > n_bins: print(f"Too many combinations, falling back to iterative stratification") return None return y_powerset y_powerset = label_powerset_stratify(Y)The iterative-stratification library provides production-ready multi-label stratification: pip install iterative-stratification. It handles edge cases like samples with no labels and adapts gracefully when exact stratification is impossible.
Multi-Output Regression Stratification
For multiple continuous targets, the challenge is similar:
# Example: Combine binned targets for stratification
y1_bins = pd.qcut(y[:, 0], q=5, labels=False)
y2_bins = pd.qcut(y[:, 1], q=5, labels=False)
combined_strata = y1_bins * 5 + y2_bins # Creates 25 unique strata
The curse of dimensionality applies: with many outputs, the combinatorial explosion makes exact stratification impossible. Use judgement to prioritize.
When data has natural grouping structure—and you want both group integrity AND stratification—the problem becomes significantly more complex. This is grouped stratification.
Common Scenarios:
The Grouped Stratification Problem
We need to:
These objectives can conflict. If all positive-class samples come from two groups, we can't put one group in training and one in testing while maintaining proportions.
Algorithm: Group-Level Stratification
Compute the group-level label for each group:
Stratify groups (not samples) based on group-level labels
Assign entire groups to train or test based on stratified group assignment
This ensures groups stay together while approximately maintaining stratification.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
import numpy as npimport pandas as pdfrom sklearn.model_selection import train_test_splitfrom collections import defaultdict def grouped_stratified_split( X, y, groups, test_size=0.2, random_state=42, stratify_method='majority'): """ Stratified split that respects group boundaries. Parameters: ----------- X : Feature matrix y : Target vector groups : Group membership for each sample test_size : Fraction for test set stratify_method : 'majority' (class majority) or 'mean' (for regression) Returns: -------- train_indices, test_indices """ np.random.seed(random_state) # Step 1: Compute group-level labels unique_groups = np.unique(groups) group_labels = {} group_indices = defaultdict(list) for idx, (yi, gi) in enumerate(zip(y, groups)): group_indices[gi].append(idx) for g in unique_groups: g_y = y[groups == g] if stratify_method == 'majority': # Use majority class for classification group_labels[g] = np.bincount(g_y).argmax() elif stratify_method == 'mean': # Use mean for regression (after binning) group_labels[g] = np.mean(g_y) # Step 2: Create group-level arrays for stratified splitting group_array = np.array(unique_groups) group_y = np.array([group_labels[g] for g in unique_groups]) # For regression, bin the group means if stratify_method == 'mean': group_y = pd.qcut(group_y, q=5, labels=False, duplicates='drop') # Step 3: Stratified split at group level train_groups, test_groups = train_test_split( group_array, test_size=test_size, random_state=random_state, stratify=group_y ) # Step 4: Convert group assignments to sample indices train_indices = [] test_indices = [] for g in train_groups: train_indices.extend(group_indices[g]) for g in test_groups: test_indices.extend(group_indices[g]) return np.array(train_indices), np.array(test_indices) # ============================================# Example Usage# ============================================np.random.seed(42) # Generate grouped datan_groups = 100samples_per_group = np.random.randint(5, 20, n_groups)n_samples = samples_per_group.sum() groups = np.repeat(np.arange(n_groups), samples_per_group)X = np.random.randn(n_samples, 10) # Make labels correlated with groups (some groups are mostly positive)y = np.zeros(n_samples, dtype=int)for g in np.arange(n_groups): group_mask = groups == g # Each group has consistent class bias bias = np.random.rand() y[group_mask] = (np.random.rand(group_mask.sum()) < bias).astype(int) # Perform grouped stratified splittrain_idx, test_idx = grouped_stratified_split(X, y, groups, test_size=0.2) # Verifyprint(f"Total samples: {n_samples}")print(f"Training samples: {len(train_idx)}")print(f"Test samples: {len(test_idx)}") # Check group integritytrain_groups = set(groups[train_idx])test_groups = set(groups[test_idx])overlap = train_groups & test_groupsprint(f"\nGroup overlap between train and test: {len(overlap)} (should be 0)") # Check stratificationprint(f"\nClass proportions:")print(f" Original: {y.mean():.4f}")print(f" Training: {y[train_idx].mean():.4f}")print(f" Test: {y[test_idx].mean():.4f}")With grouped data, exact stratification is often impossible. If one group contains all rare-class samples, you can't split that class across folds while keeping the group intact. Accept approximate stratification and verify the deviation is acceptable for your use case.
Stratification should be verified, not assumed. A comprehensive verification process catches implementation errors and identifies cases where stratification is only approximate.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
import numpy as npfrom scipy import statsfrom typing import Dict, List, Tuple class StratificationVerifier: """ Comprehensive verification of stratification quality. Checks class proportions, statistical tests, and potential issues. """ def __init__(self, tolerance: float = 0.05): """ Parameters: ----------- tolerance : Maximum acceptable deviation from original proportions """ self.tolerance = tolerance def verify_classification( self, y_full: np.ndarray, y_train: np.ndarray, y_test: np.ndarray ) -> Dict: """ Verify stratification for classification targets. """ results = { 'is_valid': True, 'class_analysis': {}, 'warnings': [], 'errors': [] } # Compute proportions classes = np.unique(np.concatenate([y_full, y_train, y_test])) for cls in classes: full_prop = (y_full == cls).mean() train_prop = (y_train == cls).mean() if len(y_train) > 0 else 0 test_prop = (y_test == cls).mean() if len(y_test) > 0 else 0 train_dev = abs(train_prop - full_prop) test_dev = abs(test_prop - full_prop) results['class_analysis'][cls] = { 'full_prop': full_prop, 'train_prop': train_prop, 'test_prop': test_prop, 'train_deviation': train_dev, 'test_deviation': test_dev, 'within_tolerance': max(train_dev, test_dev) <= self.tolerance } if max(train_dev, test_dev) > self.tolerance: results['warnings'].append( f"Class {cls}: deviation {max(train_dev, test_dev):.4f} > tolerance {self.tolerance}" ) results['is_valid'] = False # Check for missing classes if (y_train == cls).sum() == 0: results['errors'].append(f"Class {cls} missing from training set!") results['is_valid'] = False if (y_test == cls).sum() == 0: results['errors'].append(f"Class {cls} missing from test set!") results['is_valid'] = False return results def verify_regression( self, y_full: np.ndarray, y_train: np.ndarray, y_test: np.ndarray ) -> Dict: """ Verify stratification for regression targets. Uses statistical tests to compare distributions. """ results = { 'is_valid': True, 'statistics': {}, 'tests': {}, 'warnings': [] } # Basic statistics for name, arr in [('full', y_full), ('train', y_train), ('test', y_test)]: results['statistics'][name] = { 'mean': float(np.mean(arr)), 'median': float(np.median(arr)), 'std': float(np.std(arr)), 'min': float(np.min(arr)), 'max': float(np.max(arr)), 'skewness': float(stats.skew(arr)) } # Kolmogorov-Smirnov tests ks_train_full = stats.ks_2samp(y_train, y_full) ks_test_full = stats.ks_2samp(y_test, y_full) ks_train_test = stats.ks_2samp(y_train, y_test) results['tests'] = { 'ks_train_vs_full': {'statistic': ks_train_full.statistic, 'pvalue': ks_train_full.pvalue}, 'ks_test_vs_full': {'statistic': ks_test_full.statistic, 'pvalue': ks_test_full.pvalue}, 'ks_train_vs_test': {'statistic': ks_train_test.statistic, 'pvalue': ks_train_test.pvalue} } # Flag significant differences (p < 0.01) if ks_train_full.pvalue < 0.01: results['warnings'].append("Training distribution significantly differs from full") results['is_valid'] = False if ks_test_full.pvalue < 0.01: results['warnings'].append("Test distribution significantly differs from full") results['is_valid'] = False return results def verify_multilabel( self, Y_full: np.ndarray, Y_train: np.ndarray, Y_test: np.ndarray ) -> Dict: """ Verify stratification for multi-label targets. Checks per-label proportions and label co-occurrence. """ results = { 'is_valid': True, 'per_label_analysis': {}, 'warnings': [] } n_labels = Y_full.shape[1] for i in range(n_labels): full_prop = Y_full[:, i].mean() train_prop = Y_train[:, i].mean() test_prop = Y_test[:, i].mean() train_dev = abs(train_prop - full_prop) test_dev = abs(test_prop - full_prop) results['per_label_analysis'][i] = { 'full_prop': full_prop, 'train_prop': train_prop, 'test_prop': test_prop, 'max_deviation': max(train_dev, test_dev) } if max(train_dev, test_dev) > self.tolerance: results['warnings'].append( f"Label {i}: deviation {max(train_dev, test_dev):.4f}" ) # Don't mark as invalid for small deviations in multi-label if max(train_dev, test_dev) > 2 * self.tolerance: results['is_valid'] = False return results def generate_report(self, results: Dict) -> str: """Generate human-readable verification report.""" lines = ["=" * 60, "STRATIFICATION VERIFICATION REPORT", "=" * 60] lines.append(f"\nOverall Status: {'✓ VALID' if results['is_valid'] else '✗ ISSUES DETECTED'}") if results.get('warnings'): lines.append("\nWarnings:") for w in results['warnings']: lines.append(f" ⚠ {w}") if results.get('errors'): lines.append("\nErrors:") for e in results['errors']: lines.append(f" ✗ {e}") lines.append("\n" + "=" * 60) return "\n".join(lines) # Usage exampleverifier = StratificationVerifier(tolerance=0.03)results = verifier.verify_classification(y, y_train, y_test)print(verifier.generate_report(results))Stratification transforms random sampling from a hope to a guarantee. Let's consolidate the essential principles:
When NOT to Stratify
Stratification is almost always beneficial, but there are edge cases:
The Next Step: Understanding Randomness
Stratification controls one source of variability—distribution mismatch between splits. But randomness still affects results through the choice of which samples within each stratum end up in which partition. The next page explores random seeds and reproducibility.
You now understand stratification at a production depth—when to use it, how to implement it for various data types, and how to verify quality. Next, we'll explore random seeds and reproducibility: ensuring your experiments can be replicated exactly.