Loading content...
In the landscape of semi-supervised learning, one principle has emerged as the cornerstone of the most successful modern approaches: the consistency assumption. This deceptively simple idea—that small perturbations to an input should not change its predicted label—has revolutionized how we leverage unlabeled data, enabling state-of-the-art performance across computer vision, natural language processing, and beyond.
Before we explore specific algorithms like FixMatch or MixMatch, we must deeply understand this foundational assumption. Why does it work? What are its theoretical justifications? What are its limitations? This page provides the rigorous conceptual grounding necessary to master consistency-based semi-supervised learning.
By the end of this page, you will understand the consistency assumption's formal definition, its relationship to the smoothness assumption, the mathematical framework for consistency regularization, and why this principle enables effective learning from unlabeled data. You'll develop the intuition required to understand and implement modern semi-supervised methods.
Consider a photograph of a cat. If we slightly rotate the image by 5 degrees, add a tiny amount of Gaussian noise, or shift the brightness marginally, the image still clearly depicts the same cat. A robust classifier should recognize all these variations as "cat" with high confidence.
This observation forms the core intuition behind the consistency assumption:
Consistency Assumption: If an input $x$ and a perturbed version $\tilde{x}$ are semantically equivalent, then a good classifier should produce the same (or highly similar) predictions for both.
Mathematically, for a classifier $f$ and a perturbation function $\mathcal{T}$:
$$f(x) \approx f(\mathcal{T}(x))$$
where $\mathcal{T}(x)$ represents a semantically-preserving transformation of $x$.
The consistency assumption is more than just noise robustness. It's a statement about the structure of decision boundaries: a good classifier should place decision boundaries in low-density regions of the input space, far from any data points. Perturbed versions of an input should stay on the same side of the decision boundary.
Why is this useful for semi-supervised learning?
The key insight is that enforcing consistency does not require labels. We can demand that $f(x) \approx f(\mathcal{T}(x))$ for any input $x$—labeled or unlabeled. This gives us a supervision signal from unlabeled data:
The unlabeled consistency constraint regularizes the model, preventing it from learning decision boundaries that pass through regions with high data density.
| Aspect | Traditional Regularization (L2, Dropout) | Consistency Regularization |
|---|---|---|
| Primary Goal | Prevent overfitting to training data | Learn smooth decision boundaries |
| Uses Unlabeled Data | No | Yes—explicitly |
| Mechanism | Penalize model complexity | Enforce prediction stability |
| Data Dependency | Independent of input distribution | Depends on data manifold structure |
| Semantic Awareness | Purely mathematical constraint | Semantically meaningful perturbations |
The consistency assumption is intimately connected to the smoothness assumption, one of the fundamental assumptions in semi-supervised learning theory. Understanding this connection reveals why consistency regularization works.
The Smoothness Assumption (Formal Definition):
If two points $x_1$ and $x_2$ lie in a high-density region of the input space and are close to each other, then their corresponding labels $y_1$ and $y_2$ should be the same (or similar).
Equivalently:
The label function $f: \mathcal{X} \rightarrow \mathcal{Y}$ should vary smoothly in high-density regions of the input space.
This assumption implies that decision boundaries should lie in low-density regions of the input space. If a decision boundary cuts through a high-density region, it means nearby points with similar features receive different labels—violating smoothness.
Connection to Consistency:
The consistency assumption is a local version of smoothness. When we perturb an input $x$ to get $\mathcal{T}(x)$, we're exploring the local neighborhood around $x$. If $x$ lies in a high-density region (which most data points do—they're sampled from the data distribution), then $\mathcal{T}(x)$ likely remains in the same high-density region.
By enforcing $f(x) \approx f(\mathcal{T}(x))$, we're enforcing local smoothness of the decision function. Over many samples, this local constraint translates to global smoothness—decision boundaries that respect the data manifold structure.
The smoothness assumption is sometimes called the 'cluster assumption' because it implies that data points cluster by class. Points in the same cluster (high-density region connected by paths through high-density regions) should share the same label. This is crucial: it means the data distribution itself contains information about the correct labeling.
Let's formalize consistency regularization mathematically. This framework underlies all modern consistency-based semi-supervised methods.
Problem Setup:
The Consistency Regularization Objective:
$$\mathcal{L}{total} = \mathcal{L}{sup} + \lambda \cdot \mathcal{L}_{unsup}$$
where:
$$\mathcal{L}{sup} = \frac{1}{n_l} \sum{i=1}^{n_l} \ell(f_\theta(x_i), y_i)$$
$$\mathcal{L}{unsup} = \frac{1}{n_u} \sum{j=1}^{n_u} \mathbb{E}{\mathcal{T} \sim p(\mathcal{T})} \left[ d(f\theta(u_j), f_\theta(\mathcal{T}(u_j))) \right]$$
Here, $\ell$ is the supervised loss (typically cross-entropy), $d$ is a distance function between probability distributions, and $\lambda$ is a hyperparameter balancing supervised and unsupervised losses.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
import torchimport torch.nn.functional as Ffrom typing import Callable def consistency_loss( model: torch.nn.Module, inputs: torch.Tensor, augment_fn: Callable[[torch.Tensor], torch.Tensor], distance: str = "mse", num_augmentations: int = 2) -> torch.Tensor: """ Compute consistency regularization loss. Args: model: Neural network classifier inputs: Batch of unlabeled inputs [B, C, H, W] augment_fn: Data augmentation function distance: Distance metric ('mse', 'kl', 'cross_entropy') num_augmentations: Number of augmentations to average over Returns: Consistency loss scalar """ model.train() # Enable dropout, batch norm updates # Get predictions for original inputs (or first augmentation) with torch.no_grad(): # Optionally use an EMA teacher for the target original_logits = model(augment_fn(inputs)) target_probs = F.softmax(original_logits, dim=-1) # Compute consistency across augmentations total_loss = 0.0 for _ in range(num_augmentations): augmented = augment_fn(inputs) aug_logits = model(augmented) aug_probs = F.softmax(aug_logits, dim=-1) if distance == "mse": # Mean Squared Error (used in Mean Teacher, UDA) loss = F.mse_loss(aug_probs, target_probs) elif distance == "kl": # KL Divergence (asymmetric) loss = F.kl_div( F.log_softmax(aug_logits, dim=-1), target_probs, reduction='batchmean' ) elif distance == "cross_entropy": # Cross-entropy with soft targets loss = -(target_probs * F.log_softmax(aug_logits, dim=-1)).sum(dim=-1).mean() else: raise ValueError(f"Unknown distance: {distance}") total_loss += loss return total_loss / num_augmentations def semi_supervised_loss( model: torch.nn.Module, labeled_inputs: torch.Tensor, labels: torch.Tensor, unlabeled_inputs: torch.Tensor, augment_fn: Callable, lambda_u: float = 1.0, warmup_steps: int = 0, current_step: int = 0) -> tuple[torch.Tensor, dict]: """ Combined semi-supervised loss with consistency regularization. Args: model: Neural network classifier labeled_inputs: Batch of labeled inputs labels: Ground truth labels unlabeled_inputs: Batch of unlabeled inputs augment_fn: Data augmentation function lambda_u: Weight for unsupervised loss warmup_steps: Steps to linearly ramp up lambda_u current_step: Current training step Returns: Total loss and dict of individual loss components """ # Supervised loss on labeled data labeled_logits = model(augment_fn(labeled_inputs)) sup_loss = F.cross_entropy(labeled_logits, labels) # Consistency loss on unlabeled data unsup_loss = consistency_loss( model, unlabeled_inputs, augment_fn, distance="mse" ) # Linear warmup for unsupervised loss weight if current_step < warmup_steps: lambda_eff = lambda_u * (current_step / warmup_steps) else: lambda_eff = lambda_u # Combined loss total_loss = sup_loss + lambda_eff * unsup_loss return total_loss, { "supervised_loss": sup_loss.item(), "unsupervised_loss": unsup_loss.item(), "lambda_effective": lambda_eff, "total_loss": total_loss.item() }Choice of Distance Function:
The distance function $d$ measures how different two probability distributions are. Common choices include:
| Distance | Formula | Properties |
|---|---|---|
| MSE | $|p - q|_2^2$ | Symmetric, bounded, smooth gradients |
| KL Divergence | $\sum_k p_k \log(p_k/q_k)$ | Asymmetric, unbounded, mode-seeking |
| JS Divergence | $\frac{1}{2}KL(p|m) + \frac{1}{2}KL(q|m)$ | Symmetric, bounded, where $m = \frac{p+q}{2}$ |
| Cross-Entropy | $-\sum_k p_k \log q_k$ | Asymmetric, standard for soft labels |
MSE is most commonly used because it provides smooth gradients and is symmetric. The choice matters especially when predictions are very confident—KL divergence can explode when $q_k \rightarrow 0$ but $p_k > 0$.
Understanding why consistency regularization is effective requires examining multiple perspectives: geometric, information-theoretic, and empirical.
Geometric Perspective: Margin Maximization
Consider a classifier's decision boundary. When we enforce consistency, we're essentially saying: "Moving around locally shouldn't change the prediction." This is equivalent to demanding a margin around each data point—a region where the classifier is confident.
For a linear classifier, maximizing margins is exactly what SVMs do. Consistency regularization extends this idea to deep networks: it implicitly maximizes the margin by penalizing predictions that are sensitive to small perturbations.
Information-Theoretic Perspective:
From an information-theoretic view, consistency regularization implements a form of the information bottleneck principle. The model must:
This forces the model to learn representations that capture semantically meaningful features (class identity) while discarding transformation-specific details (rotation, noise, color jitter).
$$I(Z; T) \approx 0 \quad \text{(invariant to transformations)}$$ $$I(Z; Y) \approx H(Y) \quad \text{(predictive of labels)}$$
where $Z$ is the learned representation, $T$ encodes the applied transformation, and $Y$ is the label.
Data often lies on a low-dimensional manifold in high-dimensional space. Consistency regularization encourages the model to learn this manifold structure: points connected by paths along the manifold (through augmentations) should have consistent predictions. The decision boundary should cut across the manifold in low-density regions, not along it.
The choice of perturbation function $\mathcal{T}$ is crucial for effective consistency regularization. Perturbations must be semantics-preserving—they should transform the input without changing its true label. Different domains require different perturbation strategies.
Input-Space Perturbations (Data Augmentation):
These operate directly on the raw input:
| Domain | Perturbation | Semantic Preservation |
|---|---|---|
| Images | Rotation, flip, crop, color jitter | Objects remain identifiable |
| Images | Gaussian noise, blur | Fine details vary, objects preserved |
| Text | Synonym replacement, back-translation | Meaning preserved |
| Text | Word dropout, shuffling | Context recoverable |
| Audio | Time stretching, pitch shifting | Speech/content preserved |
| Tabular | Gaussian noise, feature dropout | Sample identity preserved |
Feature-Space Perturbations:
Perturbations can also be applied in the hidden representation space:
Stochastic Perturbations from Model Itself:
Some perturbations come from model stochasticity:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
import torchimport torchvision.transforms as Tfrom typing import Callable, List, Tupleimport random # ========================================# Image Augmentation for Consistency# ======================================== def get_weak_augmentation() -> Callable: """ Weak augmentation: simple, minimal perturbations. Used as baseline/target in some methods. """ return T.Compose([ T.RandomHorizontalFlip(p=0.5), T.RandomAffine( degrees=0, translate=(0.125, 0.125), # Small shifts ), ]) def get_strong_augmentation() -> Callable: """ Strong augmentation: aggressive perturbations. Used for consistency source in methods like FixMatch. """ return T.Compose([ T.RandomHorizontalFlip(p=0.5), T.RandomAffine( degrees=30, # Rotation up to 30° translate=(0.2, 0.2), # Larger shifts scale=(0.8, 1.2), # Scale variation shear=15, # Shearing ), T.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1 ), T.RandomGrayscale(p=0.2), # RandAugment or AutoAugment can be added here ]) class RandAugment: """ RandAugment: Random application of N augmentations with magnitude M. Key insight: Rather than learning augmentation policies, use random selection with tunable intensity. """ def __init__(self, n_ops: int = 2, magnitude: int = 10): self.n_ops = n_ops self.magnitude = magnitude # Define augmentation operations self.operations = [ "identity", "autocontrast", "equalize", "rotate", "solarize", "color", "posterize", "contrast", "brightness", "sharpness", "shear_x", "shear_y", "translate_x", "translate_y" ] def __call__(self, image: torch.Tensor) -> torch.Tensor: # Randomly select n_ops operations ops = random.sample(self.operations, self.n_ops) for op in ops: image = self._apply_operation(image, op, self.magnitude) return image def _apply_operation( self, image: torch.Tensor, operation: str, magnitude: int ) -> torch.Tensor: # Implementation of each operation with magnitude scaling # (Simplified for illustration) M = magnitude / 10.0 # Normalize to [0, 1] if operation == "rotate": angle = M * 30 # Max 30 degrees return T.functional.rotate(image, random.uniform(-angle, angle)) elif operation == "brightness": factor = 1.0 + random.uniform(-M, M) * 0.9 return T.functional.adjust_brightness(image, factor) # ... other operations return image class CTAugment: """ Control Theory Augment: Learns augmentation weights online. Tracks which augmentations preserve model predictions and focuses on those during training. """ def __init__(self, n_bins: int = 17): self.n_bins = n_bins # Weight bins for each augmentation type and magnitude self.weights = {} # Populated during initialization self.decay = 0.99 def update_weights( self, augmentation: str, magnitude_bin: int, was_correct: bool ): """ Update augmentation weights based on whether the model's prediction was preserved. """ key = (augmentation, magnitude_bin) if key not in self.weights: self.weights[key] = 0.5 # Update with exponential moving average target = 1.0 if was_correct else 0.0 self.weights[key] = ( self.decay * self.weights[key] + (1 - self.decay) * target ) def sample_augmentation(self, augmentation: str) -> int: """Sample magnitude bin proportional to weights.""" bins_weights = [ self.weights.get((augmentation, b), 0.5) for b in range(self.n_bins) ] # Sample proportional to weights total = sum(bins_weights) probs = [w / total for w in bins_weights] return random.choices(range(self.n_bins), weights=probs)[0]Using inappropriate augmentations can hurt performance. For example, vertical flipping preserves semantics for objects but not for digits (6 becomes 9). Back-translation works for sentiment analysis but may change meaning in legal documents. Always validate that augmentations are semantics-preserving for your specific task.
Raw consistency regularization has a subtle issue: if the model is uncertain about an unlabeled sample, enforcing consistency on uncertain predictions can be harmful. The model might learn to be consistently wrong.
Two techniques address this:
Sharpening (Temperature Scaling):
Before computing consistency loss, we can "sharpen" the target distribution by lowering its temperature:
$$\text{Sharpen}(p, T) = \frac{p^{1/T}}{\sum_k p_k^{1/T}}$$
where $T < 1$ sharpens (makes more confident) and $T > 1$ softens. Sharpening encourages the model toward confident predictions, fighting entropy collapse toward uniform distributions.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
import torchimport torch.nn.functional as F def sharpen(probs: torch.Tensor, temperature: float = 0.5) -> torch.Tensor: """ Sharpen a probability distribution by applying temperature scaling. Args: probs: Probability distribution [B, K] temperature: Sharpening temperature (< 1 sharpens, > 1 softens) Returns: Sharpened probability distribution Example: >>> probs = torch.tensor([[0.3, 0.4, 0.3]]) >>> sharpen(probs, temperature=0.5) tensor([[0.2154, 0.5692, 0.2154]]) # More peaked at class 1 """ # Raise to power 1/T and renormalize sharpened = probs.pow(1.0 / temperature) return sharpened / sharpened.sum(dim=-1, keepdim=True) def confidence_threshold_mask( probs: torch.Tensor, threshold: float = 0.95) -> torch.Tensor: """ Create mask for samples exceeding confidence threshold. Args: probs: Probability distribution [B, K] threshold: Minimum confidence to include sample Returns: Boolean mask [B] indicating high-confidence samples """ max_probs, _ = probs.max(dim=-1) return max_probs >= threshold def thresholded_consistency_loss( model: torch.nn.Module, unlabeled: torch.Tensor, augment_weak: callable, augment_strong: callable, threshold: float = 0.95, sharpen_temp: float = 0.5) -> tuple[torch.Tensor, dict]: """ Consistency loss with confidence thresholding (FixMatch style). Only applies consistency loss to samples where the model is confident on the weakly-augmented version. Args: model: Neural network classifier unlabeled: Batch of unlabeled inputs augment_weak: Weak augmentation function augment_strong: Strong augmentation function threshold: Confidence threshold for including samples sharpen_temp: Temperature for sharpening pseudo-labels Returns: Loss and statistics dictionary """ # Get pseudo-labels from weakly augmented inputs weak_aug = augment_weak(unlabeled) with torch.no_grad(): weak_logits = model(weak_aug) weak_probs = F.softmax(weak_logits, dim=-1) # Apply sharpening sharp_probs = sharpen(weak_probs, temperature=sharpen_temp) # Get pseudo-labels and confidence max_probs, pseudo_labels = sharp_probs.max(dim=-1) # Create mask for high-confidence predictions mask = max_probs >= threshold # Get predictions for strongly augmented inputs strong_aug = augment_strong(unlabeled) strong_logits = model(strong_aug) # Compute cross-entropy only for high-confidence samples if mask.sum() > 0: loss = F.cross_entropy( strong_logits[mask], pseudo_labels[mask], reduction='mean' ) else: loss = torch.tensor(0.0, device=unlabeled.device) # Statistics stats = { "above_threshold_ratio": mask.float().mean().item(), "avg_confidence": max_probs.mean().item(), "num_sampled": mask.sum().item(), } return loss, statsConfidence Thresholding:
Instead of (or in addition to) sharpening, we can simply ignore low-confidence predictions:
$$\mathcal{L}{unsup} = \frac{1}{|\mathcal{B}u|} \sum{u \in \mathcal{B}u} \mathbf{1}[\max(f\theta(u)) \geq \tau] \cdot \ell(f\theta(\mathcal{T}(u)), \hat{y})$$
where $\tau$ is the confidence threshold (typically 0.95) and $\hat{y} = \arg\max f_\theta(u)$ is the pseudo-label.
This is the core mechanism in FixMatch: only compute consistency loss when the model is confident about the (weakly-augmented) input's class.
| Technique | Mechanism | Effect | When to Use |
|---|---|---|---|
| Sharpening | Raise probs to power 1/T | Makes all predictions more confident | When model is well-calibrated but cautious |
| Thresholding | Ignore low-confidence samples | Only uses high-confidence predictions | When model may be wrong on uncertain samples |
| Combined | Sharpen first, then threshold | Confident pseudo-labels only when confident | Most robust approach (used in FixMatch) |
Consistency regularization has solid theoretical underpinnings. Understanding these foundations helps us know when consistency-based methods will (and won't) work.
PAC-Bayes Perspective:
From a PAC-Bayes perspective, consistency regularization can be viewed as controlling the complexity of learned hypotheses. By enforcing that the classifier is stable under perturbations, we limit the effective capacity of the hypothesis class.
The generalization bound includes a term:
$$\mathbb{E}{x \sim p(x)} \mathbb{E}{\mathcal{T} \sim p(\mathcal{T})} \left[ \text{loss}(f(x), f(\mathcal{T}(x))) \right]$$
Minimizing this term reduces the gap between training and test error.
Consistency regularization assumes the smoothness/cluster assumption holds. It fails when: (1) Classes are not well-separated in input space (overlapping distributions), (2) The perturbations don't respect class boundaries (e.g., rotating a 6 into a 9), (3) The unlabeled data comes from a different distribution than test data, or (4) Early training phases produce poor pseudo-labels that propagate errors.
Connection to Virtual Adversarial Training:
Virtual Adversarial Training (VAT) can be seen as a form of consistency regularization where the perturbation is chosen adversarially to maximize the KL divergence:
$$r_{adv} = \arg\max_{|r|2 \leq \epsilon} \text{KL}(f\theta(x) | f_\theta(x + r))$$
Then consistency is enforced by minimizing:
$$\mathcal{L}{VAT} = \text{KL}(f\theta(x) | f_\theta(x + r_{adv}))$$
This focuses the consistency constraint on the directions where the model is most sensitive—exactly where decision boundaries might pass through data.
Label Propagation View:
Consistency regularization can be interpreted as implicit label propagation. If $x_1$ and $x_2$ can be connected by a chain of small perturbations (each within the perturbation distribution), enforcing consistency at each step propagates label information between them:
$$f(x_1) \approx f(\mathcal{T}_1(x_1)) \approx f(\mathcal{T}_2(\mathcal{T}_1(x_1))) \approx \cdots \approx f(x_2)$$
This is particularly powerful when labeled and unlabeled points lie on the same data manifold.
We've established the theoretical and conceptual foundations of consistency regularization. This principle—that predictions should be stable under semantics-preserving perturbations—is the cornerstone of modern semi-supervised learning.
Let's consolidate the key insights:
What's Next:
With the consistency assumption firmly established, we're ready to explore how it's applied in practice. The next page examines Data Augmentation Consistency—how the choice of augmentation strategies dramatically affects the effectiveness of consistency regularization, and why strong augmentations have proven so crucial for state-of-the-art performance.
You now understand the consistency assumption—the foundational principle underlying modern semi-supervised learning. This conceptual grounding will make the specific techniques (UDA, FixMatch, MixMatch) that follow much more intuitive. Each is a creative application of this core idea.