Loading content...
MixMatch (Berthelot et al., 2019) takes a different approach than UDA or FixMatch. Rather than focusing on one core idea, MixMatch combines multiple semi-supervised learning techniques into a unified framework: consistency regularization, entropy minimization via sharpening, and MixUp augmentation.
The result is a method that is greater than the sum of its parts—achieving strong performance across diverse benchmarks through the synergistic combination of complementary techniques. Understanding MixMatch provides insight into how different semi-supervised principles interact and reinforce each other.
By the end of this page, you will understand: the complete MixMatch algorithm, how MixUp augmentation works and why it helps, the role of distribution alignment and temperature sharpening, how to implement MixMatch from scratch, and how MixMatch compares to FixMatch and UDA in different scenarios.
MixMatch is built on the observation that several semi-supervised learning techniques, while developed independently, address complementary aspects of the problem:
1. Consistency Regularization: Enforces that the model's predictions are stable under input perturbations (augmentations).
2. Entropy Minimization: Encourages the model to make confident predictions, pushing decision boundaries away from data points.
3. MixUp Regularization: Interpolates between examples in both input and label space, encouraging smooth decision boundaries.
MixMatch unifies these into a single algorithm:
$$\mathcal{L}{MixMatch} = \mathcal{L}{X'} + \lambda_U \mathcal{L}_{U'}$$
where $\mathcal{L}{X'}$ is the supervised loss on mixed labeled data, $\mathcal{L}{U'}$ is the consistency loss on mixed unlabeled data, and both use processed labels that incorporate sharpening and averaging.
MixMatch's key insight is that MixUp provides a natural way to 'bridge' labeled and unlabeled data. By mixing examples from both sets, the model is forced to interpolate between labeled supervision and consistency regularization, creating a more unified learning signal.
Before diving into MixMatch, we need to understand MixUp—a simple yet powerful regularization technique that forms the core of MixMatch's approach.
MixUp Algorithm:
Given two examples $(x_i, y_i)$ and $(x_j, y_j)$, MixUp creates a virtual example:
$$\tilde{x} = \lambda x_i + (1 - \lambda) x_j$$ $$\tilde{y} = \lambda y_i + (1 - \lambda) y_j$$
where $\lambda \sim \text{Beta}(\alpha, \alpha)$ and $\alpha$ is a hyperparameter controlling interpolation strength.
Why MixUp Works:
Linear behavior: MixUp trains the model to predict linear interpolations between classes, encouraging simpler, more generalizable decision boundaries.
Regularization: By creating "in-between" examples, MixUp prevents the model from memorizing training data.
Soft labels: Mixed labels provide uncertainty information, similar to label smoothing.
Data augmentation: MixUp effectively creates infinite training examples from finite data.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
import torchimport torch.nn.functional as Ffrom typing import Tupleimport numpy as np def mixup_data( x: torch.Tensor, y: torch.Tensor, alpha: float = 0.75,) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]: """ Apply MixUp augmentation to a batch. Args: x: Input batch [B, ...] y: Labels (one-hot or soft) [B, K] alpha: Beta distribution parameter Returns: (mixed_x, mixed_y, shuffled_indices, lambda) """ if alpha > 0: # Sample mixing coefficient from Beta distribution lam = np.random.beta(alpha, alpha) # Ensure labeled examples dominate (optional) lam = max(lam, 1 - lam) else: lam = 1 batch_size = x.size(0) # Random permutation for mixing partners index = torch.randperm(batch_size, device=x.device) # Mix inputs mixed_x = lam * x + (1 - lam) * x[index] # Mix labels mixed_y = lam * y + (1 - lam) * y[index] return mixed_x, mixed_y, index, lam def interleave(x: torch.Tensor, batch_size: int) -> torch.Tensor: """ Interleave batch elements for proper batch normalization. When mixing labeled and unlabeled data, we need to ensure batch normalization sees a representative mix in each forward pass. Args: x: Concatenated batch [B_l + B_u * K, ...] batch_size: Size of a single group Returns: Interleaved batch with same total size """ # Reshape to [num_groups, batch_size, ...] num_groups = x.size(0) // batch_size shape = [num_groups, batch_size] + list(x.shape[1:]) x = x.view(*shape) # Transpose to interleave x = x.transpose(0, 1).contiguous() # Flatten back return x.view(-1, *x.shape[2:]) def de_interleave(x: torch.Tensor, batch_size: int) -> torch.Tensor: """Reverse interleaving operation.""" num_groups = x.size(0) // batch_size shape = [batch_size, num_groups] + list(x.shape[1:]) x = x.view(*shape) x = x.transpose(0, 1).contiguous() return x.view(-1, *x.shape[2:]) class MixUpBatch: """ Efficient MixUp for semi-supervised learning. Handles mixing of labeled and unlabeled data with proper handling of different label types (one-hot vs soft). """ def __init__( self, alpha: float = 0.75, mix_labeled_only: bool = False, ): self.alpha = alpha self.mix_labeled_only = mix_labeled_only def __call__( self, x_labeled: torch.Tensor, y_labeled: torch.Tensor, x_unlabeled: torch.Tensor, q_unlabeled: torch.Tensor, # Guessed soft labels ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Apply MixUp to labeled and unlabeled batches. MixMatch-style: Concatenate all data, shuffle, then mix labeled with batch, unlabeled with batch. Args: x_labeled: Labeled inputs [B_l, ...] y_labeled: Labels (one-hot) [B_l, K] x_unlabeled: Unlabeled inputs [B_u * K, ...] q_unlabeled: Guessed labels [B_u * K, K] Returns: (x_labeled_mixed, y_labeled_mixed, x_unlabeled_mixed, q_unlabeled_mixed) """ batch_size_l = x_labeled.size(0) batch_size_u = x_unlabeled.size(0) # Concatenate all inputs and labels all_inputs = torch.cat([x_labeled, x_unlabeled], dim=0) all_targets = torch.cat([y_labeled, q_unlabeled], dim=0) # Shuffle idx = torch.randperm(all_inputs.size(0), device=all_inputs.device) shuffled_inputs = all_inputs[idx] shuffled_targets = all_targets[idx] # Sample lambda if self.alpha > 0: lam = np.random.beta(self.alpha, self.alpha) lam = max(lam, 1 - lam) # Ensure original dominates else: lam = 1.0 # Mix labeled: original labeled + shuffled (any) x_labeled_mixed = ( lam * x_labeled + (1 - lam) * shuffled_inputs[:batch_size_l] ) y_labeled_mixed = ( lam * y_labeled + (1 - lam) * shuffled_targets[:batch_size_l] ) # Mix unlabeled: original unlabeled + shuffled (any) x_unlabeled_mixed = ( lam * x_unlabeled + (1 - lam) * shuffled_inputs[batch_size_l:] ) q_unlabeled_mixed = ( lam * q_unlabeled + (1 - lam) * shuffled_targets[batch_size_l:] ) return ( x_labeled_mixed, y_labeled_mixed, x_unlabeled_mixed, q_unlabeled_mixed )| Alpha (α) | Distribution | Effect | Typical Use |
|---|---|---|---|
| 0.0 | Point mass at 1 | No mixing (standard training) | Baseline |
| 0.2 | Peaked at extremes | Light mixing | When mixing hurts |
| 0.5 | U-shaped | Moderate mixing | Default for many tasks |
| 0.75 | Flatter U | Strong mixing | MixMatch default |
| 1.0 | Uniform [0,1] | Maximum mixing | Very aggressive regularization |
Let's walk through the complete MixMatch algorithm step by step.
Input:
Step 1: Augment and Label Guess
For each labeled example, apply one random augmentation: $$\hat{x}_b = \text{Augment}(x_b), \quad \hat{y}_b = y_b$$
For each unlabeled example, apply $K$ random augmentations and compute average prediction: $$\hat{u}{b,k} = \text{Augment}(u_b) \quad \text{for } k = 1, ..., K$$ $$\bar{q}b = \frac{1}{K} \sum{k=1}^{K} p{model}(\hat{u}_{b,k})$$
Step 2: Sharpening
Sharpen the guessed distribution: $$\text{Sharpen}(\bar{q}, T)_i = \frac{\bar{q}_i^{1/T}}{\sum_j \bar{q}_j^{1/T}}$$
Set $\hat{q}_b = \text{Sharpen}(\bar{q}_b, T)$ as the pseudo-label for all $K$ augmentations.
Step 3: MixUp
Collect all augmented examples: $$\hat{\mathcal{X}} = {(\hat{x}b, \hat{y}b)}{b=1}^B$$ $$\hat{\mathcal{U}} = {(\hat{u}{b,k}, \hat{q}b)}{b=1, k=1}^{B, K}$$
Shuffle the combined set $\mathcal{W} = \text{Shuffle}(\hat{\mathcal{X}} \cup \hat{\mathcal{U}})$
Apply MixUp separately to labeled and unlabeled: $$\mathcal{X}' = \text{MixUp}(\hat{\mathcal{X}}, \mathcal{W}[:B])$$ $$\mathcal{U}' = \text{MixUp}(\hat{\mathcal{U}}, \mathcal{W}[B:])$$
Step 4: Loss Computation
$$\mathcal{L}{\mathcal{X}'} = \frac{1}{|\mathcal{X}'|} \sum{(x', p') \in \mathcal{X}'} H(p', p_{model}(x'))$$
$$\mathcal{L}{\mathcal{U}'} = \frac{1}{|\mathcal{U}'|} \sum{(u', q') \in \mathcal{U}'} |q' - p_{model}(u')|_2^2$$
$$\mathcal{L} = \mathcal{L}{\mathcal{X}'} + \lambda_U \mathcal{L}{\mathcal{U}'}$$
where $H$ is cross-entropy.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Tuple, Dict, List, Callablefrom dataclasses import dataclassimport numpy as np @dataclassclass MixMatchConfig: """Configuration for MixMatch training.""" # Number of augmentations per unlabeled sample K: int = 2 # Sharpening temperature temperature: float = 0.5 # MixUp alpha alpha: float = 0.75 # Unsupervised loss weight lambda_u: float = 75.0 # Note: much higher than FixMatch # Ramp-up period for lambda_u rampup_length: int = 16000 # Number of classes num_classes: int = 10 class MixMatch(nn.Module): """ MixMatch: A Holistic Approach to Semi-Supervised Learning. Paper: Berthelot et al., NeurIPS 2019 Combines: 1. Consistency regularization (K augmentations + averaging) 2. Entropy minimization (sharpening) 3. MixUp regularization """ def __init__( self, model: nn.Module, config: MixMatchConfig, augment_fn: Callable[[torch.Tensor], torch.Tensor], ): super().__init__() self.model = model self.config = config self.augment_fn = augment_fn def sharpen(self, probs: torch.Tensor, temperature: float) -> torch.Tensor: """Sharpen probability distribution.""" temp_probs = probs.pow(1.0 / temperature) return temp_probs / temp_probs.sum(dim=-1, keepdim=True) def guess_labels( self, unlabeled: torch.Tensor, K: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Generate guessed labels for unlabeled data. 1. Apply K different augmentations 2. Get model predictions for each 3. Average predictions 4. Sharpen averaged predictions Args: unlabeled: Raw unlabeled batch [B, C, H, W] K: Number of augmentations Returns: (augmented_inputs [B*K, C, H, W], guessed_labels [B*K, num_classes]) """ batch_size = unlabeled.size(0) # Generate K augmentations for each unlabeled sample all_augmented = [] all_predictions = [] with torch.no_grad(): for _ in range(K): aug = self.augment_fn(unlabeled) all_augmented.append(aug) logits = self.model(aug) probs = F.softmax(logits, dim=-1) all_predictions.append(probs) # Stack augmented versions: [K, B, C, H, W] augmented_stacked = torch.stack(all_augmented, dim=0) # Average predictions across K augmentations: [K, B, num_classes] -> [B, num_classes] predictions_stacked = torch.stack(all_predictions, dim=0) avg_predictions = predictions_stacked.mean(dim=0) # Sharpen the average sharpened = self.sharpen(avg_predictions, self.config.temperature) # Reshape augmented to [B*K, C, H, W] augmented_flat = augmented_stacked.transpose(0, 1).contiguous() augmented_flat = augmented_flat.view(-1, *unlabeled.shape[1:]) # Repeat sharpened labels K times: [B, num_classes] -> [B*K, num_classes] guessed_labels = sharpened.unsqueeze(1).repeat(1, K, 1) guessed_labels = guessed_labels.view(-1, self.config.num_classes) return augmented_flat, guessed_labels def mixup( self, x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply MixUp between two sets of examples. Args: x1, y1: Primary examples (these dominate due to max(lam, 1-lam)) x2, y2: Secondary examples (mixed in) Returns: (mixed_x, mixed_y) """ # Sample lambda if self.config.alpha > 0: lam = np.random.beta(self.config.alpha, self.config.alpha) lam = max(lam, 1 - lam) # Ensure primary dominates else: lam = 1.0 # MixUp mixed_x = lam * x1 + (1 - lam) * x2 mixed_y = lam * y1 + (1 - lam) * y2 return mixed_x, mixed_y def get_lambda_u(self, step: int) -> float: """Linear ramp-up for unsupervised loss weight.""" if step >= self.config.rampup_length: return self.config.lambda_u return self.config.lambda_u * (step / self.config.rampup_length) def forward( self, x_labeled: torch.Tensor, y_labeled: torch.Tensor, x_unlabeled: torch.Tensor, step: int, ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Compute MixMatch loss. Args: x_labeled: Labeled inputs [B_l, C, H, W] y_labeled: Labels (indices) [B_l] x_unlabeled: Unlabeled inputs [B_u, C, H, W] step: Current training step Returns: (loss, metrics_dict) """ batch_size_l = x_labeled.size(0) batch_size_u = x_unlabeled.size(0) K = self.config.K # ===================== # Step 1: Augment labeled data (once) # ===================== x_labeled_aug = self.augment_fn(x_labeled) # Convert labels to one-hot y_labeled_onehot = F.one_hot( y_labeled, num_classes=self.config.num_classes ).float() # ===================== # Step 2: Generate guessed labels for unlabeled # ===================== x_unlabeled_aug, q_guessed = self.guess_labels(x_unlabeled, K) # ===================== # Step 3: Concatenate and shuffle for MixUp # ===================== all_inputs = torch.cat([x_labeled_aug, x_unlabeled_aug], dim=0) all_targets = torch.cat([y_labeled_onehot, q_guessed], dim=0) # Shuffle idx = torch.randperm(all_inputs.size(0), device=all_inputs.device) shuffled_inputs = all_inputs[idx] shuffled_targets = all_targets[idx] # ===================== # Step 4: MixUp labeled and unlabeled separately # ===================== # For labeled: mix with first batch_size_l shuffled samples x_labeled_mixed, y_labeled_mixed = self.mixup( x_labeled_aug, y_labeled_onehot, shuffled_inputs[:batch_size_l], shuffled_targets[:batch_size_l] ) # For unlabeled: mix with remaining shuffled samples x_unlabeled_mixed, q_mixed = self.mixup( x_unlabeled_aug, q_guessed, shuffled_inputs[batch_size_l:], shuffled_targets[batch_size_l:] ) # ===================== # Step 5: Forward pass through model # ===================== # Interleave for proper batch norm behavior all_mixed = torch.cat([x_labeled_mixed, x_unlabeled_mixed], dim=0) all_logits = self.model(all_mixed) logits_labeled = all_logits[:batch_size_l] logits_unlabeled = all_logits[batch_size_l:] # ===================== # Step 6: Compute losses # ===================== # Supervised: Cross-entropy with soft labels log_probs_labeled = F.log_softmax(logits_labeled, dim=-1) loss_labeled = -(y_labeled_mixed * log_probs_labeled).sum(dim=-1).mean() # Unsupervised: MSE between prediction and guessed label probs_unlabeled = F.softmax(logits_unlabeled, dim=-1) loss_unlabeled = F.mse_loss(probs_unlabeled, q_mixed) # Combined loss with ramp-up lambda_u = self.get_lambda_u(step) loss_total = loss_labeled + lambda_u * loss_unlabeled # ===================== # Metrics # ===================== metrics = { "loss_labeled": loss_labeled.item(), "loss_unlabeled": loss_unlabeled.item(), "loss_total": loss_total.item(), "lambda_u": lambda_u, "avg_guess_confidence": q_guessed.max(dim=-1)[0].mean().item(), "guess_entropy": -(q_guessed * (q_guessed + 1e-8).log()).sum(dim=-1).mean().item(), } return loss_total, metricsAn important extension of MixMatch is distribution alignment, introduced in ReMixMatch. The observation is that pseudo-labels from a model may not match the true label distribution, especially early in training.
The Problem:
If the model is biased toward certain classes, pseudo-labels will over-represent those classes. This creates a feedback loop where the bias is reinforced.
The Solution: Distribution Alignment
Align the pseudo-label distribution to match the expected class distribution (e.g., uniform for balanced datasets):
$$\tilde{q} = \text{Normalize}(q \cdot \frac{p(y)}{\tilde{p}(y)})$$
where:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
import torchimport torch.nn.functional as Ffrom typing import Optional class DistributionAlignment: """ Distribution alignment for semi-supervised learning. Aligns pseudo-label distribution to match target class distribution, preventing confirmation bias in pseudo-labeling. """ def __init__( self, num_classes: int, momentum: float = 0.999, target_distribution: Optional[torch.Tensor] = None, ): """ Args: num_classes: Number of classes momentum: EMA momentum for tracking model distribution target_distribution: Target class distribution (uniform if None) """ self.num_classes = num_classes self.momentum = momentum # Target distribution (default: uniform) if target_distribution is None: self.target_dist = torch.ones(num_classes) / num_classes else: self.target_dist = target_distribution # Running estimate of model's prediction distribution self.running_dist = torch.ones(num_classes) / num_classes def update(self, predictions: torch.Tensor): """ Update running distribution estimate. Args: predictions: Batch of soft predictions [B, K] """ # Current batch distribution batch_dist = predictions.mean(dim=0).detach() # EMA update self.running_dist = ( self.momentum * self.running_dist + (1 - self.momentum) * batch_dist.cpu() ) def align(self, predictions: torch.Tensor) -> torch.Tensor: """ Align predictions to target distribution. Args: predictions: Model predictions [B, K] Returns: Aligned predictions [B, K] """ # Move target and running to same device as predictions device = predictions.device target = self.target_dist.to(device) running = self.running_dist.to(device) # Scaling factor: how much to adjust each class # If running[k] < target[k], we want to upweight class k scale = target / (running + 1e-8) # Apply scaling aligned = predictions * scale.unsqueeze(0) # Renormalize to valid probability distribution aligned = aligned / aligned.sum(dim=-1, keepdim=True) return aligned class ReMixMatch(MixMatch): """ ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring. Paper: Berthelot et al., ICLR 2020 Extends MixMatch with: 1. Distribution alignment 2. Augmentation anchoring (not implemented here) 3. Stronger augmentation via CTAugment """ def __init__( self, model: nn.Module, config: MixMatchConfig, augment_fn: callable, strong_augment_fn: callable, # Additional strong augmentation ): super().__init__(model, config, augment_fn) self.strong_augment_fn = strong_augment_fn self.dist_align = DistributionAlignment(config.num_classes) def guess_labels( self, unlabeled: torch.Tensor, K: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Guess labels with distribution alignment. """ batch_size = unlabeled.size(0) all_augmented = [] all_predictions = [] with torch.no_grad(): for k in range(K): # Use strong augmentation for all but first if k == 0: aug = self.augment_fn(unlabeled) # Weak else: aug = self.strong_augment_fn(unlabeled) # Strong all_augmented.append(aug) logits = self.model(aug) probs = F.softmax(logits, dim=-1) all_predictions.append(probs) # Average and sharpen predictions_stacked = torch.stack(all_predictions, dim=0) avg_predictions = predictions_stacked.mean(dim=0) # Distribution alignment before sharpening aligned = self.dist_align.align(avg_predictions) # Update running distribution self.dist_align.update(avg_predictions) # Sharpen aligned predictions sharpened = self.sharpen(aligned, self.config.temperature) # Format output augmented_stacked = torch.stack(all_augmented, dim=0) augmented_flat = augmented_stacked.transpose(0, 1).contiguous() augmented_flat = augmented_flat.view(-1, *unlabeled.shape[1:]) guessed_labels = sharpened.unsqueeze(1).repeat(1, K, 1) guessed_labels = guessed_labels.view(-1, self.config.num_classes) return augmented_flat, guessed_labelsDistribution alignment helps most when: (1) Classes are imbalanced — prevents model from ignoring minority classes. (2) Model has strong initial bias — corrects for pre-training effects. (3) Early training stages — stabilizes pseudo-label distribution before model converges. For balanced datasets with good initialization, the improvement is usually modest (1-2%).
Understanding when to use MixMatch versus FixMatch or UDA requires analyzing their architectural and performance differences.
| Aspect | MixMatch | FixMatch | UDA |
|---|---|---|---|
| Core technique | MixUp + Consistency | Strong aug + Hard labels | Strong aug + Soft labels |
| Pseudo-labels | Soft (averaged over K aug) | Hard (argmax) | Soft (sharpened) |
| Loss function | CE (labeled) + MSE (unlabeled) | CE for both | CE + KL divergence |
| Augmentations per sample | K (typically 2) | 2 (weak + strong) | 2 (weak + strong) |
| Uses MixUp | Yes (core component) | No | No |
| Confidence threshold | No (implicit via sharpening) | Yes (0.95) | Yes (0.8) |
| λ_u typical value | 75-100 | 1 | 1 |
| Training complexity | Higher (MixUp, K forwards) | Lower (2 forwards) | Medium |
Performance Comparison (CIFAR-10 Error Rate %):
| Labels | MixMatch | FixMatch | ReMixMatch |
|---|---|---|---|
| 40 | 11.08 ± 0.87 | 4.26 ± 0.05 | 6.27 ± 0.34 |
| 250 | 6.24 ± 0.06 | 4.86 ± 0.05 | 4.72 ± 0.13 |
| 4000 | 4.95 ± 0.08 | 4.21 ± 0.08 | 4.14 ± 0.09 |
Key Observations:
FixMatch wins at extreme low-label settings (40 labels) — the high confidence threshold prevents noisy pseudo-labels from dominating
MixMatch is more stable across different label counts — less variance in results
ReMixMatch improves on MixMatch through distribution alignment and stronger augmentation
Computational cost differs — MixMatch requires K forward passes for label guessing, making it slower
Successfully implementing MixMatch requires attention to several practical details that significantly impact performance.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
# Recommended MixMatch hyperparameters for CIFAR-10 MIXMATCH_CIFAR10 = { # Model "architecture": "WideResNet-28-2", "weight_decay": 0.04, # Higher than FixMatch # MixMatch specific "K": 2, # Number of augmentations "temperature": 0.5, # Sharpening temperature "alpha": 0.75, # MixUp parameter "lambda_u": 75.0, # Unsupervised loss weight (high!) "rampup_length": 16000, # Steps to ramp up lambda_u # Data "labeled_batch_size": 64, "unlabeled_batch_size": 64, "num_workers": 4, # Augmentation (standard, not strong like FixMatch) "augmentation": "random_crop (pad=4) + horizontal_flip", # Optimizer "optimizer": "Adam", # Note: MixMatch uses Adam, not SGD "lr": 0.002, "betas": (0.9, 0.999), # Training "total_steps": 1000000, # 1M steps "eval_every": 1000, # EMA for evaluation "use_ema": True, "ema_decay": 0.999,} # Adjustments for different label countsADJUSTMENTS = { "40_labels": { "lambda_u": 100.0, # Higher weight compensates for fewer labels "rampup_length": 65536, # Longer ramp-up for stability }, "250_labels": { "lambda_u": 75.0, "rampup_length": 16384, }, "4000_labels": { "lambda_u": 25.0, # Lower weight since labeled data provides signal "rampup_length": 4096, },}MixMatch demonstrates the power of combining multiple semi-supervised techniques into a cohesive framework. While not the simplest method, its holistic approach provides stable, strong performance across diverse settings.
What's Next:
We've examined the major consistency regularization methods: UDA, FixMatch, and MixMatch. The next page explores Pseudo-Labeling—a simpler but foundational technique that predates these methods and continues to be effective. Understanding pseudo-labeling provides insight into why modern methods work and how to diagnose issues when they don't.
You now understand MixMatch's holistic approach to semi-supervised learning. This knowledge enables you to choose the right method for your specific setting and to combine techniques creatively when tackling new semi-supervised learning challenges.