Loading content...
Hand-designing augmentation pipelines is laborious, error-prone, and domain-specific. A strategy that works for ImageNet may fail for medical imaging; what helps satellite imagery may hurt document analysis. Each new domain requires extensive experimentation to find the right combination of transformations, magnitudes, and probabilities.
AutoAugment changes this paradigm fundamentally. Instead of manually designing augmentation policies, we formulate augmentation design as a search problem and let algorithms discover optimal strategies automatically. The result: policies that consistently outperform human intuition, often by discovering non-obvious augmentation combinations.
This page explores the landscape of learned augmentation—from the original reinforcement learning approach through efficient differentiable alternatives. Understanding these methods is essential for practitioners working across diverse domains where transfer of hand-designed augmentations fails.
By the end of this page, you will understand the AutoAugment search space formulation, implement RandAugment's simplified approach requiring only two hyperparameters, apply FastAutoAugment and differentiable augmentation search, and know when learned policies provide meaningful improvements over manual design.
AutoAugment (Cubuk et al., 2019) was the first systematic approach to learning augmentation policies. It formulates augmentation design as a discrete search problem over a structured policy space.
An AutoAugment policy consists of 5 sub-policies, each containing 2 operations. Each operation specifies:
During training, a random sub-policy is selected for each image, and both operations within that sub-policy are applied sequentially (if their probability check passes).
The total number of possible policies is astronomical:
$$|\mathcal{S}| = (16 \times 11 \times 11)^{2 \times 5} = (1936)^{10} \approx 2.9 \times 10^{32}$$
Exhaustive search is impossible. AutoAugment uses Proximal Policy Optimization (PPO), a reinforcement learning algorithm, where:
| Category | Transformations | Magnitude Meaning |
|---|---|---|
| Geometric | ShearX, ShearY, TranslateX, TranslateY, Rotate | Degrees or pixel offset |
| Photometric | Brightness, Color, Contrast, Sharpness | Intensity factor [0.1, 1.9] |
| Distortion | AutoContrast, Equalize, Invert, Solarize, Posterize | Threshold or bit depth |
| Mixing | Cutout, SamplePairing | Region size or blend ratio |
This process is computationally expensive—the original AutoAugment required 15,000 GPU hours for ImageNet. However, learned policies transfer well: the ImageNet policy works on CIFAR-10, and the CIFAR-10 policy often helps other small-scale datasets.
AutoAugment often discovers non-intuitive strategies. For ImageNet, it heavily uses geometric transforms (rotation, shear) but avoids Cutout. For SVHN digit recognition, it emphasizes color inversion and shearing. These discoveries would be difficult to anticipate through manual design.
The complexity of AutoAugment's search motivated simpler alternatives. RandAugment (Cubuk et al., 2020) dramatically simplifies the approach while achieving comparable or better results.
AutoAugment searches for optimal (operation, probability, magnitude) tuples. RandAugment observes that:
RandAugment needs only two hyperparameters:
For each image:
This reduces the search space from $10^{32}$ possibilities to approximately $30 \times 3 = 90$ (N, M) pairs that can be searched via simple grid search.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
import torchimport numpy as npfrom PIL import Image, ImageOps, ImageEnhancefrom typing import List, Tuple, Callableimport random # Define augmentation operationsdef shear_x(img: Image.Image, magnitude: float) -> Image.Image: """Shear image along x-axis.""" return img.transform( img.size, Image.AFFINE, (1, magnitude, 0, 0, 1, 0), resample=Image.BILINEAR ) def shear_y(img: Image.Image, magnitude: float) -> Image.Image: """Shear image along y-axis.""" return img.transform( img.size, Image.AFFINE, (1, 0, 0, magnitude, 1, 0), resample=Image.BILINEAR ) def translate_x(img: Image.Image, magnitude: float) -> Image.Image: """Translate image along x-axis.""" pixels = int(magnitude * img.size[0]) return img.transform( img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), resample=Image.BILINEAR ) def translate_y(img: Image.Image, magnitude: float) -> Image.Image: """Translate image along y-axis.""" pixels = int(magnitude * img.size[1]) return img.transform( img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), resample=Image.BILINEAR ) def rotate(img: Image.Image, magnitude: float) -> Image.Image: """Rotate image by specified degrees.""" return img.rotate(magnitude, resample=Image.BILINEAR, fillcolor=(128, 128, 128)) def brightness(img: Image.Image, magnitude: float) -> Image.Image: """Adjust image brightness.""" return ImageEnhance.Brightness(img).enhance(1 + magnitude) def color(img: Image.Image, magnitude: float) -> Image.Image: """Adjust color saturation.""" return ImageEnhance.Color(img).enhance(1 + magnitude) def contrast(img: Image.Image, magnitude: float) -> Image.Image: """Adjust image contrast.""" return ImageEnhance.Contrast(img).enhance(1 + magnitude) def sharpness(img: Image.Image, magnitude: float) -> Image.Image: """Adjust image sharpness.""" return ImageEnhance.Sharpness(img).enhance(1 + magnitude) def auto_contrast(img: Image.Image, magnitude: float) -> Image.Image: """Apply auto contrast (magnitude unused).""" return ImageOps.autocontrast(img) def equalize(img: Image.Image, magnitude: float) -> Image.Image: """Histogram equalization (magnitude unused).""" return ImageOps.equalize(img) def invert(img: Image.Image, magnitude: float) -> Image.Image: """Invert colors (magnitude unused).""" return ImageOps.invert(img) def solarize(img: Image.Image, magnitude: float) -> Image.Image: """Solarize image above threshold.""" threshold = int(magnitude * 255) return ImageOps.solarize(img, threshold) def posterize(img: Image.Image, magnitude: float) -> Image.Image: """Reduce color bit depth.""" bits = int(8 - magnitude * 4) # map magnitude to 4-8 bits bits = max(1, min(8, bits)) return ImageOps.posterize(img, bits) class RandAugment: """ RandAugment data augmentation as described in: 'RandAugment: Practical automated data augmentation with a reduced search space' (Cubuk et al., 2020) Applies N random transformations each with magnitude M. """ # Operation pool with (function, magnitude_range) AUGMENTATION_POOL = [ ("ShearX", shear_x, (-0.3, 0.3)), ("ShearY", shear_y, (-0.3, 0.3)), ("TranslateX", translate_x, (-0.45, 0.45)), ("TranslateY", translate_y, (-0.45, 0.45)), ("Rotate", rotate, (-30, 30)), ("Brightness", brightness, (-0.9, 0.9)), ("Color", color, (-0.9, 0.9)), ("Contrast", contrast, (-0.9, 0.9)), ("Sharpness", sharpness, (-0.9, 0.9)), ("AutoContrast", auto_contrast, (0, 0)), ("Equalize", equalize, (0, 0)), ("Invert", invert, (0, 0)), ("Solarize", solarize, (0, 1)), ("Posterize", posterize, (0, 1)), ] def __init__( self, n: int = 2, m: int = 10, max_magnitude: int = 30 ): """ Parameters: ----------- n : int Number of transformations to apply per image m : int Global magnitude (0 to max_magnitude) max_magnitude : int Maximum possible magnitude value """ self.n = n self.m = m self.max_magnitude = max_magnitude def _apply_op( self, img: Image.Image, op_name: str, op_fn: Callable, magnitude_range: Tuple[float, float] ) -> Image.Image: """ Apply operation with magnitude scaled to range. """ min_val, max_val = magnitude_range magnitude = (self.m / self.max_magnitude) * (max_val - min_val) + min_val # Random sign for symmetric operations if random.random() < 0.5 and min_val < 0: magnitude = -magnitude return op_fn(img, magnitude) def __call__(self, img: Image.Image) -> Image.Image: """ Apply RandAugment to image. Parameters: ----------- img : PIL Image Input image Returns: -------- Augmented PIL Image """ # Randomly select N operations ops = random.choices(self.AUGMENTATION_POOL, k=self.n) # Apply each operation for op_name, op_fn, magnitude_range in ops: img = self._apply_op(img, op_name, op_fn, magnitude_range) return img # Example usagedef get_imagenet_train_transform(randaug_n: int = 2, randaug_m: int = 9): """ Create ImageNet training transform with RandAugment. """ from torchvision import transforms return transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), RandAugment(n=randaug_n, m=randaug_m), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ])A key insight of RandAugment is that optimal magnitude scales with model size and dataset size:
| Model | Dataset | Optimal N | Optimal M |
|---|---|---|---|
| ResNet-50 | ImageNet | 2 | 9 |
| EfficientNet-B7 | ImageNet | 2 | 17 |
| Wide-ResNet-28-10 | CIFAR-10 | 3 | 6 |
| ResNet-200 | ImageNet | 2 | 14 |
Larger models have more capacity and need stronger regularization (higher M). Smaller datasets also benefit from stronger augmentation to prevent overfitting.
Start with N=2, M=10 for most applications. If training shows underfitting (training accuracy low), reduce M. If validation accuracy lags training (overfitting), increase M. Grid search over M ∈ {5, 7, 9, 11, 13, 15} typically suffices.
The 15,000 GPU-hour cost of AutoAugment motivated more efficient search methods. Fast AutoAugment (Lim et al., 2019) reduces search time to under 5 GPU-hours while achieving comparable performance.
AutoAugment trains child models to completion for each candidate policy—the main computational bottleneck. Fast AutoAugment observes that effective augmentation policies should create augmented images that match the distribution of held-out validation data:
$$\text{Policy}^* = \arg\min_\pi D_{KL}(p_{aug}(x|\pi) || p_{val}(x))$$
This can be estimated without full training through density matching proxies.
The search uses Bayesian optimization with Tree-structured Parzen Estimators (TPE) to efficiently explore the policy space.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
import numpy as npfrom typing import List, Tuple, Dictfrom dataclasses import dataclassimport torchimport torch.nn as nn @dataclassclass AugmentOperation: """Represents a single augmentation operation.""" name: str probability: float magnitude: float @dataclass class SubPolicy: """A sub-policy contains two operations.""" op1: AugmentOperation op2: AugmentOperation @dataclassclass Policy: """A full policy contains 5 sub-policies.""" sub_policies: List[SubPolicy] class FastAutoAugmentSearch: """ Simplified Fast AutoAugment search implementation. Uses Bayesian optimization to find policies that maximize validation set predictability under augmentation. """ def __init__( self, model: nn.Module, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, device: str = 'cuda', num_trials: int = 100 ): self.model = model self.train_loader = train_loader self.val_loader = val_loader self.device = device self.num_trials = num_trials # Available operations self.operations = [ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'AutoContrast', 'Invert', 'Equalize', 'Solarize', 'Posterize', 'Contrast', 'Color', 'Brightness', 'Sharpness', 'Cutout' ] def _sample_subpolicy(self) -> SubPolicy: """Sample a random sub-policy.""" return SubPolicy( op1=AugmentOperation( name=np.random.choice(self.operations), probability=np.random.uniform(0, 1), magnitude=np.random.uniform(0, 1) ), op2=AugmentOperation( name=np.random.choice(self.operations), probability=np.random.uniform(0, 1), magnitude=np.random.uniform(0, 1) ) ) def _sample_policy(self) -> Policy: """Sample a random policy with 5 sub-policies.""" return Policy( sub_policies=[self._sample_subpolicy() for _ in range(5)] ) def _evaluate_policy( self, policy: Policy ) -> float: """ Evaluate a policy by measuring how well the model predicts augmented validation samples. Uses the insight that good augmentations create samples that the model can still correctly classify. """ self.model.eval() correct = 0 total = 0 with torch.no_grad(): for images, labels in self.val_loader: # Apply policy to validation images augmented = self._apply_policy(images, policy) augmented = augmented.to(self.device) labels = labels.to(self.device) outputs = self.model(augmented) _, predicted = outputs.max(1) correct += predicted.eq(labels).sum().item() total += labels.size(0) return correct / total def _apply_policy( self, images: torch.Tensor, policy: Policy ) -> torch.Tensor: """ Apply policy to batch of images. Randomly selects one sub-policy per image. """ B = images.size(0) augmented = images.clone() for i in range(B): # Random sub-policy selection sub_policy = np.random.choice(policy.sub_policies) # Apply operations probabilistically if np.random.random() < sub_policy.op1.probability: augmented[i] = self._apply_operation( augmented[i], sub_policy.op1.name, sub_policy.op1.magnitude ) if np.random.random() < sub_policy.op2.probability: augmented[i] = self._apply_operation( augmented[i], sub_policy.op2.name, sub_policy.op2.magnitude ) return augmented def _apply_operation( self, img: torch.Tensor, op_name: str, magnitude: float ) -> torch.Tensor: """ Apply a single operation to an image tensor. Implementation would map op_name to actual transforms. """ # Placeholder - actual implementation would apply PIL transforms return img def search(self) -> Policy: """ Run the policy search. Uses random search as a simple baseline. Full implementation would use TPE Bayesian optimization. """ best_policy = None best_score = 0 for trial in range(self.num_trials): # Sample candidate policy policy = self._sample_policy() # Evaluate score = self._evaluate_policy(policy) if score > best_score: best_score = score best_policy = policy print(f"Trial {trial}: New best score = {score:.4f}") return best_policyPopulation-Based Augmentation (PBA) takes a different approach: instead of searching for a fixed policy, it evolves policies during training:
This discovers dynamic schedules where augmentation intensity changes through training—often starting mild and increasing.
DADA (Li et al., 2020) makes the search fully differentiable by relaxing discrete choices:
$$\text{aug}(x) = \sum_{o \in \mathcal{O}} \alpha_o \cdot T_o(x, m_o)$$
where $\alpha_o$ are softmax-weighted probabilities over operations $o$, learned end-to-end with the model. This reduces search to a single training run.
For most practitioners, RandAugment or the pretrained AutoAugment policies are sufficient. Fast AutoAugment or DADA are worthwhile only when: (1) working on a specialized domain where standard policies don't transfer, (2) training many models where amortized search cost is low, or (3) seeking the last 0.1-0.3% accuracy improvement.
Even RandAugment's two hyperparameters (N, M) require tuning. TrivialAugment (Müller & Hutter, 2021) takes simplification to its logical conclusion: zero hyperparameters.
For each image:
That's it. No N to tune, no M to set. Despite (or perhaps because of) this simplicity, TrivialAugment matches or exceeds RandAugment's performance.
The success of TrivialAugment challenges assumptions about augmentation design:
1. Single operations suffice Applying multiple transforms (N>1) doesn't consistently improve results. The key is consistent exposure to transformations, not stacking them.
2. Random magnitude reduces overfitting Fixed magnitude can cause the model to memorize specific distortion levels. Random magnitude forces learning across the full transform spectrum.
3. Hyperparameter sensitivity harms generalization Models trained with "optimal" (N, M) for the validation set may overfit those specific settings. Random sampling is more robust.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
import randomfrom PIL import Image, ImageOps, ImageEnhancefrom typing import Tuple, List class TrivialAugment: """ TrivialAugment data augmentation. No hyperparameters: randomly samples one operation with random magnitude for each image. Reference: 'TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation' (Müller & Hutter, 2021) """ # Each operation: (name, function, magnitude_bins) # magnitude_bins[i] = magnitude value for bin i AUGMENTATION_SPACE = [ ("Identity", lambda img, v: img, [0]), ("ShearX", lambda img, v: img.transform( img.size, Image.AFFINE, (1, v, 0, 0, 1, 0), resample=Image.BILINEAR ), [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]), ("ShearY", lambda img, v: img.transform( img.size, Image.AFFINE, (1, 0, 0, v, 1, 0), resample=Image.BILINEAR ), [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]), ("TranslateX", lambda img, v: img.transform( img.size, Image.AFFINE, (1, 0, int(v * img.size[0]), 0, 1, 0), resample=Image.BILINEAR ), [-0.45, -0.30, -0.15, 0, 0.15, 0.30, 0.45]), ("TranslateY", lambda img, v: img.transform( img.size, Image.AFFINE, (1, 0, 0, 0, 1, int(v * img.size[1])), resample=Image.BILINEAR ), [-0.45, -0.30, -0.15, 0, 0.15, 0.30, 0.45]), ("Rotate", lambda img, v: img.rotate(v, resample=Image.BILINEAR, fillcolor=(128,128,128)), [-30, -20, -10, 0, 10, 20, 30]), ("Brightness", lambda img, v: ImageEnhance.Brightness(img).enhance(1 + v), [-0.9, -0.6, -0.3, 0, 0.3, 0.6, 0.9]), ("Color", lambda img, v: ImageEnhance.Color(img).enhance(1 + v), [-0.9, -0.6, -0.3, 0, 0.3, 0.6, 0.9]), ("Contrast", lambda img, v: ImageEnhance.Contrast(img).enhance(1 + v), [-0.9, -0.6, -0.3, 0, 0.3, 0.6, 0.9]), ("Sharpness", lambda img, v: ImageEnhance.Sharpness(img).enhance(1 + v), [-0.9, -0.6, -0.3, 0, 0.3, 0.6, 0.9]), ("AutoContrast", lambda img, v: ImageOps.autocontrast(img), [0]), ("Equalize", lambda img, v: ImageOps.equalize(img), [0]), ("Solarize", lambda img, v: ImageOps.solarize(img, int(v)), [256, 200, 150, 100, 50, 0]), ("Posterize", lambda img, v: ImageOps.posterize(img, int(v)), [8, 7, 6, 5, 4, 3, 2, 1]), ] def __init__(self, exclude_identity: bool = False): """ Parameters: ----------- exclude_identity : bool If True, never sample the Identity (no-op) transform """ self.augmentations = [ aug for aug in self.AUGMENTATION_SPACE if not (exclude_identity and aug[0] == "Identity") ] def __call__(self, img: Image.Image) -> Image.Image: """ Apply TrivialAugment to a single image. 1. Uniformly sample one operation 2. Uniformly sample one magnitude from that operation's range 3. Apply and return """ # Sample random operation op_name, op_fn, magnitudes = random.choice(self.augmentations) # Sample random magnitude magnitude = random.choice(magnitudes) # Apply operation return op_fn(img, magnitude) class TrivialAugmentWide(TrivialAugment): """ TrivialAugment-Wide variant with continuous magnitude sampling. Instead of discrete magnitude bins, samples uniformly from the full continuous range for each operation. """ # (name, function, (min_magnitude, max_magnitude)) AUGMENTATION_SPACE_WIDE = [ ("Identity", lambda img, v: img, (0, 0)), ("ShearX", lambda img, v: img.transform( img.size, Image.AFFINE, (1, v, 0, 0, 1, 0), resample=Image.BILINEAR ), (-0.99, 0.99)), ("ShearY", lambda img, v: img.transform( img.size, Image.AFFINE, (1, 0, 0, v, 1, 0), resample=Image.BILINEAR ), (-0.99, 0.99)), ("TranslateX", lambda img, v: img.transform( img.size, Image.AFFINE, (1, 0, int(v * img.size[0]), 0, 1, 0) ), (-0.5, 0.5)), ("TranslateY", lambda img, v: img.transform( img.size, Image.AFFINE, (1, 0, 0, 0, 1, int(v * img.size[1])) ), (-0.5, 0.5)), ("Rotate", lambda img, v: img.rotate(v, fillcolor=(128, 128, 128)), (-135, 135)), ("Brightness", lambda img, v: ImageEnhance.Brightness(img).enhance(1 + v), (-0.99, 0.99)), ("Color", lambda img, v: ImageEnhance.Color(img).enhance(1 + v), (-0.99, 0.99)), ("Contrast", lambda img, v: ImageEnhance.Contrast(img).enhance(1 + v), (-0.99, 0.99)), ("Sharpness", lambda img, v: ImageEnhance.Sharpness(img).enhance(1 + v), (-0.99, 0.99)), ("AutoContrast", lambda img, v: ImageOps.autocontrast(img), (0, 0)), ("Equalize", lambda img, v: ImageOps.equalize(img), (0, 0)), ] def __init__(self): self.augmentations = self.AUGMENTATION_SPACE_WIDE def __call__(self, img: Image.Image) -> Image.Image: """Apply TrivialAugment-Wide with continuous magnitude.""" op_name, op_fn, (min_mag, max_mag) = random.choice(self.augmentations) magnitude = random.uniform(min_mag, max_mag) return op_fn(img, magnitude)| Method | Hyperparameters | Search Cost | ImageNet Top-1 | Simplicity |
|---|---|---|---|---|
| AutoAugment | ~30 per operation | 15,000 GPU-hours | 77.6% | ★☆☆☆☆ |
| Fast AutoAugment | ~30 per operation | 3.5 GPU-hours | 77.6% | ★★☆☆☆ |
| RandAugment | 2 (N, M) | Grid search | 77.6% | ★★★★☆ |
| TrivialAugment | 0 | None | 77.7% | ★★★★★ |
For new projects, TrivialAugment is an excellent default choice. It requires no tuning, matches state-of-the-art performance, and eliminates augmentation hyperparameter search from your workflow. Only consider alternatives if TrivialAugment underperforms on your specific domain.
While general-purpose policies work well for natural images, specialized domains often require carefully designed or searched policies that respect domain-specific constraints.
Constraints:
Appropriate augmentations:
Constraints:
Appropriate augmentations:
When standard policies don't transfer, domain-specific search is worthwhile:
1. Define a custom operation pool
2. Run efficient search
3. Validate on held-out domain data
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
from typing import List, Tuple, Callableimport randomfrom PIL import Image class DomainSpecificRandAugment: """ RandAugment with customizable operation pool for domain-specific applications. Allows defining custom operations and excluding inappropriate standard operations. """ def __init__( self, n: int = 2, m: int = 10, operations: List[Tuple[str, Callable, Tuple[float, float]]] = None ): """ Parameters: ----------- n : int Number of operations to apply m : int Global magnitude (0-30 scale) operations : list Custom operation pool. Each tuple contains: (name, function, (min_magnitude, max_magnitude)) If None, uses default RandAugment operations """ self.n = n self.m = m self.operations = operations or self._default_operations() def _default_operations(self): """Default RandAugment operation pool.""" # Standard operations - see full RandAugment implementation return [] def __call__(self, img: Image.Image) -> Image.Image: """Apply domain-specific RandAugment.""" ops = random.choices(self.operations, k=self.n) for name, op_fn, (min_mag, max_mag) in ops: # Scale magnitude to operation range magnitude = (self.m / 30) * (max_mag - min_mag) + min_mag img = op_fn(img, magnitude) return img # Example: Medical Imaging Policydef create_medical_randaugment(n: int = 2, m: int = 7): """ Create RandAugment policy appropriate for medical imaging. Excludes color changes that might affect diagnosis. Includes elastic deformation for tissue variability. """ from scipy.ndimage import gaussian_filter, map_coordinates import numpy as np def elastic_deformation(img: Image.Image, magnitude: float) -> Image.Image: """Apply elastic deformation appropriate for medical images.""" img_array = np.array(img) alpha = magnitude * 100 # Displacement intensity sigma = magnitude * 5 # Smoothness shape = img_array.shape[:2] dx = gaussian_filter( (np.random.rand(*shape) * 2 - 1), sigma ) * alpha dy = gaussian_filter( (np.random.rand(*shape) * 2 - 1), sigma ) * alpha x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) indices = [np.reshape(y + dy, (-1,)), np.reshape(x + dx, (-1,))] result = np.zeros_like(img_array) for c in range(img_array.shape[2] if img_array.ndim == 3 else 1): channel = img_array[:, :, c] if img_array.ndim == 3 else img_array result[:, :, c] = map_coordinates( channel, indices, order=1, mode='reflect' ).reshape(shape) return Image.fromarray(result.astype(np.uint8)) def intensity_shift(img: Image.Image, magnitude: float) -> Image.Image: """Shift pixel intensities (simulates exposure variations).""" img_array = np.array(img).astype(np.float32) shift = magnitude * 30 - 15 # +/- 15 intensity units img_array = np.clip(img_array + shift, 0, 255) return Image.fromarray(img_array.astype(np.uint8)) MEDICAL_OPERATIONS = [ # Allowed geometric transforms ("Rotate", lambda img, v: img.rotate(v * 30 - 15, fillcolor=0), (0, 1)), ("TranslateX", lambda img, v: img.transform( img.size, Image.AFFINE, (1, 0, int((v - 0.5) * 0.1 * img.size[0]), 0, 1, 0) ), (0, 1)), ("TranslateY", lambda img, v: img.transform( img.size, Image.AFFINE, (1, 0, 0, 0, 1, int((v - 0.5) * 0.1 * img.size[1])) ), (0, 1)), # Medical-specific ("ElasticDeform", elastic_deformation, (0.1, 0.5)), ("IntensityShift", intensity_shift, (0, 1)), # Safe photometric ("GaussianNoise", lambda img, v: img, (0, 1)), # Placeholder ("GaussianBlur", lambda img, v: img.filter( ImageFilter.GaussianBlur(radius=v) ), (0, 2)), ] return DomainSpecificRandAugment( n=n, m=m, operations=MEDICAL_OPERATIONS )We've traversed the landscape of learned augmentation—from expensive reinforcement learning search through surprisingly effective zero-hyperparameter approaches.
What's Next:
We've explored training-time augmentation extensively. Now we'll examine Test-Time Augmentation (TTA)—where augmentations are applied at inference to improve prediction robustness and uncertainty estimation. TTA provides nearly free accuracy gains by aggregating predictions across augmented views of test inputs.
You now understand the evolution of learned augmentation from expensive AutoAugment through practical TrivialAugment. For most applications, TrivialAugment or RandAugment provides an excellent starting point without hyperparameter tuning overhead.