Loading learning content...
Imagine you're a teacher grading exams, but you only have answer keys for half the questions. What if you could use your own confident answers on the graded questions to help figure out the rest? This seemingly circular logic—using your own predictions to improve yourself—forms the foundation of self-training, one of the oldest and most intuitive semi-supervised learning algorithms.
Self-training, also known as self-labeling or bootstrapping, represents a remarkably simple yet powerful idea: train a model on labeled data, use it to predict labels for unlabeled data, and then retrain on this expanded dataset. Despite its simplicity, this technique has powered breakthroughs in speech recognition, natural language processing, and computer vision for decades.
By the end of this page, you will understand the complete self-training algorithm, its theoretical foundations in pseudo-labeling, confidence estimation strategies, convergence properties, and practical implementation patterns. You'll develop intuition for when self-training succeeds and when it fails catastrophically.
Self-training follows a straightforward iterative procedure that alternates between prediction and retraining. Let's formalize this algorithm precisely.
Problem Setup:
Given:
Goal: Leverage $\mathcal{D}_U$ to improve the classifier trained only on $\mathcal{D}_L$.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
def self_training(D_L, D_U, base_classifier, max_iterations=10, confidence_threshold=0.95): """ Classic Self-Training Algorithm Args: D_L: Labeled dataset [(x, y), ...] D_U: Unlabeled dataset [x, ...] base_classifier: Any classifier with fit() and predict_proba() max_iterations: Maximum number of self-training iterations confidence_threshold: Minimum confidence to accept pseudo-label Returns: Trained classifier f """ # Initialize with labeled data X_L, y_L = zip(*D_L) X_L, y_L = list(X_L), list(y_L) X_U = list(D_U) for iteration in range(max_iterations): # Step 1: Train classifier on current labeled set f = base_classifier.fit(X_L, y_L) # Step 2: Predict labels and confidences for unlabeled data if len(X_U) == 0: break probabilities = f.predict_proba(X_U) predictions = probabilities.argmax(axis=1) confidences = probabilities.max(axis=1) # Step 3: Select high-confidence predictions high_confidence_mask = confidences >= confidence_threshold selected_indices = [i for i, mask in enumerate(high_confidence_mask) if mask] if len(selected_indices) == 0: # No confident predictions, stop or lower threshold print(f"Iteration {iteration}: No high-confidence samples found") break # Step 4: Add pseudo-labeled examples to labeled set for idx in sorted(selected_indices, reverse=True): X_L.append(X_U[idx]) y_L.append(predictions[idx]) X_U.pop(idx) print(f"Iteration {iteration}: Added {len(selected_indices)} samples, " f"Labeled: {len(X_L)}, Unlabeled: {len(X_U)}") # Final training on expanded dataset return base_classifier.fit(X_L, y_L)Algorithm Analysis:
The algorithm makes several key design choices:
Confidence-based Selection: Only predictions exceeding the threshold are converted to pseudo-labels. This is crucial—accepting all predictions would amplify errors.
Iterative Refinement: Each iteration potentially adds more labeled data, allowing the model to improve and make more confident predictions on remaining unlabeled data.
Monotonic Labeled Set Growth: Once a sample is pseudo-labeled, it permanently joins the labeled set. This is both a strength (efficient) and weakness (errors propagate).
| Step | Action | Mathematical Formulation |
|---|---|---|
| Train on labeled data | $f^{(0)} = \text{Train}(\mathcal{D}_L)$ |
| Get predictions on unlabeled | $\hat{y}_j = f^{(t)}(x_j), ; c_j = \max_k P(y=k|x_j)$ |
| Filter by confidence | $\mathcal{D}_{\text{new}} = {(x_j, \hat{y}_j) : c_j \geq \tau}$ |
| Add to labeled set | $\mathcal{D}_L \leftarrow \mathcal{D}L \cup \mathcal{D}{\text{new}}$ |
| Iterate until convergence | $f^{(t+1)} = \text{Train}(\mathcal{D}_L)$ |
Self-training's effectiveness rests on several theoretical assumptions about the data distribution and the classifier's behavior. Understanding these foundations helps us predict when the algorithm will succeed or fail.
The Cluster Assumption:
Self-training implicitly relies on the cluster assumption: points that are close together (in input space or learned representation) belong to the same class. More formally:
$$P(y | x) \text{ varies smoothly across high-density regions of } P(x)$$
This means decision boundaries should pass through low-density regions of the data distribution. When this holds, a classifier trained on a subset of labeled points will correctly generalize to unlabeled points in the same cluster.
Closely related to the cluster assumption is the smoothness assumption: if two points x₁ and x₂ are close in a high-density region, then their labels y₁ and y₂ should be the same. Self-training succeeds when the learned representation respects this smoothness, allowing confident predictions to propagate correctly through the data manifold.
Pseudo-Labeling as Entropy Minimization:
Lee (2013) provided a compelling interpretation of self-training as entropy minimization. When we assign pseudo-labels based on maximum confidence:
$$\tilde{y}_j = \arg\max_k P(y=k | x_j; \theta)$$
We are implicitly encouraging the model to make low-entropy (confident) predictions on unlabeled data. The cross-entropy loss on pseudo-labels:
$$\mathcal{L}{\text{pseudo}} = -\sum{j \in U} \sum_k \mathbb{1}[\tilde{y}_j = k] \log P(y=k | x_j; \theta)$$
Pushes the decision boundaries away from unlabeled points, implementing a form of maximum margin.
Connection to EM Algorithm:
Self-training can be viewed as a hard variant of the Expectation-Maximization (EM) algorithm:
The "hard" nature comes from using argmax instead of maintaining a distribution over possible labels. This makes the algorithm efficient but can lead to premature commitment to incorrect labels.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
import numpy as npimport torchimport torch.nn.functional as F def soft_self_training_loss(model, X_labeled, y_labeled, X_unlabeled, labeled_weight=1.0, unlabeled_weight=0.1, temperature=0.5): """ Soft pseudo-label training (related to self-training). Instead of hard labels, uses sharpened probability distribution. Args: model: Neural network with forward() method X_labeled: Labeled inputs y_labeled: True labels (one-hot or indices) X_unlabeled: Unlabeled inputs labeled_weight: Weight for supervised loss unlabeled_weight: Weight for unsupervised loss temperature: Sharpening temperature (lower = harder labels) Returns: Total loss combining supervised and pseudo-label terms """ # Supervised loss on labeled data logits_labeled = model(X_labeled) loss_supervised = F.cross_entropy(logits_labeled, y_labeled) # Generate soft pseudo-labels for unlabeled data with torch.no_grad(): logits_unlabeled = model(X_unlabeled) probs_unlabeled = F.softmax(logits_unlabeled, dim=1) # Temperature sharpening: T < 1 makes distribution peakier # As T -> 0, this approaches hard pseudo-labels sharpened_probs = F.softmax(logits_unlabeled / temperature, dim=1) # Cross-entropy with soft targets # Note: Using soft targets allows gradient flow through probability mass logits_unlabeled_fresh = model(X_unlabeled) # Fresh forward for gradients log_probs = F.log_softmax(logits_unlabeled_fresh, dim=1) loss_pseudo = -(sharpened_probs * log_probs).sum(dim=1).mean() # Combine losses total_loss = labeled_weight * loss_supervised + unlabeled_weight * loss_pseudo return total_loss, { 'supervised': loss_supervised.item(), 'pseudo': loss_pseudo.item(), 'pseudo_entropy': -(probs_unlabeled * torch.log(probs_unlabeled + 1e-10)).sum(1).mean().item() }Self-training suffers from confirmation bias: the model confirms its own biases by training on its predictions. If the initial model is biased toward a particular class, it will produce confident (but wrong) predictions for that class, reinforcing the bias. This is why initial classifier quality is critical.
The effectiveness of self-training critically depends on accurate confidence estimation. If a model is overconfident on incorrect predictions, those errors will be incorporated as training data, degrading performance. Let's examine various confidence estimation strategies.
Maximum Class Probability (MCP):
The simplest approach uses the maximum predicted probability:
$$c(x) = \max_k P(y=k | x; \theta)$$
However, modern neural networks are notoriously overconfident—they often output high probabilities even for incorrect predictions. This overconfidence is a major weakness of naive self-training.
Temperature Scaling for Calibration:
Before using MCP, we can calibrate predictions using temperature scaling:
$$P(y=k | x; \theta, T) = \frac{\exp(z_k / T)}{\sum_j \exp(z_j / T)}$$
where $z_k$ are logits and $T > 1$ softens the distribution (reducing overconfidence). The temperature is learned on a held-out calibration set to minimize negative log-likelihood.
Monte Carlo Dropout:
Gal and Ghahramani (2016) proposed using dropout at inference time to approximate Bayesian uncertainty:
$$\hat{\sigma}^2(x) = \frac{1}{T} \sum_{t=1}^{T} P(y|x, \theta_t)^2 - \left(\frac{1}{T} \sum_{t=1}^{T} P(y|x, \theta_t)\right)^2$$
where $\theta_t$ represents different dropout masks. High variance indicates uncertainty.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
import numpy as npimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom scipy.special import softmax class ConfidenceEstimator: """ Collection of confidence estimation strategies for self-training. """ @staticmethod def maximum_class_probability(logits): """Basic MCP - maximum softmax probability.""" probs = softmax(logits, axis=-1) return np.max(probs, axis=-1) @staticmethod def margin_confidence(logits): """ Margin between top two class probabilities. Higher margin = more confident distinction between classes. """ probs = softmax(logits, axis=-1) sorted_probs = np.sort(probs, axis=-1) return sorted_probs[:, -1] - sorted_probs[:, -2] @staticmethod def entropy_confidence(logits): """ Negative entropy as confidence. Lower entropy = more peaked distribution = higher confidence. """ probs = softmax(logits, axis=-1) entropy = -np.sum(probs * np.log(probs + 1e-10), axis=-1) # Normalize to [0, 1] range (assuming num_classes) num_classes = probs.shape[-1] max_entropy = np.log(num_classes) return 1 - (entropy / max_entropy) @staticmethod def mc_dropout_confidence(model, x, num_samples=30, dropout_rate=0.5): """ Monte Carlo Dropout for uncertainty estimation. Returns mean prediction and predictive uncertainty. Args: model: PyTorch model with dropout layers x: Input tensor num_samples: Number of stochastic forward passes dropout_rate: Not used if dropout is already in model Returns: mean_probs: Average predicted probabilities uncertainty: Predictive variance (epistemic uncertainty) """ model.train() # Enable dropout predictions = [] with torch.no_grad(): for _ in range(num_samples): logits = model(x) probs = F.softmax(logits, dim=-1) predictions.append(probs.cpu().numpy()) model.eval() predictions = np.stack(predictions, axis=0) # [T, batch, classes] mean_probs = np.mean(predictions, axis=0) # Predictive entropy predictive_entropy = -np.sum(mean_probs * np.log(mean_probs + 1e-10), axis=-1) # Expected entropy (aleatoric uncertainty) expected_entropy = np.mean( -np.sum(predictions * np.log(predictions + 1e-10), axis=-1), axis=0 ) # Mutual information (epistemic uncertainty) epistemic_uncertainty = predictive_entropy - expected_entropy return mean_probs, { 'predictive_entropy': predictive_entropy, 'aleatoric': expected_entropy, 'epistemic': epistemic_uncertainty } @staticmethod def temperature_scaled_confidence(logits, temperature=1.5): """ Apply temperature scaling before computing MCP. Temperature > 1 reduces overconfidence. Args: logits: Raw model outputs (pre-softmax) temperature: Scaling factor (learn on calibration set) """ scaled_logits = logits / temperature probs = softmax(scaled_logits, axis=-1) return np.max(probs, axis=-1)In practice, combine multiple confidence measures. Use margin confidence as a primary filter, temperature scaling for calibration, and optionally MC Dropout for samples near the threshold. This multi-layered approach significantly reduces the incorporation of incorrect pseudo-labels.
Understanding the convergence behavior of self-training helps us design better stopping criteria and identify failure modes before they cause significant damage.
Convergence Conditions:
Self-training converges when one of these conditions is met:
The Dynamics of Error Accumulation:
Let $\epsilon_t$ be the error rate of pseudo-labels at iteration $t$. A simplified analysis shows:
$$\epsilon_{t+1} \approx \epsilon_t \cdot (1 + \alpha \cdot n_{\text{pseudo},t} / n_L)$$
where $\alpha$ reflects how much noisy labels degrade the classifier. This suggests exponential error growth if early iterations introduce errors!
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
import numpy as npfrom dataclasses import dataclassfrom typing import List, Optionalimport matplotlib.pyplot as plt @dataclassclass SelfTrainingMetrics: """Metrics tracked during self-training for monitoring convergence.""" iteration: int labeled_count: int unlabeled_count: int samples_added: int avg_confidence: float min_confidence_added: float max_confidence_added: float pseudo_label_accuracy: Optional[float] = None # If ground truth available model_accuracy_labeled: Optional[float] = None class_distribution: Optional[dict] = None class SelfTrainingMonitor: """ Monitor self-training dynamics and detect potential failure modes. """ def __init__(self, true_labels_unlabeled=None): """ Args: true_labels_unlabeled: Ground truth for unlabeled data (for analysis only) """ self.true_labels = true_labels_unlabeled self.history: List[SelfTrainingMetrics] = [] def log_iteration(self, iteration, X_L, y_L, X_U, samples_added_indices, confidences, predictions): """Log metrics for a self-training iteration.""" # Compute pseudo-label accuracy if ground truth available pseudo_acc = None if self.true_labels is not None and len(samples_added_indices) > 0: predicted = [predictions[i] for i in samples_added_indices] true = [self.true_labels[i] for i in samples_added_indices] pseudo_acc = np.mean(np.array(predicted) == np.array(true)) # Class distribution of pseudo-labels if len(samples_added_indices) > 0: added_preds = [predictions[i] for i in samples_added_indices] unique, counts = np.unique(added_preds, return_counts=True) class_dist = dict(zip(unique.tolist(), counts.tolist())) else: class_dist = {} metrics = SelfTrainingMetrics( iteration=iteration, labeled_count=len(X_L), unlabeled_count=len(X_U), samples_added=len(samples_added_indices), avg_confidence=np.mean(confidences) if len(confidences) > 0 else 0, min_confidence_added=min([confidences[i] for i in samples_added_indices]) if len(samples_added_indices) > 0 else 0, max_confidence_added=max([confidences[i] for i in samples_added_indices]) if len(samples_added_indices) > 0 else 0, pseudo_label_accuracy=pseudo_acc, class_distribution=class_dist ) self.history.append(metrics) return metrics def detect_failure_modes(self) -> List[str]: """Analyze history to detect potential failure modes.""" warnings = [] if len(self.history) < 2: return warnings # 1. Diminishing returns: samples added decreasing rapidly recent = self.history[-3:] if len(recent) >= 3: additions = [m.samples_added for m in recent] if additions[-1] < additions[0] * 0.1 and additions[-1] < 10: warnings.append("Diminishing returns: very few samples being added") # 2. Confidence collapse: average confidence dropping if len(self.history) >= 5: early_conf = np.mean([m.avg_confidence for m in self.history[:3]]) late_conf = np.mean([m.avg_confidence for m in self.history[-3:]]) if late_conf < early_conf * 0.8: warnings.append("Confidence collapse: model becoming less certain") # 3. Class imbalance in pseudo-labels if self.history[-1].class_distribution: dist = self.history[-1].class_distribution if len(dist) > 1: counts = list(dist.values()) if max(counts) > 5 * min(counts): warnings.append("Severe class imbalance in pseudo-labels") # 4. Degrading pseudo-label accuracy (if available) accs = [m.pseudo_label_accuracy for m in self.history if m.pseudo_label_accuracy is not None] if len(accs) >= 3 and accs[-1] < accs[0] * 0.9: warnings.append("Pseudo-label accuracy degrading over iterations") return warnings def plot_dynamics(self): """Visualize self-training dynamics.""" if len(self.history) < 2: return fig, axes = plt.subplots(2, 2, figsize=(12, 8)) iterations = [m.iteration for m in self.history] # Plot 1: Labeled/Unlabeled counts axes[0, 0].plot(iterations, [m.labeled_count for m in self.history], 'b-', label='Labeled') axes[0, 0].plot(iterations, [m.unlabeled_count for m in self.history], 'r-', label='Unlabeled') axes[0, 0].set_xlabel('Iteration') axes[0, 0].set_ylabel('Sample Count') axes[0, 0].legend() axes[0, 0].set_title('Dataset Size Evolution') # Plot 2: Samples added per iteration axes[0, 1].bar(iterations, [m.samples_added for m in self.history]) axes[0, 1].set_xlabel('Iteration') axes[0, 1].set_ylabel('Samples Added') axes[0, 1].set_title('Pseudo-Labels Added per Iteration') # Plot 3: Confidence statistics axes[1, 0].plot(iterations, [m.avg_confidence for m in self.history], 'g-', label='Avg Confidence') axes[1, 0].fill_between( iterations, [m.min_confidence_added for m in self.history], [m.max_confidence_added for m in self.history], alpha=0.3, label='Added Range' ) axes[1, 0].set_xlabel('Iteration') axes[1, 0].set_ylabel('Confidence') axes[1, 0].legend() axes[1, 0].set_title('Confidence Dynamics') # Plot 4: Pseudo-label accuracy (if available) accs = [m.pseudo_label_accuracy for m in self.history] if any(a is not None for a in accs): valid_iters = [it for it, a in zip(iterations, accs) if a is not None] valid_accs = [a for a in accs if a is not None] axes[1, 1].plot(valid_iters, valid_accs, 'mo-') axes[1, 1].set_xlabel('Iteration') axes[1, 1].set_ylabel('Accuracy') axes[1, 1].set_title('Pseudo-Label Accuracy') axes[1, 1].set_ylim([0, 1]) plt.tight_layout() return figModern self-training has evolved significantly from the basic algorithm. Let's examine key variants that address fundamental limitations.
Curriculum Self-Training:
Instead of accepting all samples above a fixed threshold, curriculum self-training progressively lowers the threshold, starting with the most confident samples:
$$\tau_t = \tau_{\max} \cdot \lambda^t$$
where $\lambda < 1$ is a decay factor. This implements an easy-to-hard curriculum, ensuring the model learns from reliable examples before tackling harder ones.
Self-Training with Noise (Noisy Student):
Xie et al. (2020) introduced Noisy Student Training, which adds noise during student training:
Critically, the student trains with data augmentation and dropout while the teacher generates clean labels. This prevents the student from simply memorizing teacher predictions.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoader, ConcatDatasetimport copy class NoisyStudentTrainer: """ Implementation of Noisy Student Training (Xie et al., 2020). Key differences from basic self-training: 1. Teacher generates labels WITHOUT noise 2. Student trains WITH noise (augmentation, dropout, stochastic depth) 3. Student can be equal or larger than teacher 4. Iterative refinement: student becomes new teacher """ def __init__(self, model_fn, augmentation_fn, device='cuda'): """ Args: model_fn: Function that returns a new model instance augmentation_fn: Data augmentation function device: Training device """ self.model_fn = model_fn self.augment = augmentation_fn self.device = device def train_teacher(self, model, labeled_loader, epochs=100): """Train teacher on labeled data only.""" model.to(self.device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(epochs): model.train() for x, y in labeled_loader: x, y = x.to(self.device), y.to(self.device) optimizer.zero_grad() loss = F.cross_entropy(model(x), y) loss.backward() optimizer.step() return model def generate_pseudo_labels(self, teacher, unlabeled_loader, confidence_threshold=0.0): """ Generate pseudo-labels using teacher (NO noise/augmentation). Args: teacher: Trained teacher model unlabeled_loader: DataLoader for unlabeled data confidence_threshold: Minimum confidence to include (0.0 = include all) Returns: List of (x, pseudo_y, confidence) tuples """ teacher.eval() pseudo_labeled = [] with torch.no_grad(): for x in unlabeled_loader: if isinstance(x, (list, tuple)): x = x[0] # Handle case where loader returns tuple x = x.to(self.device) logits = teacher(x) probs = F.softmax(logits, dim=1) confidences, predictions = probs.max(dim=1) for i in range(len(x)): if confidences[i] >= confidence_threshold: pseudo_labeled.append({ 'x': x[i].cpu(), 'y': predictions[i].cpu().item(), 'confidence': confidences[i].cpu().item() }) print(f"Generated {len(pseudo_labeled)} pseudo-labels " f"(avg confidence: {sum(p['confidence'] for p in pseudo_labeled)/len(pseudo_labeled):.3f})") return pseudo_labeled def train_student_with_noise(self, student, labeled_loader, pseudo_labeled_data, epochs=100, dropout_rate=0.5, noise_strength=0.1): """ Train student on combined data WITH noise. Noise sources: 1. Data augmentation (RandAugment, etc.) 2. Dropout during training 3. Optional input noise """ student.to(self.device) student.train() # Enable all noise mechanisms for module in student.modules(): if isinstance(module, nn.Dropout): module.p = dropout_rate optimizer = torch.optim.Adam(student.parameters(), lr=0.001) for epoch in range(epochs): # Iterate through labeled data for x, y in labeled_loader: x, y = x.to(self.device), y.to(self.device) # Apply augmentation (noise source #1) x_aug = self.augment(x) # Optional: Add input noise (noise source #3) x_noisy = x_aug + noise_strength * torch.randn_like(x_aug) optimizer.zero_grad() # Dropout is automatically applied during training (noise source #2) loss = F.cross_entropy(student(x_noisy), y) loss.backward() optimizer.step() # Iterate through pseudo-labeled data # Note: We shuffle and create mini-batches pseudo_batch_size = 64 indices = torch.randperm(len(pseudo_labeled_data)) for i in range(0, len(indices), pseudo_batch_size): batch_indices = indices[i:i+pseudo_batch_size] batch = [pseudo_labeled_data[idx] for idx in batch_indices] x = torch.stack([item['x'] for item in batch]).to(self.device) y = torch.tensor([item['y'] for item in batch]).to(self.device) # Apply augmentation to pseudo-labeled data too x_aug = self.augment(x) x_noisy = x_aug + noise_strength * torch.randn_like(x_aug) optimizer.zero_grad() loss = F.cross_entropy(student(x_noisy), y) loss.backward() optimizer.step() return student def iterative_training(self, labeled_loader, unlabeled_loader, num_iterations=3, teacher_epochs=100, student_epochs=150): """ Full iterative Noisy Student Training. Each iteration: 1. Train/retrieve teacher 2. Generate pseudo-labels (no noise) 3. Train student (with noise) 4. Student becomes new teacher """ teacher = self.model_fn() teacher = self.train_teacher(teacher, labeled_loader, epochs=teacher_epochs) for iteration in range(num_iterations): print(f"\n=== Noisy Student Iteration {iteration + 1} ===") # Generate pseudo-labels with current teacher pseudo_data = self.generate_pseudo_labels(teacher, unlabeled_loader) # Create student (potentially larger architecture) student = self.model_fn() # Train student with noise student = self.train_student_with_noise( student, labeled_loader, pseudo_data, epochs=student_epochs ) # Student becomes new teacher teacher = student return teacherMeta Pseudo Labels (MPL):
Pham et al. (2021) introduced Meta Pseudo Labels, where the teacher is adaptive—it learns to generate better pseudo-labels based on student feedback:
$$\theta_T \leftarrow \theta_T - \alpha \nabla_{\theta_T} \mathcal{L}{\text{labeled}}(f{\theta_S})$$
The teacher's parameters are updated to minimize the student's loss on labeled data, creating a meta-learning loop where good pseudo-labels lead to a better student.
Self-Training with Soft Labels:
Instead of hard pseudo-labels, some methods retain the full probability distribution:
$$\mathcal{L} = -\sum_j \sum_k P_T(y=k|x_j) \log P_S(y=k|x_j)$$
This preserves uncertainty information but requires careful temperature scaling to prevent overconfident teacher predictions from dominating.
| Method | Key Innovation | When to Use |
|---|---|---|
| Classic Self-Training | High-confidence pseudo-labeling | Small datasets, simple problems |
| Curriculum Self-Training | Progressive threshold decay | When confidence varies widely |
| Noisy Student | Noise during student training | Image classification at scale |
| Meta Pseudo Labels | Adaptive teacher from student feedback | When pseudo-label quality is critical |
| Soft Self-Training | Probability distribution as labels | When uncertainty matters |
Deploying self-training in production requires careful attention to engineering details that papers often gloss over. Let's examine practical patterns for robust implementation.
Hyperparameter Selection:
The key hyperparameters and their typical ranges:
| Parameter | Description | Typical Range | Selection Strategy |
|---|---|---|---|
| $\tau$ (threshold) | Confidence cutoff | 0.7 - 0.99 | Start high (0.95), decay if needed |
| $K$ (per-iteration) | Max samples per iteration | 10 - 1000 | 1-5% of remaining unlabeled |
| $\alpha$ (weight) | Pseudo-label loss weight | 0.1 - 1.0 | Increase as training progresses |
| $T$ (temperature) | Calibration temperature | 1.0 - 3.0 | Learn on calibration set |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
import numpy as npfrom sklearn.base import clonefrom sklearn.model_selection import train_test_splitfrom collections import Counterfrom typing import Callable, Optional class RobustSelfTraining: """ Production-ready self-training with common failure mode mitigations. Features: 1. Class-balanced pseudo-labeling 2. Confidence calibration 3. Validation-based early stopping 4. Rollback on performance degradation """ def __init__( self, base_classifier, confidence_threshold: float = 0.9, threshold_decay: float = 0.99, min_threshold: float = 0.7, max_samples_per_iter: int = 100, class_balance_strategy: str = 'equal', # 'equal', 'proportional', 'none' max_iterations: int = 20, early_stopping_patience: int = 3, early_stopping_metric: str = 'accuracy', validation_fraction: float = 0.1, random_state: int = 42 ): self.base_classifier = base_classifier self.threshold = confidence_threshold self.threshold_decay = threshold_decay self.min_threshold = min_threshold self.max_samples_per_iter = max_samples_per_iter self.class_balance = class_balance_strategy self.max_iterations = max_iterations self.patience = early_stopping_patience self.metric = early_stopping_metric self.val_fraction = validation_fraction self.random_state = random_state self.classifier_ = None self.training_history_ = [] def _get_class_balanced_samples(self, indices, predictions, confidences, max_samples, existing_class_counts): """ Select samples while maintaining class balance. Args: indices: Indices of samples above threshold predictions: Predicted labels for those samples confidences: Confidence scores max_samples: Maximum samples to select existing_class_counts: Current class distribution in labeled set """ # Group by predicted class class_groups = {} for idx, pred, conf in zip(indices, predictions, confidences): if pred not in class_groups: class_groups[pred] = [] class_groups[pred].append((idx, conf)) # Sort each group by confidence (descending) for pred in class_groups: class_groups[pred].sort(key=lambda x: -x[1]) selected = [] if self.class_balance == 'equal': # Take equal samples from each class num_classes = len(class_groups) per_class = max_samples // num_classes for pred, group in class_groups.items(): selected.extend([g[0] for g in group[:per_class]]) elif self.class_balance == 'proportional': # Match existing class distribution total_existing = sum(existing_class_counts.values()) for pred, group in class_groups.items(): # Historical proportion of this class prop = existing_class_counts.get(pred, 1) / total_existing n_samples = int(max_samples * prop) selected.extend([g[0] for g in group[:n_samples]]) else: # 'none' # Take top by confidence regardless of class all_samples = [(idx, conf) for group in class_groups.values() for idx, conf in group] all_samples.sort(key=lambda x: -x[1]) selected = [s[0] for s in all_samples[:max_samples]] return selected def fit(self, X_labeled, y_labeled, X_unlabeled, X_val=None, y_val=None): """ Fit the self-training classifier. Args: X_labeled: Labeled features y_labeled: Labels X_unlabeled: Unlabeled features X_val, y_val: Optional validation set (will split from labeled if not provided) """ np.random.seed(self.random_state) # Create validation set if not provided if X_val is None: X_labeled, X_val, y_labeled, y_val = train_test_split( X_labeled, y_labeled, test_size=self.val_fraction, random_state=self.random_state, stratify=y_labeled ) # Initialize X_L, y_L = list(X_labeled), list(y_labeled) X_U = list(X_unlabeled) current_threshold = self.threshold best_score = -np.inf best_classifier = None patience_counter = 0 for iteration in range(self.max_iterations): # Train classifier self.classifier_ = clone(self.base_classifier) self.classifier_.fit(X_L, y_L) # Evaluate on validation set val_score = self.classifier_.score(X_val, y_val) # Early stopping check if val_score > best_score: best_score = val_score best_classifier = clone(self.classifier_) patience_counter = 0 else: patience_counter += 1 if patience_counter >= self.patience: print(f"Early stopping at iteration {iteration}") break # Check termination if len(X_U) == 0: print("All samples pseudo-labeled") break # Get predictions on unlabeled data probas = self.classifier_.predict_proba(X_U) predictions = probas.argmax(axis=1) confidences = probas.max(axis=1) # Find samples above threshold high_conf_mask = confidences >= current_threshold high_conf_indices = np.where(high_conf_mask)[0] if len(high_conf_indices) == 0: # Decay threshold current_threshold = max( self.min_threshold, current_threshold * self.threshold_decay ) print(f"Iteration {iteration}: No samples above threshold, " f"decaying to {current_threshold:.3f}") continue # Class-balanced selection class_counts = Counter(y_L) selected_indices = self._get_class_balanced_samples( high_conf_indices, predictions[high_conf_indices], confidences[high_conf_indices], self.max_samples_per_iter, class_counts ) # Add pseudo-labeled samples for idx in sorted(selected_indices, reverse=True): X_L.append(X_U[idx]) y_L.append(predictions[idx]) del X_U[idx] # Log progress self.training_history_.append({ 'iteration': iteration, 'labeled_count': len(X_L), 'unlabeled_count': len(X_U), 'samples_added': len(selected_indices), 'val_score': val_score, 'threshold': current_threshold }) print(f"Iteration {iteration}: Added {len(selected_indices)} samples, " f"Val score: {val_score:.4f}, Threshold: {current_threshold:.3f}") # Use best classifier self.classifier_ = best_classifier return self def predict(self, X): return self.classifier_.predict(X) def predict_proba(self, X): return self.classifier_.predict_proba(X)Before deploying self-training: (1) Ensure labeled data is representative of all classes, (2) Calibrate model predictions with temperature scaling, (3) Implement validation-based early stopping, (4) Monitor class distribution drift in pseudo-labels, (5) Set up rollback mechanisms to previous best model.
Understanding the conditions under which self-training succeeds or fails is crucial for practitioners. Let's synthesize insights from theory and empirical studies.
Conditions Favoring Self-Training:
Empirical Observations from Research:
| Study | Finding |
|---|---|
| Chapelle et al. (2006) | Self-training effective when labeled data is representative of cluster structure |
| Zhu (2005) | Performance degrades when label noise exceeds ~20% |
| Ruder & Plank (2018) | NLP tasks: self-training helps low-resource languages by 2-10% F1 |
| Xie et al. (2020) | ImageNet: Noisy Student achieved SOTA with 3% improvement |
| Cascante-Bonilla (2021) | Self-training fails when confident predictions cluster in easy regions |
The "Easy First" Problem:
Self-training naturally picks "easy" samples first—those the model is confident about. This creates a potentially problematic dynamic:
This can lead to improved average accuracy but degraded performance on minority classes or edge cases.
Modern neural networks are systematically overconfident. A model predicting 98% confidence is often only 85% accurate. Without calibration, self-training rapidly incorporates ~15% error rate pseudo-labels, leading to compounding degradation. Always calibrate before deploying self-training with neural networks.
We've conducted a comprehensive exploration of self-training, from basic algorithm to advanced variants and failure modes. Let's consolidate the key insights:
What's Next:
Self-training teaches each model in isolation. But what if we had multiple models that could teach each other, leveraging different views of the same data? This is the intuition behind co-training, which we'll explore in the next page. Co-training provides complementary perspectives that can break the confirmation bias cycle inherent in single-model self-training.
You now understand self-training from first principles: the algorithm, theoretical foundations, confidence estimation, convergence dynamics, and practical implementation patterns. This foundational method sets the stage for understanding more sophisticated semi-supervised approaches.