Loading content...
You're building a disease classifier from medical images. Your dataset has:
Using Group K-Fold prevents patient leakage, but random group assignment might put all high-disease-rate patients in one fold. Using Stratified K-Fold balances classes, but might split patient images across folds.
You need both: group integrity AND class balance. This is the problem of grouped stratification (also called "stratified group k-fold")—partitioning groups into folds such that each fold has similar class proportions while never splitting groups.
This is fundamentally harder than either problem alone. With standard stratification, you can freely assign samples to balance classes. With groups, your assignment unit is the entire group, and groups have varying internal class distributions. Perfect stratification is often impossible; we seek the best approximation.
By the end of this page, you will understand why grouped stratification is NP-hard, master approximate algorithms that work well in practice, implement production-ready solutions, and develop intuition for when perfect stratification is achievable versus when approximation must suffice.
Why Grouped Stratification Is Hard
In standard stratification, we distribute individual samples. Each sample has a single class label, and we can place it in any fold. The problem reduces to counting.
In grouped stratification, each group is a fixed "bundle" of samples with a specific class distribution. We must assign entire bundles to folds without breaking them apart.
Formal Problem Statement
Given:
Objective: Assign groups to folds such that each fold's class distribution is as close to proportional as possible.
This is a variant of the multi-way number partitioning problem, which is NP-hard. With $G$ groups and $k$ folds, there are $k^G$ possible assignments—exponential in the number of groups.
When Perfect Stratification Is Possible
Perfect stratification (exact proportions in every fold) is possible only when:
In practice, neither condition holds. Groups have varying sizes and class mixes, making perfect stratification impossible.
Unlike standard stratification which guarantees near-exact proportions, grouped stratification is an optimization problem. We minimize deviation from ideal proportions, accepting that some deviation is inevitable. Success is measured by the magnitude of deviation, not its absence.
Several algorithms approximate optimal grouped stratification. We'll examine three main approaches:
Approach 1: Greedy Class-Priority Assignment
Assign groups to folds greedily, prioritizing balance of the minority class:
This prioritizes the most constrained class (minority) while letting majority class balance follow.
Approach 2: Multi-Objective Optimization
Define a loss function that penalizes deviation from target proportions:
$$L(\text{assignment}) = \sum_{\text{fold } f} \sum_{\text{class } c} w_c \left| \frac{n_{f,c}}{n_f} - p_c \right|$$
where $w_c$ is a weight (higher for rare classes), $n_{f,c}$ is class $c$ count in fold $f$, $n_f$ is total size of fold $f$, and $p_c$ is the target proportion.
Use greedy, simulated annealing, or genetic algorithms to minimize this loss.
Approach 3: Iterative Stratification (Multi-Label Style)
Adapt the iterative stratification algorithm from multi-label to grouped data:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
import numpy as npfrom collections import defaultdictfrom typing import List, Dict, Tuple, Any class StratifiedGroupKFold: """ Stratified Group K-Fold cross-validator. Balances class distributions across folds while keeping all samples from the same group together. """ def __init__( self, n_splits: int = 5, shuffle: bool = True, random_state: int = None ): self.n_splits = n_splits self.shuffle = shuffle self.random_state = random_state def split( self, X: np.ndarray, y: np.ndarray, groups: np.ndarray ) -> List[Tuple[np.ndarray, np.ndarray]]: """ Generate stratified group k-fold splits. """ if self.random_state is not None: np.random.seed(self.random_state) n_samples = len(y) unique_groups = np.unique(groups) n_groups = len(unique_groups) classes = np.unique(y) n_classes = len(classes) k = self.n_splits if n_groups < k: raise ValueError( f"Cannot have more splits ({k}) than groups ({n_groups})" ) # Compute per-group class counts group_class_counts = {} # group -> {class: count} group_indices = {} # group -> list of indices for group in unique_groups: mask = groups == group group_indices[group] = np.where(mask)[0] y_group = y[mask] group_class_counts[group] = { c: int((y_group == c).sum()) for c in classes } # Target class counts per fold total_class_counts = { c: int((y == c).sum()) for c in classes } target_per_fold = { c: total_class_counts[c] / k for c in classes } # Initialize fold assignments and current counts fold_class_counts = [{c: 0 for c in classes} for _ in range(k)] fold_groups = [[] for _ in range(k)] # Sort groups by minority class count (descending) # This ensures we handle constrained groups first minority_class = min(classes, key=lambda c: total_class_counts[c]) sorted_groups = sorted( unique_groups, key=lambda g: group_class_counts[g].get(minority_class, 0), reverse=True ) if self.shuffle: # Partial shuffle: shuffle within size tiers tier_size = max(1, len(sorted_groups) // 4) for i in range(0, len(sorted_groups), tier_size): tier = sorted_groups[i:i + tier_size] np.random.shuffle(tier) sorted_groups[i:i + tier_size] = tier # Greedy assignment for group in sorted_groups: group_counts = group_class_counts[group] # Find fold with greatest need for this group's class distribution # Score = sum of (target - current) * group_contribution for each class best_fold = None best_score = float('-inf') for fold_idx in range(k): score = 0 for c in classes: need = target_per_fold[c] - fold_class_counts[fold_idx][c] contribution = group_counts.get(c, 0) # Weight by class rarity weight = 1.0 / (total_class_counts[c] + 1) score += weight * min(need, contribution) # Small tiebreaker: prefer smaller folds score -= 0.001 * sum(fold_class_counts[fold_idx].values()) if score > best_score: best_score = score best_fold = fold_idx # Assign group to best fold fold_groups[best_fold].append(group) for c in classes: fold_class_counts[best_fold][c] += group_counts.get(c, 0) # Store diagnostics self.fold_class_counts_ = fold_class_counts self.target_per_fold_ = target_per_fold # Convert to train/test indices splits = [] for fold_idx in range(k): test_groups = set(fold_groups[fold_idx]) test_indices = [] train_indices = [] for group in unique_groups: indices = group_indices[group].tolist() if group in test_groups: test_indices.extend(indices) else: train_indices.extend(indices) splits.append(( np.array(train_indices), np.array(test_indices) )) return splits def get_diagnostics(self) -> Dict[str, Any]: """Return stratification quality diagnostics.""" if not hasattr(self, 'fold_class_counts_'): return {} diagnostics = { 'target_per_fold': self.target_per_fold_, 'actual_per_fold': self.fold_class_counts_, 'deviations': [] } for fold_idx, actual in enumerate(self.fold_class_counts_): fold_total = sum(actual.values()) fold_devs = {} for c in actual: target_prop = self.target_per_fold_[c] / sum(self.target_per_fold_.values()) actual_prop = actual[c] / fold_total if fold_total > 0 else 0 fold_devs[c] = actual_prop - target_prop diagnostics['deviations'].append(fold_devs) return diagnostics def evaluate_stratification_quality( y: np.ndarray, groups: np.ndarray, splits: List[Tuple[np.ndarray, np.ndarray]]) -> Dict[str, Any]: """ Evaluate how well stratification was achieved. """ classes = np.unique(y) overall_dist = {c: (y == c).mean() for c in classes} fold_metrics = [] max_deviations = {c: 0.0 for c in classes} for fold_idx, (train_idx, test_idx) in enumerate(splits): y_test = y[test_idx] fold_dist = {c: (y_test == c).mean() if len(y_test) > 0 else 0 for c in classes} deviations = {c: abs(fold_dist[c] - overall_dist[c]) for c in classes} for c in classes: max_deviations[c] = max(max_deviations[c], deviations[c]) # Test group integrity test_groups = set(groups[test_idx]) train_groups = set(groups[train_idx]) overlap = test_groups.intersection(train_groups) fold_metrics.append({ 'fold': fold_idx + 1, 'test_size': len(test_idx), 'test_groups': len(test_groups), 'distribution': fold_dist, 'deviations': deviations, 'group_leak': len(overlap) > 0 }) return { 'overall_distribution': overall_dist, 'fold_metrics': fold_metrics, 'max_class_deviations': max_deviations, 'any_group_leaks': any(fm['group_leak'] for fm in fold_metrics) } # Demonstrationif __name__ == "__main__": np.random.seed(42) # Create grouped imbalanced dataset n_groups = 30 X_list, y_list, groups_list = [], [], [] for group_id in range(n_groups): group_size = np.random.randint(20, 100) # Vary class balance per group group_pos_rate = np.random.beta(1, 9) # Most groups have low positive rate X_g = np.random.randn(group_size, 10) y_g = (np.random.random(group_size) < group_pos_rate).astype(int) X_list.append(X_g) y_list.append(y_g) groups_list.extend([f"G{group_id:02d}"] * group_size) X = np.vstack(X_list) y = np.concatenate(y_list) groups = np.array(groups_list) print(f"Dataset: {len(y)} samples, {n_groups} groups") print(f"Overall class distribution: {dict(zip(*np.unique(y, return_counts=True)))}") print(f"Overall positive rate: {y.mean():.1%}") print() # Compare with non-stratified GroupKFold from sklearn.model_selection import GroupKFold print("=" * 60) print("STANDARD GROUP K-FOLD (NO STRATIFICATION)") print("=" * 60) gkf = GroupKFold(n_splits=5) gkf_splits = list(gkf.split(X, y, groups)) gkf_quality = evaluate_stratification_quality(y, groups, gkf_splits) for fm in gkf_quality['fold_metrics']: print(f"Fold {fm['fold']}: n={fm['test_size']}, pos_rate={fm['distribution'][1]:.1%}, " f"deviation={fm['deviations'][1]:.1%}") print(f"Max deviation (positive class): {gkf_quality['max_class_deviations'][1]:.1%}") print() print("=" * 60) print("STRATIFIED GROUP K-FOLD") print("=" * 60) sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42) sgkf_splits = sgkf.split(X, y, groups) sgkf_quality = evaluate_stratification_quality(y, groups, sgkf_splits) for fm in sgkf_quality['fold_metrics']: print(f"Fold {fm['fold']}: n={fm['test_size']}, pos_rate={fm['distribution'][1]:.1%}, " f"deviation={fm['deviations'][1]:.1%}") print(f"Max deviation (positive class): {sgkf_quality['max_class_deviations'][1]:.1%}") print(f"Any group leaks: {sgkf_quality['any_group_leaks']}")Since perfect stratification is often impossible, we need metrics to quantify how close we got.
Metric 1: Maximum Absolute Deviation (MAD)
For each class $c$, the maximum deviation from target proportion across all folds:
$$\text{MAD}c = \max{f \in \text{folds}} \left| \frac{n_{f,c}}{n_f} - p_c \right|$$
Lower is better. MAD = 0 means perfect stratification for that class.
Metric 2: Root Mean Square Deviation (RMSD)
Average deviation across all classes and folds:
$$\text{RMSD} = \sqrt{\frac{1}{k \cdot C} \sum_{f} \sum_{c} \left( \frac{n_{f,c}}{n_f} - p_c \right)^2}$$
Penalizes large deviations more heavily than MAD.
Metric 3: Class-Weighted Deviation
Weight deviations by class rarity (rare class deviations matter more):
$$\text{WD} = \sum_{f} \sum_{c} \frac{1}{p_c} \left| \frac{n_{f,c}}{n_f} - p_c \right|$$
Metric 4: Group Integrity Check
Binary check: Does any group appear in both train and test for any fold?
$$\text{Integrity} = \begin{cases} 1 & \text{if no leakage} \ 0 & \text{if any leakage} \end{cases}$$
This should always be 1; any failure is critical.
| Metric | Formula | Interpretation | Target |
|---|---|---|---|
| MAD (per class) | |actual_prop - target_prop|_max | Worst-case deviation | < 0.05 (5%) |
| RMSD | sqrt(mean(deviations²)) | Average deviation | < 0.02 (2%) |
| Weighted deviation | sum(deviation / class_prop) | Rare-class-focused | Minimize |
| Group integrity | 1 if no leakage | Validity check | = 1 (mandatory) |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
import numpy as npfrom typing import Dict, List, Tuple def compute_stratification_metrics( y: np.ndarray, groups: np.ndarray, splits: List[Tuple[np.ndarray, np.ndarray]]) -> Dict[str, float]: """ Compute comprehensive stratification quality metrics. """ classes = np.unique(y) k = len(splits) n_total = len(y) # Target proportions target_props = {c: (y == c).sum() / n_total for c in classes} # Collect deviations all_deviations = {c: [] for c in classes} group_integrity_violations = 0 for train_idx, test_idx in splits: y_test = y[test_idx] n_test = len(y_test) # Check group integrity train_groups = set(groups[train_idx]) test_groups = set(groups[test_idx]) if train_groups.intersection(test_groups): group_integrity_violations += 1 # Compute deviations for c in classes: actual_prop = (y_test == c).sum() / n_test if n_test > 0 else 0 deviation = actual_prop - target_props[c] all_deviations[c].append(deviation) # Compute metrics metrics = {} # MAD per class for c in classes: metrics[f'mad_class_{c}'] = max(abs(d) for d in all_deviations[c]) metrics['mad_max'] = max(metrics[f'mad_class_{c}'] for c in classes) # RMSD all_sq_devs = [] for c in classes: all_sq_devs.extend([d**2 for d in all_deviations[c]]) metrics['rmsd'] = np.sqrt(np.mean(all_sq_devs)) # Weighted deviation (by rarity) weighted_sum = 0 for c in classes: weight = 1 / target_props[c] if target_props[c] > 0 else 0 weighted_sum += sum(abs(d) * weight for d in all_deviations[c]) metrics['weighted_deviation'] = weighted_sum / (k * len(classes)) # Group integrity metrics['group_integrity'] = 1.0 if group_integrity_violations == 0 else 0.0 metrics['n_integrity_violations'] = group_integrity_violations return metrics def print_quality_report(metrics: Dict[str, float]) -> None: """Print formatted quality report.""" print("Stratification Quality Report") print("-" * 40) print(f"Group Integrity: {'✓ PASSED' if metrics['group_integrity'] == 1 else '✗ FAILED'}") if metrics['n_integrity_violations'] > 0: print(f" ⚠️ {metrics['n_integrity_violations']} folds have group leakage!") print(f"Class Deviation Metrics:") print(f" Max Absolute Deviation: {metrics['mad_max']:.2%}") print(f" RMSD: {metrics['rmsd']:.4f}") print(f" Weighted Deviation: {metrics['weighted_deviation']:.4f}") # Quality rating if metrics['mad_max'] < 0.03: quality = "Excellent" elif metrics['mad_max'] < 0.05: quality = "Good" elif metrics['mad_max'] < 0.10: quality = "Acceptable" else: quality = "Poor" print(f"Overall Quality: {quality}")For most applications, MAD < 5% is acceptable, MAD < 3% is good, and MAD < 1% is excellent. If you're seeing MAD > 10%, either your groups are too heterogeneous or you have too few groups. Consider increasing k (fewer groups per fold) or collecting more data.
As of scikit-learn 1.0+, StratifiedGroupKFold is available natively. Let's explore its usage and compare with our custom implementation.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
from sklearn.model_selection import StratifiedGroupKFold, cross_validatefrom sklearn.ensemble import RandomForestClassifierfrom sklearn.preprocessing import StandardScalerfrom sklearn.pipeline import Pipelineimport numpy as np def production_stratified_group_cv( model, X: np.ndarray, y: np.ndarray, groups: np.ndarray, n_splits: int = 5, random_state: int = 42, verbose: bool = True): """ Production-ready stratified group cross-validation. """ # Create stratified group k-fold splitter sgkf = StratifiedGroupKFold( n_splits=n_splits, shuffle=True, random_state=random_state ) # Create preprocessing pipeline pipeline = Pipeline([ ('scaler', StandardScaler()), ('model', model) ]) # Define scoring scoring = { 'accuracy': 'accuracy', 'f1': 'f1', 'roc_auc': 'roc_auc', 'precision': 'precision', 'recall': 'recall' } # Perform cross-validation cv_results = cross_validate( pipeline, X, y, cv=sgkf, groups=groups, # Required for group-aware CV scoring=scoring, return_train_score=True, n_jobs=-1 ) if verbose: # Print fold details print("=" * 60) print("STRATIFIED GROUP K-FOLD CROSS-VALIDATION") print("=" * 60) # Analyze each fold overall_pos_rate = y.mean() print(f"Overall positive rate: {overall_pos_rate:.2%}") print(f"Fold Analysis:") for fold_idx, (train_idx, test_idx) in enumerate(sgkf.split(X, y, groups)): test_pos_rate = y[test_idx].mean() n_test_groups = len(np.unique(groups[test_idx])) deviation = abs(test_pos_rate - overall_pos_rate) print(f" Fold {fold_idx + 1}: n={len(test_idx):4d}, " f"groups={n_test_groups:2d}, " f"pos_rate={test_pos_rate:.2%} " f"(deviation: {deviation:.2%})") print(f"Metrics Summary:") for metric in scoring: test_scores = cv_results[f'test_{metric}'] train_scores = cv_results[f'train_{metric}'] print(f" {metric:12s}: {test_scores.mean():.4f} ± {test_scores.std():.4f} " f"(train: {train_scores.mean():.4f})") return cv_results # Demonstrationif __name__ == "__main__": np.random.seed(42) # Create realistic medical imaging dataset n_patients = 60 X_list, y_list, groups_list = [], [], [] for patient_id in range(n_patients): # Number of scans per patient n_scans = np.random.randint(5, 30) # Patient-level disease probability # Some patients are definitely diseased, some healthy if patient_id % 10 == 0: # 10% of patients have disease patient_disease_prob = 0.8 else: patient_disease_prob = 0.05 # Generate scans X_p = np.random.randn(n_scans, 20) y_p = (np.random.random(n_scans) < patient_disease_prob).astype(int) X_list.append(X_p) y_list.append(y_p) groups_list.extend([f"patient_{patient_id:03d}"] * n_scans) X = np.vstack(X_list) y = np.concatenate(y_list) groups = np.array(groups_list) print(f"Dataset: {len(y)} scans from {n_patients} patients") print(f"Class distribution: {np.bincount(y)} (positive rate: {y.mean():.2%})") print() model = RandomForestClassifier(n_estimators=100, random_state=42) results = production_stratified_group_cv(model, X, y, groups)StratifiedGroupKFold requires both y (for stratification) AND groups (for group separation). Omitting the groups= parameter in cross_validate() silently falls back to regular StratifiedKFold, losing group separation. Always verify with a diagnostic check.
Grouped stratification encounters several challenging edge cases:
Handling Homogeneous Groups
Groups that are internally homogeneous (all samples have the same class) are the primary cause of imperfect stratification. Here's how to diagnose and handle them:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
import numpy as npfrom collections import Counter def analyze_group_homogeneity( y: np.ndarray, groups: np.ndarray) -> dict: """ Analyze group-level class distributions to identify stratification constraints. """ unique_groups = np.unique(groups) classes = np.unique(y) homogeneous_groups = {c: [] for c in classes} mixed_groups = [] for group in unique_groups: mask = groups == group y_group = y[mask] unique_classes = np.unique(y_group) if len(unique_classes) == 1: # Homogeneous group homogeneous_groups[unique_classes[0]].append(group) else: # Mixed group mixed_groups.append({ 'group': group, 'distribution': dict(Counter(y_group)), 'n_samples': len(y_group) }) # Compute stratifiability score # Lower = more constrained n_total = len(unique_groups) n_mixed = len(mixed_groups) stratifiability = n_mixed / n_total if n_total > 0 else 0 return { 'n_groups': len(unique_groups), 'n_classes': len(classes), 'homogeneous_groups': homogeneous_groups, 'n_homogeneous': {c: len(gs) for c, gs in homogeneous_groups.items()}, 'mixed_groups': mixed_groups, 'n_mixed': len(mixed_groups), 'stratifiability_score': stratifiability, 'recommendation': get_recommendation(stratifiability) } def get_recommendation(stratifiability: float) -> str: """Get recommendation based on stratifiability score.""" if stratifiability > 0.7: return "High stratifiability: Standard algorithms should work well." elif stratifiability > 0.4: return "Moderate stratifiability: Expect some deviation, but manageable." elif stratifiability > 0.2: return "Low stratifiability: Consider using more folds or fewer groups." else: return "Very low stratifiability: Consider binned target or relaxed constraints." # Example usageif __name__ == "__main__": np.random.seed(42) # Simulate varying homogeneity n_groups = 50 groups_list = [] y_list = [] for g in range(n_groups): size = np.random.randint(10, 50) if g < 15: # 30% pure positive y_g = np.ones(size) elif g < 35: # 40% pure negative y_g = np.zeros(size) else: # 30% mixed y_g = (np.random.random(size) < 0.3).astype(int) groups_list.extend([f"G{g:02d}"] * size) y_list.extend(y_g) groups = np.array(groups_list) y = np.array(y_list) analysis = analyze_group_homogeneity(y, groups) print("Group Homogeneity Analysis") print("-" * 40) print(f"Total groups: {analysis['n_groups']}") print(f"Homogeneous groups: {sum(analysis['n_homogeneous'].values())}") for c, count in analysis['n_homogeneous'].items(): print(f" Class {c}: {count} pure groups") print(f"Mixed groups: {analysis['n_mixed']}") print(f"Stratifiability score: {analysis['stratifiability_score']:.2f}") print(f"Recommendation: {analysis['recommendation']}")Based on the theoretical and practical considerations discussed, here are best practices for production use:
When conflicts arise: 1) Group integrity (non-negotiable), 2) Minority class stratification (most impactful), 3) Majority class stratification (least constrained). Never sacrifice group integrity for better class balance—you'll get optimistic biased results.
We've comprehensively covered the challenging problem of combining class stratification with group-based cross-validation. Here are the essential takeaways:
You now understand how to handle both class imbalance and group structure. But there's a deeper issue lurking: data leakage. Even with perfect group separation, information can leak through preprocessing, feature engineering, or target encoding. The next page covers Data Leakage Prevention—a critical topic for trustworthy ML evaluation.