Loading content...
If the consistency assumption is the principle behind modern semi-supervised learning, then data augmentation is its engine. The remarkable success of methods like FixMatch, UDA, and MixMatch stems not just from enforcing consistency, but from doing so with carefully designed augmentation strategies.
This page provides an exhaustive examination of data augmentation for consistency regularization. We'll explore why augmentation matters so much, how different augmentation strategies affect learning, and the paradigm-shifting insight that combining weak and strong augmentations unlocks dramatically better performance.
By the end of this page, you will understand: why augmentation quality is paramount for semi-supervised learning, the weak-to-strong augmentation paradigm, how to design augmentations for different data modalities, RandAugment and AutoAugment strategies, and how augmentation diversity connects to consistency regularization effectiveness.
Consistency regularization enforces that $f(x) \approx f(\mathcal{T}(x))$. The choice of perturbation function $\mathcal{T}$ determines what invariances the model learns. Different augmentations encode different prior knowledge about the task:
Rotation invariance: "Objects remain the same object regardless of orientation"
Translation invariance: "Objects remain the same when shifted in space"
Color invariance: "Object identity is independent of lighting conditions"
Noise robustness: "Minor pixel variations don't change semantics"
Weak augmentations (like small translations) create easy consistency targets—the model can trivially satisfy them. Strong augmentations create harder targets that force the model to learn robust, generalizable representations.
Research has consistently shown that augmentation quality is often the bottleneck in semi-supervised learning. A strong method with weak augmentations underperforms a simple method with strong augmentations. This is why RandAugment and AutoAugment transformed the field.
| Method | Weak Aug Only | Medium Aug | Strong Aug (RandAugment) | Improvement |
|---|---|---|---|---|
| Pseudo-Label | 51.2% | 73.4% | 79.1% | +27.9% |
| Mean Teacher | 47.3% | 78.2% | 89.3% | +42.0% |
| MixMatch | 67.1% | 88.9% | 93.6% | +26.5% |
| FixMatch | N/A | N/A | 95.7% | State-of-the-art |
Why Strong Augmentation Helps:
Strong augmentations provide a richer supervisory signal from unlabeled data:
The key insight is that strong augmentations simulate having more training data by creating plausible variations of existing samples.
A breakthrough in semi-supervised learning came from recognizing that weak and strong augmentations should play different roles in training. This insight, crystallized in FixMatch, separates:
Weak augmentation (for pseudo-labels): Create a minimally perturbed version to get a reliable prediction
Strong augmentation (for consistency target): Create an aggressively perturbed version that must match the pseudo-label
This asymmetry is crucial: the pseudo-label comes from a "clean" version where the model is more likely to be correct, while consistency is enforced on the "corrupted" version where the model must work harder.
Why This Works:
Pseudo-label quality: Weak augmentation minimizes the chance of corrupting semantics, giving more accurate pseudo-labels
Consistency difficulty: Strong augmentation makes the consistency task non-trivial, forcing representation learning
Asymmetric roles: The pseudo-label provides stable supervision; the strong augmentation provides challenging examples
Bootstrapping: As the model improves, pseudo-labels become more accurate, enabling progressively harder consistency tasks
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
import torchimport torchvision.transforms as Tfrom typing import Tuple, Callableimport numpy as np class WeakStrongAugmentation: """ Implements the weak-to-strong augmentation paradigm. - Weak: Standard minimal augmentation for reliable pseudo-labels - Strong: Aggressive augmentation for challenging consistency targets """ def __init__( self, image_size: int = 32, padding: int = 4, mean: Tuple[float, ...] = (0.4914, 0.4822, 0.4465), std: Tuple[float, ...] = (0.2470, 0.2435, 0.2616), ): self.image_size = image_size # Weak augmentation: flip + random crop (standard supervised aug) self.weak = T.Compose([ T.RandomHorizontalFlip(p=0.5), T.RandomCrop( image_size, padding=padding, padding_mode='reflect' ), T.ToTensor(), T.Normalize(mean, std), ]) # Strong augmentation: RandAugment + Cutout self.strong = T.Compose([ T.RandomHorizontalFlip(p=0.5), T.RandomCrop( image_size, padding=padding, padding_mode='reflect' ), RandAugmentTransform(n=2, m=10), # 2 ops, magnitude 10 T.ToTensor(), T.Normalize(mean, std), Cutout(n_holes=1, length=16), # Cutout regularization ]) def __call__(self, image) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns both weak and strong augmented versions. Args: image: PIL Image Returns: (weak_augmented, strong_augmented) tensor tuple """ return self.weak(image), self.strong(image) class RandAugmentTransform: """ RandAugment: Random selection of N augmentations with magnitude M. Paper: "RandAugment: Practical Automated Data Augmentation with a Reduced Search Space" (Cubuk et al., 2020) Key insight: Simple random selection with uniform magnitude matches or exceeds learned augmentation policies. """ # Available augmentation operations OPERATIONS = [ 'identity', 'autocontrast', 'equalize', 'rotate', 'solarize', 'color', 'posterize', 'contrast', 'brightness', 'sharpness', 'shear_x', 'shear_y', 'translate_x', 'translate_y' ] def __init__(self, n: int = 2, m: int = 10): """ Args: n: Number of operations to apply sequentially m: Magnitude of operations (0-30 scale, 10 is standard) """ self.n = n self.m = m def __call__(self, image): """Apply n random augmentations.""" import random ops = random.choices(self.OPERATIONS, k=self.n) for op in ops: image = self._apply_op(image, op, self.m) return image def _apply_op(self, image, operation: str, magnitude: int): """Apply single operation with given magnitude.""" from PIL import ImageEnhance, ImageOps import random # Normalize magnitude to [0, 1] range M = magnitude / 30.0 if operation == 'identity': return image elif operation == 'autocontrast': return ImageOps.autocontrast(image) elif operation == 'equalize': return ImageOps.equalize(image) elif operation == 'rotate': degrees = M * 30 # max 30 degrees return image.rotate(random.uniform(-degrees, degrees)) elif operation == 'solarize': threshold = int((1 - M) * 256) return ImageOps.solarize(image, threshold) elif operation == 'color': factor = 1 + random.uniform(-M, M) * 0.9 return ImageEnhance.Color(image).enhance(factor) elif operation == 'contrast': factor = 1 + random.uniform(-M, M) * 0.9 return ImageEnhance.Contrast(image).enhance(factor) elif operation == 'brightness': factor = 1 + random.uniform(-M, M) * 0.9 return ImageEnhance.Brightness(image).enhance(factor) elif operation == 'sharpness': factor = 1 + random.uniform(-M, M) * 0.9 return ImageEnhance.Sharpness(image).enhance(factor) elif operation == 'posterize': bits = int(8 - M * 4) # 8 to 4 bits return ImageOps.posterize(image, max(1, bits)) elif operation == 'shear_x': shear = M * 0.3 * random.choice([-1, 1]) return image.transform( image.size, Image.AFFINE, (1, shear, 0, 0, 1, 0) ) elif operation == 'shear_y': shear = M * 0.3 * random.choice([-1, 1]) return image.transform( image.size, Image.AFFINE, (1, 0, 0, shear, 1, 0) ) elif operation == 'translate_x': pixels = int(M * image.size[0] * 0.45) pixels *= random.choice([-1, 1]) return image.transform( image.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0) ) elif operation == 'translate_y': pixels = int(M * image.size[1] * 0.45) pixels *= random.choice([-1, 1]) return image.transform( image.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels) ) return image class Cutout: """ Cutout regularization: Random erasing of square patches. Paper: "Improved Regularization of Convolutional Neural Networks with Cutout" (DeVries & Taylor, 2017) """ def __init__(self, n_holes: int = 1, length: int = 16): self.n_holes = n_holes self.length = length def __call__(self, tensor: torch.Tensor) -> torch.Tensor: """Apply cutout to tensor [C, H, W].""" h, w = tensor.shape[1:] mask = torch.ones_like(tensor) for _ in range(self.n_holes): # Random center y = np.random.randint(h) x = np.random.randint(w) # Compute patch bounds y1 = max(0, y - self.length // 2) y2 = min(h, y + self.length // 2) x1 = max(0, x - self.length // 2) x2 = min(w, x + self.length // 2) # Zero out patch mask[:, y1:y2, x1:x2] = 0 return tensor * maskRandAugment represents a paradigm shift in automated data augmentation. Prior approaches like AutoAugment used reinforcement learning to search for optimal augmentation policies—a computationally expensive process requiring thousands of GPU hours. RandAugment demonstrated that a much simpler approach works just as well.
The RandAugment Algorithm:
The key hyperparameters are:
RandAugment works because: (1) The optimal augmentation policy often doesn't depend on the dataset—strong random augmentations generalize across tasks. (2) Uniform magnitude across operations approximates learned policies well. (3) Simplicity enables better hyperparameter tuning—you only tune N and M instead of hundreds of parameters.
| Operation | Description | Magnitude Effect | Semantics Preservation |
|---|---|---|---|
| Identity | No change | N/A | Perfect |
| AutoContrast | Maximize image contrast | N/A (no magnitude) | High |
| Equalize | Histogram equalization | N/A (no magnitude) | High |
| Rotate | Rotate by angle | 0° to ±30° | Usually high* |
| Solarize | Invert pixels above threshold | 256 to 0 threshold | Medium |
| Color | Adjust color saturation | 0.1x to 1.9x | High |
| Posterize | Reduce bits per channel | 8 to 4 bits | Medium |
| Contrast | Adjust contrast | 0.1x to 1.9x | High |
| Brightness | Adjust brightness | 0.1x to 1.9x | High |
| Sharpness | Adjust sharpness | 0.1x to 1.9x | High |
| Shear X/Y | Shear transformation | 0° to ±30° | High |
| Translate X/Y | Translate image | 0 to ±45% of dim | High |
*Note: Rotation may not preserve semantics for orientation-sensitive tasks (e.g., digit recognition where 6 and 9 are different).
Magnitude Scheduling:
While RandAugment uses fixed magnitude, some implementations benefit from magnitude scheduling:
The intuition: early in training, the model is fragile and benefits from gentler augmentations. As training progresses, stronger augmentations provide harder challenges.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
import numpy as npfrom typing import List, Tuple, Optional class RandAugmentScheduled: """ RandAugment with magnitude scheduling. Supports: - Fixed magnitude (standard RandAugment) - Linear ramp-up from low to target magnitude - Random magnitude sampling """ def __init__( self, n_ops: int = 2, magnitude: int = 10, num_magnitude_bins: int = 31, schedule: str = "fixed", # 'fixed', 'linear', 'random' warmup_epochs: int = 0, max_epochs: int = 1000, ): self.n_ops = n_ops self.target_magnitude = magnitude self.num_bins = num_magnitude_bins self.schedule = schedule self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs self._current_epoch = 0 def set_epoch(self, epoch: int): """Update current epoch for scheduling.""" self._current_epoch = epoch def get_magnitude(self) -> int: """Get magnitude based on schedule.""" if self.schedule == "fixed": return self.target_magnitude elif self.schedule == "linear": # Linear ramp-up during warmup if self._current_epoch < self.warmup_epochs: progress = self._current_epoch / self.warmup_epochs return int(self.target_magnitude * progress) return self.target_magnitude elif self.schedule == "random": # Uniform random sampling return np.random.randint(1, self.target_magnitude + 1) return self.target_magnitude class CTAugment: """ Control Theory Augment (CTAugment). Learns augmentation weights online by tracking which augmentation-magnitude combinations preserve prediction confidence. Used in ReMixMatch. Paper: "ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring" """ def __init__( self, operations: List[str], num_bins: int = 17, decay_rate: float = 0.99, threshold: float = 0.8, ): self.operations = operations self.num_bins = num_bins self.decay = decay_rate self.threshold = threshold # Initialize weight bins for each (operation, magnitude) pair # Start at 1.0 (all augmentations equally likely) self.weights = { op: np.ones(num_bins) for op in operations } def update( self, operation: str, magnitude_bin: int, confidence: float ): """ Update weights based on model's prediction confidence after applying augmentation. Args: operation: Augmentation operation name magnitude_bin: Which magnitude bin was used confidence: Model's max softmax probability on augmented input """ # Target is 1 if confidence above threshold, else 0 target = 1.0 if confidence >= self.threshold else 0.0 # Exponential moving average update self.weights[operation][magnitude_bin] = ( self.decay * self.weights[operation][magnitude_bin] + (1 - self.decay) * target ) def sample_operation(self) -> Tuple[str, int]: """ Sample operation and magnitude proportional to learned weights. Returns: (operation_name, magnitude_bin) tuple """ # Sample operation uniformly (or could weight by sum of bins) op = np.random.choice(self.operations) # Sample magnitude proportional to weights op_weights = self.weights[op] probs = op_weights / op_weights.sum() mag_bin = np.random.choice(self.num_bins, p=probs) return op, mag_bin def apply(self, image, n_ops: int = 2): """Apply n_ops sampled augmentations.""" for _ in range(n_ops): op, mag_bin = self.sample_operation() magnitude = mag_bin / (self.num_bins - 1) # Normalize to [0, 1] image = self._apply_operation(image, op, magnitude) return image def _apply_operation(self, image, operation: str, magnitude: float): """Apply operation at given normalized magnitude [0, 1].""" # Implementation similar to RandAugment # Magnitude mapped to operation-specific range pass # Abbreviated for clarityWhile image augmentation is well-studied, consistency regularization applies to all data modalities. Each domain requires carefully designed augmentations that preserve semantics while providing meaningful variation.
The challenge: Augmentations must be semantics-preserving for your specific task. What works for object classification may break digit recognition (rotation) or medical imaging (color changes).
Standard Image Augmentations:
| Category | Operations | Typical Settings |
|---|---|---|
| Geometric | Flip, rotate, crop, scale, translate | Horizontal flip: 50%, Rotation: ±15°, Scale: 0.8-1.2x |
| Color | Brightness, contrast, saturation, hue | Each ±20-40% of original |
| Noise/Blur | Gaussian noise, motion blur, JPEG compression | Varies by task |
| Erasing | Cutout, random erasing, GridMask | 10-25% of image area |
Domain-Specific Considerations:
The integration of data augmentation with consistency loss requires careful design decisions. Different methods make different choices about how augmentations generate targets and predictions.
Core Question: Given an unlabeled sample $u$, how do we generate:
Approaches:
| Method | Target Generation | Input for Prediction | Consistency Loss |
|---|---|---|---|
| Π-Model | Model(Aug(u)) | Model(Aug(u)) [different] | MSE between two forward passes |
| Mean Teacher | EMA_Model(Aug(u)) | Model(Aug(u)) | MSE between student and teacher |
| UDA | Model(Weak(u)) | Model(Strong(u)) | KL divergence |
| FixMatch | argmax Model(Weak(u)) | Model(Strong(u)) | Cross-entropy with hard pseudo-label |
| MixMatch | Sharpen(avg over K augs) | Model(Mixup(augs)) | MSE in label space |
Key Design Decisions:
1. Soft vs. Hard Targets:
FixMatch uses hard targets with confidence thresholding to get the best of both: strong signal only when confident.
2. Same vs. Different Augmentations:
3. Single vs. Multiple Augmentations:
4. Student-Teacher vs. Self-Consistency:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
import torchimport torch.nn.functional as Ffrom typing import List, Callable def pi_model_loss( model: torch.nn.Module, unlabeled: torch.Tensor, augment: Callable, n_augmentations: int = 2) -> torch.Tensor: """ Π-Model: Consistency between multiple stochastic forward passes. Both target and prediction come from the same model with different augmentations and dropout. """ model.train() # Enable dropout predictions = [] for _ in range(n_augmentations): aug_input = augment(unlabeled) logits = model(aug_input) predictions.append(F.softmax(logits, dim=-1)) # Compute pairwise consistency loss loss = 0.0 count = 0 for i in range(n_augmentations): for j in range(i + 1, n_augmentations): loss += F.mse_loss(predictions[i], predictions[j]) count += 1 return loss / count def mean_teacher_loss( student: torch.nn.Module, teacher: torch.nn.Module, # EMA of student unlabeled: torch.Tensor, augment: Callable,) -> torch.Tensor: """ Mean Teacher: Consistency between student and EMA teacher. Teacher (EMA) provides stable targets. Student learns to match teacher on augmented inputs. """ # Teacher generates target (no gradient) teacher.eval() with torch.no_grad(): teacher_aug = augment(unlabeled) teacher_logits = teacher(teacher_aug) target = F.softmax(teacher_logits, dim=-1) # Student generates prediction student.train() student_aug = augment(unlabeled) student_logits = student(student_aug) prediction = F.softmax(student_logits, dim=-1) return F.mse_loss(prediction, target) def uda_loss( model: torch.nn.Module, unlabeled: torch.Tensor, augment_weak: Callable, augment_strong: Callable, temperature: float = 0.4, threshold: float = 0.8,) -> torch.Tensor: """ UDA (Unsupervised Data Augmentation): Weak-to-Strong consistency. Key insights: 1. Use weak augmentation for pseudo-label quality 2. Use strong augmentation for learning signal 3. Apply confidence threshold and sharpening """ # Generate pseudo-labels from weakly augmented input with torch.no_grad(): weak_aug = augment_weak(unlabeled) weak_logits = model(weak_aug) weak_probs = F.softmax(weak_logits, dim=-1) # Sharpen the distribution sharp_probs = (weak_probs ** (1 / temperature)) sharp_probs = sharp_probs / sharp_probs.sum(dim=-1, keepdim=True) # Confidence mask max_probs, _ = weak_probs.max(dim=-1) mask = (max_probs >= threshold).float() # Predict on strongly augmented input strong_aug = augment_strong(unlabeled) strong_logits = model(strong_aug) # KL divergence loss with mask log_probs = F.log_softmax(strong_logits, dim=-1) loss = -(sharp_probs * log_probs).sum(dim=-1) loss = (loss * mask).sum() / (mask.sum() + 1e-8) return loss def fixmatch_loss( model: torch.nn.Module, unlabeled: torch.Tensor, augment_weak: Callable, augment_strong: Callable, threshold: float = 0.95,) -> tuple[torch.Tensor, dict]: """ FixMatch: Hard pseudo-labels with high confidence threshold. Simplification of UDA: 1. Hard pseudo-labels (not soft) 2. Standard cross-entropy (not KL) 3. Higher threshold (0.95) """ # Generate pseudo-labels from weakly augmented input with torch.no_grad(): weak_aug = augment_weak(unlabeled) weak_logits = model(weak_aug) weak_probs = F.softmax(weak_logits, dim=-1) # Hard pseudo-label max_probs, pseudo_labels = weak_probs.max(dim=-1) # High confidence mask mask = max_probs >= threshold # Only compute loss if we have confident samples if mask.sum() == 0: return torch.tensor(0.0, device=unlabeled.device), { "n_masked": 0, "mask_ratio": 0.0, } # Predict on strongly augmented input strong_aug = augment_strong(unlabeled) strong_logits = model(strong_aug) # Cross-entropy on masked samples loss = F.cross_entropy( strong_logits[mask], pseudo_labels[mask], reduction='mean' ) return loss, { "n_masked": mask.sum().item(), "mask_ratio": mask.float().mean().item(), "avg_confidence": max_probs.mean().item(), }Beyond augmentation strength, augmentation diversity plays a crucial role in consistency regularization. Diverse augmentations expose the model to varied transformations of the same semantic content, promoting robust feature learning.
Why Diversity Matters:
Coverage of invariance space: Different augmentations test different invariances (rotation, color, scale). Diverse augmentations collectively enforce a broad invariance prior.
Avoiding shortcuts: If augmentations are too predictable, models can learn to "undo" them rather than learning invariant features.
Gradient diversity: Different augmentations provide different gradient signals, smoothing the optimization landscape.
Reduced correlation: Diverse augmentations reduce correlation between training examples, improving generalization.
There's a tension between diversity and difficulty. Too many different augmentations might include some that are too easy (providing weak signal) or too hard (providing noisy signal). Methods like CTAugment balance this by learning which augmentations are effective.
Measuring Augmentation Effectiveness:
How do we know if augmentations are effective? Several metrics help:
| Metric | Description | Ideal Range |
|---|---|---|
| Prediction agreement | How often do original and augmented get same prediction? | 70-90% |
| Confidence drop | How much does confidence decrease on augmented? | 5-20% |
| Feature distance | How far are features for original vs. augmented? | Moderate |
| Gradient magnitude | How strong are gradients from consistency loss? | Non-zero but stable |
If prediction agreement is 100%, augmentations are too weak. If it's 50%, they might be too strong (semantics not preserved).
Implementing augmentation-based consistency regularization effectively requires attention to several practical details that significantly impact performance.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
import torchfrom torch.utils.data import DataLoaderfrom typing import Optionalimport numpy as np class SemiSupervisedTrainer: """ Training loop for semi-supervised learning with augmentation-based consistency regularization. """ def __init__( self, model: torch.nn.Module, labeled_loader: DataLoader, unlabeled_loader: DataLoader, augment_weak, augment_strong, optimizer: torch.optim.Optimizer, lambda_u: float = 1.0, threshold: float = 0.95, warmup_epochs: int = 5, total_epochs: int = 1024, device: str = "cuda", ): self.model = model self.labeled_loader = labeled_loader self.unlabeled_loader = unlabeled_loader self.augment_weak = augment_weak self.augment_strong = augment_strong self.optimizer = optimizer self.lambda_u = lambda_u self.threshold = threshold self.warmup_epochs = warmup_epochs self.total_epochs = total_epochs self.device = device self.global_step = 0 self.epoch = 0 def get_lambda(self) -> float: """Linear warmup for consistency weight.""" if self.epoch < self.warmup_epochs: return self.lambda_u * (self.epoch / self.warmup_epochs) return self.lambda_u def train_epoch(self) -> dict: """Train for one epoch.""" self.model.train() metrics = { "sup_loss": 0.0, "unsup_loss": 0.0, "total_loss": 0.0, "mask_ratio": 0.0, "n_steps": 0, } # Iterate over labeled and unlabeled data together unlabeled_iter = iter(self.unlabeled_loader) for labeled_batch in self.labeled_loader: # Get next unlabeled batch (cycle if exhausted) try: unlabeled_batch = next(unlabeled_iter) except StopIteration: unlabeled_iter = iter(self.unlabeled_loader) unlabeled_batch = next(unlabeled_iter) # Unpack data x_l, y_l = labeled_batch x_u, = unlabeled_batch # No labels x_l, y_l = x_l.to(self.device), y_l.to(self.device) x_u = x_u.to(self.device) # ===================== # Supervised loss # ===================== x_l_aug = self.augment_weak(x_l) logits_l = self.model(x_l_aug) loss_sup = torch.nn.functional.cross_entropy(logits_l, y_l) # ===================== # Unsupervised loss (FixMatch style) # ===================== with torch.no_grad(): x_u_weak = self.augment_weak(x_u) logits_weak = self.model(x_u_weak) probs_weak = torch.softmax(logits_weak, dim=-1) max_probs, pseudo_labels = probs_weak.max(dim=-1) mask = max_probs >= self.threshold x_u_strong = self.augment_strong(x_u) logits_strong = self.model(x_u_strong) if mask.sum() > 0: loss_unsup = torch.nn.functional.cross_entropy( logits_strong[mask], pseudo_labels[mask], reduction='mean' ) else: loss_unsup = torch.tensor(0.0, device=self.device) # ===================== # Combined loss # ===================== lambda_eff = self.get_lambda() loss_total = loss_sup + lambda_eff * loss_unsup # Optimization step self.optimizer.zero_grad() loss_total.backward() self.optimizer.step() # Update metrics metrics["sup_loss"] += loss_sup.item() metrics["unsup_loss"] += loss_unsup.item() metrics["total_loss"] += loss_total.item() metrics["mask_ratio"] += mask.float().mean().item() metrics["n_steps"] += 1 self.global_step += 1 # Average metrics n = metrics["n_steps"] for key in ["sup_loss", "unsup_loss", "total_loss", "mask_ratio"]: metrics[key] /= n self.epoch += 1 return metricsData augmentation is not just a preprocessing step—it's the core mechanism that makes consistency regularization effective. The quality and design of augmentations often matters more than the specific semi-supervised algorithm used.
What's Next:
With the foundations of consistency and augmentation established, we're ready to examine the state-of-the-art methods that combine these principles. The next page explores UDA and FixMatch—the elegant algorithms that have set new benchmarks in semi-supervised learning through careful integration of strong augmentation with consistency regularization.
You now understand how data augmentation drives consistency regularization. This knowledge is essential for implementing and improving semi-supervised methods—augmentation is where empirical gains are often found.