Loading learning content...
In 2019-2020, two methods revolutionized semi-supervised learning: UDA (Unsupervised Data Augmentation) and FixMatch. These approaches demonstrated that combining strong data augmentation with simple consistency regularization could achieve unprecedented performance—reaching 95%+ accuracy on CIFAR-10 with just 250 labeled examples (0.5% of the dataset).
What makes these methods remarkable is their simplicity. Neither relies on complex generative models, adversarial training, or elaborate loss functions. Instead, they carefully orchestrate known components—augmentation, pseudo-labeling, and consistency—into elegant, highly effective systems.
This page provides a comprehensive, implementation-level understanding of both methods, their differences, and why they work so well.
By the end of this page, you will understand: the complete algorithms for UDA and FixMatch, the key design decisions that make them effective, how to implement both methods from scratch, their performance characteristics and failure modes, and when to choose one over the other.
UDA (Unsupervised Data Augmentation) was introduced by Xie et al. (2019) at Google Research. The key insight of UDA is deceptively simple: advanced data augmentation that improves supervised learning also improves semi-supervised learning—and by a larger margin.
Core Idea:
The Algorithm:
$$\mathcal{L}{total} = \mathcal{L}{sup}(x_l, y_l) + \lambda \cdot \mathbb{1}[\max(p_{weak}) \geq \tau] \cdot KL(p_{weak} | p_{strong})$$
where:
Key Components of UDA:
1. Training Signal Annealing (TSA):
UDA introduces a clever technique to prevent the model from overfitting to labeled data early in training:
$$\mathcal{L}{sup} = \sum{(x,y) \in \mathcal{D}L} \mathbb{1}[p{model}(y|x) < \eta_t] \cdot \ell(f_\theta(x), y)$$
where $\eta_t$ is a threshold that increases during training (linear, log, or exponential schedule). This removes correctly classified labeled examples from supervision, forcing the model to rely more on unlabeled data.
2. Sharpening:
Before using the soft pseudo-label, UDA sharpens the distribution:
$$\tilde{p}_k = \frac{p_k^{1/T}}{\sum_j p_j^{1/T}}$$
with temperature $T = 0.4$ (sharpening). This encourages confident predictions and reduces entropy collapse.
3. Soft Pseudo-Labels + KL Divergence:
Unlike hard pseudo-labeling, UDA uses soft targets and KL divergence. This preserves uncertainty information and provides smoother gradients.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Tuple, Dict, Callable, Optionalfrom dataclasses import dataclass @dataclassclass UDAConfig: """Configuration for UDA training.""" # Loss weights lambda_u: float = 1.0 # Unsupervised loss weight # Confidence thresholding threshold: float = 0.8 # Confidence threshold for pseudo-labels # Sharpening temperature: float = 0.4 # Temperature for sharpening (< 1 sharpens) # Training Signal Annealing use_tsa: bool = True # Whether to use TSA tsa_schedule: str = "linear" # 'linear', 'log', 'exp' # Training total_steps: int = 1000000 # Total training steps warmup_steps: int = 10000 # Steps to warmup lambda_u class TrainingSignalAnnealing: """ Training Signal Annealing (TSA) for UDA. Gradually releases labeled examples based on model's predictions. Prevents overfitting to labeled data early in training. """ def __init__( self, schedule: str = "linear", num_classes: int = 10, num_steps: int = 1000000, ): self.schedule = schedule self.num_classes = num_classes self.num_steps = num_steps def get_threshold(self, step: int) -> float: """ Get TSA threshold at current step. Returns threshold η_t. Examples with p(y|x) > η_t are removed. """ progress = step / self.num_steps # 0 to 1 # Start at 1/K (random guess) and end at 1 start = 1.0 / self.num_classes end = 1.0 if self.schedule == "linear": threshold = start + (end - start) * progress elif self.schedule == "log": # Logarithmic: slow start, fast end threshold = start + (end - start) * ( 1 - (1 - progress) ** 5 ) elif self.schedule == "exp": # Exponential: fast start, slow end threshold = start + (end - start) * (progress ** 5) else: threshold = 1.0 # No TSA return threshold def compute_mask( self, logits: torch.Tensor, labels: torch.Tensor, step: int ) -> torch.Tensor: """ Compute mask for supervised loss with TSA. Args: logits: Model predictions [B, K] labels: True labels [B] step: Current training step Returns: Boolean mask [B] - True means include in loss """ probs = F.softmax(logits, dim=-1) # Get probability of correct class for each example correct_probs = probs[torch.arange(len(labels)), labels] # Mask out examples where model is already confident threshold = self.get_threshold(step) mask = correct_probs < threshold return mask class UDA(nn.Module): """ Unsupervised Data Augmentation (UDA). Paper: "Unsupervised Data Augmentation for Consistency Training" Xie et al., 2019 Key innovations: 1. Weak-to-strong augmentation paradigm 2. Training Signal Annealing 3. Sharpened soft pseudo-labels """ def __init__( self, model: nn.Module, config: UDAConfig, num_classes: int = 10, ): super().__init__() self.model = model self.config = config self.num_classes = num_classes if config.use_tsa: self.tsa = TrainingSignalAnnealing( schedule=config.tsa_schedule, num_classes=num_classes, num_steps=config.total_steps, ) else: self.tsa = None def sharpen(self, probs: torch.Tensor) -> torch.Tensor: """Sharpen probability distribution using temperature.""" temp = self.config.temperature sharpened = probs.pow(1.0 / temp) return sharpened / sharpened.sum(dim=-1, keepdim=True) def forward( self, x_labeled: torch.Tensor, y_labeled: torch.Tensor, x_unlabeled_weak: torch.Tensor, x_unlabeled_strong: torch.Tensor, step: int, ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Compute UDA loss. Args: x_labeled: Labeled inputs (already augmented) y_labeled: Labels x_unlabeled_weak: Weakly augmented unlabeled inputs x_unlabeled_strong: Strongly augmented unlabeled inputs step: Current training step Returns: (loss, metrics_dict) """ batch_size_l = x_labeled.size(0) batch_size_u = x_unlabeled_weak.size(0) # ===================== # Supervised loss # ===================== logits_l = self.model(x_labeled) if self.tsa is not None: # Apply Training Signal Annealing tsa_mask = self.tsa.compute_mask(logits_l, y_labeled, step) if tsa_mask.sum() > 0: loss_sup = F.cross_entropy( logits_l[tsa_mask], y_labeled[tsa_mask] ) else: loss_sup = torch.tensor(0.0, device=x_labeled.device) tsa_removed = (~tsa_mask).float().mean().item() else: loss_sup = F.cross_entropy(logits_l, y_labeled) tsa_removed = 0.0 # ===================== # Unsupervised loss # ===================== # Get pseudo-labels from weak augmentation with torch.no_grad(): logits_weak = self.model(x_unlabeled_weak) probs_weak = F.softmax(logits_weak, dim=-1) # Sharpen the distribution probs_sharp = self.sharpen(probs_weak) # Confidence mask max_probs, _ = probs_weak.max(dim=-1) mask = max_probs >= self.config.threshold # Get predictions from strong augmentation logits_strong = self.model(x_unlabeled_strong) # KL divergence loss (only on confident samples) if mask.sum() > 0: log_probs_strong = F.log_softmax(logits_strong[mask], dim=-1) loss_unsup = F.kl_div( log_probs_strong, probs_sharp[mask], reduction='batchmean' ) else: loss_unsup = torch.tensor(0.0, device=x_labeled.device) # ===================== # Combined loss # ===================== # Linear warmup for lambda if step < self.config.warmup_steps: lambda_eff = self.config.lambda_u * (step / self.config.warmup_steps) else: lambda_eff = self.config.lambda_u loss_total = loss_sup + lambda_eff * loss_unsup # Metrics metrics = { "loss_sup": loss_sup.item(), "loss_unsup": loss_unsup.item(), "loss_total": loss_total.item(), "mask_ratio": mask.float().mean().item(), "avg_confidence": max_probs.mean().item(), "tsa_removed_ratio": tsa_removed, "lambda_eff": lambda_eff, } return loss_total, metricsFixMatch (Sohn et al., 2020) simplified UDA while achieving even better performance. The key insight: replace soft pseudo-labels and KL divergence with hard pseudo-labels and cross-entropy, combined with a higher confidence threshold.
The FixMatch Algorithm (Pseudocode):
for (x_l, y) in labeled_data, x_u in unlabeled_data:
# Supervised loss on labeled data
L_s = CrossEntropy(model(Aug_weak(x_l)), y)
# Generate pseudo-label from weakly augmented unlabeled
p = model(Aug_weak(x_u))
y_pseudo = argmax(p)
confidence = max(p)
# Unsupervised loss on confident samples only
if confidence >= 0.95:
L_u = CrossEntropy(model(Aug_strong(x_u)), y_pseudo)
else:
L_u = 0
# Total loss
L = L_s + lambda * L_u
Why FixMatch's Simplifications Work:
| Aspect | UDA | FixMatch | Why FixMatch's Choice Works |
|---|---|---|---|
| Pseudo-labels | Soft (distribution) | Hard (one-hot) | Hard labels provide stronger supervision signal |
| Loss function | KL divergence | Cross-entropy | Cross-entropy is simpler, more stable with hard labels |
| Confidence threshold | 0.8 (typical) | 0.95 | Higher threshold ensures pseudo-label quality |
| Sharpening | Yes (T=0.4) | No (implicit with argmax) | argmax is extreme sharpening (T→0) |
| TSA | Yes | No | High threshold removes need for TSA |
| Augmentation | AutoAugment | RandAugment + CTAugment | RandAugment is simpler, equally effective |
The Magic of the 0.95 Threshold:
FixMatch's 0.95 confidence threshold might seem arbitrary, but it's carefully chosen:
Pseudo-label accuracy: At 95% confidence, pseudo-labels are highly accurate (>90% correct on CIFAR-10)
Learning progression: Early in training, few samples exceed 0.95. As the model improves, more samples qualify—creating a natural curriculum.
No confirmation bias: Low-confidence (potentially wrong) predictions don't reinforce themselves.
No need for TSA: The threshold naturally prevents overfitting to pseudo-labels.
The threshold effectively selects "easy" unlabeled samples where the model is already confident, providing reliable supervision.
FixMatch's effectiveness comes from the combination of: (1) High confidence threshold for pseudo-label quality, (2) Strong augmentations for consistency difficulty, (3) Hard pseudo-labels for strong supervision signal. None of these alone is sufficient—it's their combination that works.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Tuple, Dictfrom dataclasses import dataclass @dataclassclass FixMatchConfig: """Configuration for FixMatch training.""" # Loss weights lambda_u: float = 1.0 # Unsupervised loss weight # Confidence thresholding threshold: float = 0.95 # High confidence threshold # Training warmup_steps: int = 0 # Usually no warmup needed # Architecture use_ema: bool = False # Optional: EMA teacher ema_decay: float = 0.999 # EMA decay rate class EMAModel: """Exponential Moving Average of model parameters.""" def __init__(self, model: nn.Module, decay: float = 0.999): self.model = model self.decay = decay self.shadow = {} self.backup = {} # Initialize shadow parameters for name, param in model.named_parameters(): if param.requires_grad: self.shadow[name] = param.data.clone() def update(self): """Update shadow parameters with EMA.""" for name, param in self.model.named_parameters(): if param.requires_grad: self.shadow[name] = ( self.decay * self.shadow[name] + (1 - self.decay) * param.data ) def apply_shadow(self): """Replace model parameters with shadow.""" for name, param in self.model.named_parameters(): if param.requires_grad: self.backup[name] = param.data.clone() param.data = self.shadow[name] def restore(self): """Restore original model parameters.""" for name, param in self.model.named_parameters(): if param.requires_grad: param.data = self.backup[name] class FixMatch(nn.Module): """ FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. Paper: Sohn et al., NeurIPS 2020 Key innovations: 1. Hard pseudo-labels with high confidence threshold (0.95) 2. Cross-entropy loss instead of KL divergence 3. No sharpening or TSA needed 4. RandAugment + CTAugment for strong augmentation """ def __init__( self, model: nn.Module, config: FixMatchConfig, num_classes: int = 10, ): super().__init__() self.model = model self.config = config self.num_classes = num_classes # Optional EMA teacher if config.use_ema: self.ema = EMAModel(model, decay=config.ema_decay) else: self.ema = None def forward( self, x_labeled: torch.Tensor, y_labeled: torch.Tensor, x_unlabeled_weak: torch.Tensor, x_unlabeled_strong: torch.Tensor, step: int = 0, ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Compute FixMatch loss. Args: x_labeled: Labeled inputs (with weak augmentation) y_labeled: Labels x_unlabeled_weak: Weakly augmented unlabeled inputs x_unlabeled_strong: Strongly augmented unlabeled inputs step: Current training step (for optional warmup) Returns: (loss, metrics_dict) """ batch_size_l = x_labeled.size(0) batch_size_u = x_unlabeled_weak.size(0) # ===================== # Supervised loss # ===================== logits_l = self.model(x_labeled) loss_sup = F.cross_entropy(logits_l, y_labeled) # ===================== # Unsupervised loss # ===================== # Get pseudo-labels from weak augmentation with torch.no_grad(): # Use EMA model for pseudo-labels if available if self.ema is not None: self.ema.apply_shadow() logits_weak = self.model(x_unlabeled_weak) self.ema.restore() else: logits_weak = self.model(x_unlabeled_weak) probs_weak = F.softmax(logits_weak, dim=-1) # Hard pseudo-labels max_probs, pseudo_labels = probs_weak.max(dim=-1) # Confidence mask (the key: high threshold) mask = max_probs >= self.config.threshold # Get predictions from strong augmentation logits_strong = self.model(x_unlabeled_strong) # Cross-entropy loss on confident samples only if mask.sum() > 0: loss_unsup = F.cross_entropy( logits_strong[mask], pseudo_labels[mask], reduction='mean' ) else: loss_unsup = torch.tensor(0.0, device=x_labeled.device) # ===================== # Combined loss # ===================== # Optional warmup (usually not needed for FixMatch) if self.config.warmup_steps > 0 and step < self.config.warmup_steps: lambda_eff = self.config.lambda_u * (step / self.config.warmup_steps) else: lambda_eff = self.config.lambda_u loss_total = loss_sup + lambda_eff * loss_unsup # ===================== # Metrics # ===================== with torch.no_grad(): # Compute pseudo-label accuracy (for monitoring) # This requires ground truth labels which we don't have # in real scenarios, but useful for debugging metrics = { "loss_sup": loss_sup.item(), "loss_unsup": loss_unsup.item(), "loss_total": loss_total.item(), "mask_ratio": mask.float().mean().item(), "mask_count": mask.sum().item(), "avg_confidence": max_probs.mean().item(), "max_confidence": max_probs.max().item(), "min_confidence_above_thresh": max_probs[mask].min().item() if mask.any() else 0.0, } return loss_total, metrics def update_ema(self): """Update EMA parameters after optimizer step.""" if self.ema is not None: self.ema.update() def train_step_fixmatch( model: FixMatch, optimizer: torch.optim.Optimizer, x_labeled: torch.Tensor, y_labeled: torch.Tensor, x_unlabeled_weak: torch.Tensor, x_unlabeled_strong: torch.Tensor, step: int,) -> Dict[str, float]: """ Single training step for FixMatch. Args: model: FixMatch module optimizer: Optimizer x_labeled: Labeled batch y_labeled: Labels x_unlabeled_weak: Weakly augmented unlabeled batch x_unlabeled_strong: Strongly augmented unlabeled batch step: Current step Returns: Metrics dictionary """ model.train() optimizer.zero_grad() loss, metrics = model( x_labeled, y_labeled, x_unlabeled_weak, x_unlabeled_strong, step ) loss.backward() optimizer.step() # Update EMA after optimizer step model.update_ema() return metricsBoth UDA and FixMatch achieved remarkable results on standard benchmarks. Understanding their performance characteristics helps in practical application.
| Method | 40 labels | 250 labels | 4000 labels |
|---|---|---|---|
| Π-Model | 45.74 ± 3.97 | 16.37 ± 0.63 | 6.32 ± 0.15 |
| Mean Teacher | 32.32 ± 2.30 | 10.36 ± 0.25 | 5.94 ± 0.11 |
| MixMatch | 11.08 ± 0.87 | 6.24 ± 0.06 | 4.95 ± 0.08 |
| UDA | 10.62 ± 3.75 | 5.29 ± 0.25 | 4.31 ± 0.08 |
| ReMixMatch | 6.27 ± 0.34 | 4.72 ± 0.13 | 4.14 ± 0.09 |
| FixMatch | 4.26 ± 0.05 | 4.86 ± 0.05 | 4.21 ± 0.08 |
Key Observations:
Low-label regime: FixMatch dominates with 40 labels (4 per class), achieving 4.26% error where MixMatch has 11.08%.
Consistent improvement: Unlike some methods that excel in specific regimes, FixMatch performs well across all label counts.
Low variance: FixMatch has remarkably low variance (±0.05 to ±0.08), indicating stable training.
Near-supervised performance: With 4000 labels, FixMatch (4.21%) approaches fully-supervised performance on all 50,000 labels (~3.9%).
Both UDA and FixMatch struggle when: (1) Class imbalance is severe — model becomes confident on majority classes. (2) Domains are very different — standard augmentations don't apply. (3) Label noise is high — wrong labels propagate through pseudo-labeling. (4) Classes are inherently hard to separate — threshold is never reached.
Understanding which components are essential helps in adapting these methods to new domains. The original papers provide detailed ablations.
| Modification | Error Rate | Δ from Baseline |
|---|---|---|
| Full FixMatch | 4.86% | — |
| Without strong augmentation | 36.21% | +31.35% |
| Threshold = 0.0 (all samples) | 7.98% | +3.12% |
| Threshold = 0.5 | 5.86% | +1.00% |
| Threshold = 0.8 | 5.24% | +0.38% |
| Soft pseudo-labels + MSE | 5.46% | +0.60% |
| CTAugment instead of RandAugment | 5.07% | +0.21% |
| λ_u = 0.5 | 5.64% | +0.78% |
| λ_u = 2.0 | 5.23% | +0.37% |
Critical insight: Without strong augmentation, error rate jumps from 4.86% to 36.21% — a 7× increase! This confirms that augmentation is the most critical component, not the specific loss formulation.
Given the similarity between UDA and FixMatch, how do you choose? Here's practical guidance based on empirical evidence and community experience.
For most image classification tasks, start with FixMatch. It's simpler, has fewer hyperparameters, and achieves state-of-the-art results. Only switch to UDA or more complex methods if FixMatch underperforms or you have specific requirements (e.g., NLP tasks where soft labels are more appropriate).
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
# Recommended hyperparameters for FixMatch/UDA# Based on original papers and community best practices CIFAR10_FIXMATCH = { # Model "architecture": "WideResNet-28-2", "weight_decay": 5e-4, # Data "labeled_batch_size": 64, "unlabeled_batch_size": 64 * 7, # 7x ratio "num_workers": 4, # Augmentation "weak_aug": "flip + random_crop (pad=4)", "strong_aug": "RandAugment(n=2, m=10) + Cutout", # FixMatch specific "threshold": 0.95, "lambda_u": 1.0, # Optimizer "optimizer": "SGD", "lr": 0.03, "momentum": 0.9, "nesterov": True, # Schedule "total_steps": 2**20, # ~1M steps "warmup_steps": 0, # No warmup needed # LR schedule: cosine decay "lr_schedule": "cosine",} CIFAR10_UDA = { # Same as FixMatch except: "temperature": 0.4, # For sharpening "threshold": 0.8, # Lower threshold "use_tsa": True, # Enable TSA "tsa_schedule": "linear", # Linear TSA schedule "strong_aug": "AutoAugment", # Original UDA used AutoAugment} IMAGENET_FIXMATCH = { # ImageNet requires larger model, longer training "architecture": "ResNet-50", "labeled_batch_size": 2048, # Distributed across GPUs "unlabeled_batch_size": 2048 * 5, "total_steps": 300000, "threshold": 0.7, # Lower threshold for harder task "strong_aug": "RandAugment(n=2, m=10)",}UDA and FixMatch represent the maturation of consistency-based semi-supervised learning. By combining strong augmentations with careful pseudo-labeling, they achieve remarkable performance with minimal complexity.
What's Next:
Building on UDA and FixMatch, research has developed more sophisticated methods that incorporate additional techniques. The next page explores MixMatch—a method that combines consistency regularization with MixUp data augmentation and distribution alignment, achieving strong results through a different balance of techniques.
You now have complete understanding of UDA and FixMatch—from theory to implementation. These methods form the foundation of modern semi-supervised learning and provide a strong baseline for any low-label image classification task.