Loading learning content...
All augmentation techniques we've studied so far apply during training. But what if we could also leverage augmentations at inference time to improve predictions on individual test samples?
Test-Time Augmentation (TTA) does exactly this: instead of predicting on a single view of the input, we generate multiple augmented versions, predict on each, and aggregate the results. This simple technique provides substantial accuracy improvements with minimal additional computation—effectively ensemble behavior without multiple models.
TTA is particularly valuable in high-stakes applications where prediction reliability matters more than inference speed: medical diagnosis, autonomous driving, and competitive benchmarks. Understanding when and how to apply TTA is essential for extracting maximum performance from trained models.
By the end of this page, you will understand the theoretical foundation of TTA as implicit ensembling, implement TTA correctly with proper aggregation strategies, select appropriate augmentations for different tasks, and understand the trade-offs between accuracy gains and inference cost.
Test-time augmentation is grounded in the principle that robust predictions should be consistent across semantically-equivalent transformations of the input.
Given a test sample $x$ and a set of transformations $\mathcal{T} = {T_1, T_2, ..., T_K}$, TTA produces a prediction by:
$$\hat{y}_{TTA} = \text{Aggregate}(f(T_1(x)), f(T_2(x)), ..., f(T_K(x)))$$
where $f$ is the trained model and Aggregate combines predictions (typically averaging for probabilities, voting for classes, or geometric mean for scores).
Assume model predictions have some variance due to input sensitivity:
$$f(x) = f^*(x) + \epsilon$$
where $f^*(x)$ is the "true" prediction and $\epsilon$ is prediction noise. For uncorrelated augmentations:
$$\text{Var}(\hat{y}_{TTA}) = \frac{1}{K}\text{Var}(\hat{y})$$
Averaging K predictions reduces variance by a factor of K, improving reliability.
TTA can be viewed as an implicit ensemble. Instead of training K separate models (expensive), we create K different "views" of the same input (cheap). The predictions from these views are correlated (same model) but not identical (different inputs), providing ensemble-like benefits.
Comparison with explicit ensembles:
| Aspect | Explicit Ensemble | TTA |
|---|---|---|
| Training Cost | K× model training | 1× training |
| Model Storage | K× parameters | 1× parameters |
| Inference Cost | K× forward passes | K× forward passes |
| Prediction Diversity | Independent models | Augmented views |
| Typical Improvement | 2-5% relative | 0.5-2% relative |
TTA provides smaller gains than explicit ensembles but at dramatically lower training and storage cost.
From a Bayesian perspective, TTA marginalizes over input uncertainty:
$$p(y|x) = \int p(y|T(x)) p(T) dT \approx \frac{1}{K}\sum_{k=1}^K p(y|T_k(x))$$
This treats augmentations as samples from a distribution over plausible inputs, with TTA approximating the posterior predictive distribution.
TTA is most beneficial when: (1) the model shows high variance across similar inputs, (2) test images may differ from training distribution, (3) the task has inherent ambiguity that benefits from multiple perspectives, or (4) you're in a compute-unconstrained setting (batch processing, competitions).
Effective TTA implementation requires careful attention to augmentation selection, inverse transformation handling, and prediction aggregation.
The simplest TTA averages softmax probabilities across augmented views:
$$p(y=c|x) = \frac{1}{K}\sum_{k=1}^K \text{softmax}(f(T_k(x)))_c$$
The final prediction is $\hat{y} = \arg\max_c p(y=c|x)$.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import List, Callable, Optionalimport numpy as np class TTAClassification: """ Test-Time Augmentation for image classification. Aggregates predictions across multiple augmented views of the input to improve accuracy and calibration. """ def __init__( self, model: nn.Module, transforms: List[Callable] = None, n_augmentations: int = 5, aggregation: str = 'mean', # 'mean', 'vote', 'geometric' device: str = 'cuda' ): """ Parameters: ----------- model : nn.Module Trained classification model transforms : list of callable Augmentation functions. Each takes image tensor and returns transformed tensor. If None, uses default flip and multi-crop. n_augmentations : int Number of augmented views per image (used if transforms not provided) aggregation : str How to combine predictions: - 'mean': Average softmax probabilities - 'vote': Majority voting on predictions - 'geometric': Geometric mean of probabilities """ self.model = model self.device = device self.aggregation = aggregation if transforms is not None: self.transforms = transforms else: self.transforms = self._default_transforms(n_augmentations) def _default_transforms(self, n: int) -> List[Callable]: """ Generate default TTA transforms. Includes identity, horizontal flip, and multi-scale crops. """ transforms = [ # Identity (original image) lambda x: x, # Horizontal flip lambda x: torch.flip(x, dims=[-1]), ] # Add random crops if more augmentations requested while len(transforms) < n: # Multi-scale center crops would be added here transforms.append( lambda x: x # Placeholder ) return transforms[:n] def predict_single( self, image: torch.Tensor ) -> torch.Tensor: """ TTA prediction for a single image. Parameters: ----------- image : torch.Tensor Input image tensor of shape (C, H, W) Returns: -------- Predicted class probabilities of shape (num_classes,) """ self.model.eval() image = image.to(self.device) # Collect predictions from all augmented views all_probs = [] with torch.no_grad(): for transform in self.transforms: # Apply augmentation aug_image = transform(image.unsqueeze(0)) # Get prediction logits = self.model(aug_image) probs = F.softmax(logits, dim=-1) all_probs.append(probs) # Stack predictions: (K, C) all_probs = torch.cat(all_probs, dim=0) # Aggregate if self.aggregation == 'mean': return all_probs.mean(dim=0) elif self.aggregation == 'geometric': return torch.exp(torch.log(all_probs + 1e-8).mean(dim=0)) elif self.aggregation == 'vote': votes = all_probs.argmax(dim=1) # Return one-hot of majority vote vote_counts = torch.bincount(votes, minlength=all_probs.size(1)) return vote_counts.float() / vote_counts.sum() else: raise ValueError(f"Unknown aggregation: {self.aggregation}") def predict_batch( self, images: torch.Tensor ) -> torch.Tensor: """ TTA prediction for a batch of images. More efficient than predict_single for multiple images. Parameters: ----------- images : torch.Tensor Batch of images of shape (B, C, H, W) Returns: -------- Predicted probabilities of shape (B, num_classes) """ self.model.eval() images = images.to(self.device) B = images.size(0) all_probs = [] with torch.no_grad(): for transform in self.transforms: # Apply augmentation to entire batch aug_batch = transform(images) # Get predictions logits = self.model(aug_batch) probs = F.softmax(logits, dim=-1) all_probs.append(probs.unsqueeze(0)) # (1, B, C) # Stack: (K, B, C) all_probs = torch.cat(all_probs, dim=0) # Aggregate across augmentations if self.aggregation == 'mean': return all_probs.mean(dim=0) elif self.aggregation == 'geometric': return torch.exp(torch.log(all_probs + 1e-8).mean(dim=0)) else: # Vote: compute per-sample results = [] for b in range(B): votes = all_probs[:, b, :].argmax(dim=1) vote_counts = torch.bincount(votes, minlength=all_probs.size(2)) results.append(vote_counts.float() / vote_counts.sum()) return torch.stack(results)For pixel-level prediction tasks like segmentation, TTA requires inverse transformations to align predictions before aggregation:
$$\hat{y}{TTA}(p) = \frac{1}{K}\sum{k=1}^K T_k^{-1}(f(T_k(x)))(p)$$
If we flip an image horizontally, we must flip the predicted mask back. If we rotate by 90°, we must rotate the mask by -90°.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
import torchimport torch.nn.functional as Ffrom typing import List, Tuple, Callable class TTASegmentation: """ Test-Time Augmentation for semantic segmentation. Applies geometric augmentations and properly inverts predictions before averaging. """ def __init__( self, model, include_flips: bool = True, include_rotations: bool = False, scales: List[float] = None, device: str = 'cuda' ): """ Parameters: ----------- model : nn.Module Segmentation model include_flips : bool Include horizontal and vertical flips include_rotations : bool Include 90°, 180°, 270° rotations scales : list of float Multi-scale inference scales (e.g., [0.5, 1.0, 1.5]) """ self.model = model self.device = device # Build transform/inverse pairs self.transform_pairs = self._build_transforms( include_flips, include_rotations, scales or [1.0] ) def _build_transforms( self, include_flips: bool, include_rotations: bool, scales: List[float] ) -> List[Tuple[Callable, Callable]]: """ Build list of (transform, inverse_transform) pairs. """ pairs = [] # Identity pairs.append(( lambda x: x, lambda y: y )) if include_flips: # Horizontal flip pairs.append(( lambda x: torch.flip(x, dims=[-1]), lambda y: torch.flip(y, dims=[-1]) )) # Vertical flip pairs.append(( lambda x: torch.flip(x, dims=[-2]), lambda y: torch.flip(y, dims=[-2]) )) # Both flips pairs.append(( lambda x: torch.flip(x, dims=[-2, -1]), lambda y: torch.flip(y, dims=[-2, -1]) )) if include_rotations: # 90° rotation pairs.append(( lambda x: torch.rot90(x, k=1, dims=[-2, -1]), lambda y: torch.rot90(y, k=-1, dims=[-2, -1]) )) # 180° rotation pairs.append(( lambda x: torch.rot90(x, k=2, dims=[-2, -1]), lambda y: torch.rot90(y, k=-2, dims=[-2, -1]) )) # 270° rotation pairs.append(( lambda x: torch.rot90(x, k=3, dims=[-2, -1]), lambda y: torch.rot90(y, k=-3, dims=[-2, -1]) )) return pairs def _apply_scale( self, image: torch.Tensor, mask: torch.Tensor, scale: float, original_size: Tuple[int, int] ) -> torch.Tensor: """ Apply scale during inference and resize back. """ if scale != 1.0: H, W = original_size scaled_H = int(H * scale) scaled_W = int(W * scale) # Resize input scaled_image = F.interpolate( image, size=(scaled_H, scaled_W), mode='bilinear', align_corners=False ) # Resize output back mask = F.interpolate( mask, size=(H, W), mode='bilinear', align_corners=False ) return mask def predict( self, image: torch.Tensor, scales: List[float] = None ) -> torch.Tensor: """ TTA prediction for segmentation. Parameters: ----------- image : torch.Tensor Input image of shape (B, C, H, W) scales : list of float Optional multi-scale factors Returns: -------- Averaged segmentation logits of shape (B, num_classes, H, W) """ self.model.eval() image = image.to(self.device) scales = scales or [1.0] original_size = image.shape[-2:] accumulated = None count = 0 with torch.no_grad(): for scale in scales: # Scale image if needed if scale != 1.0: scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale)) scaled_image = F.interpolate( image, size=scaled_size, mode='bilinear', align_corners=False ) else: scaled_image = image for transform, inverse_transform in self.transform_pairs: # Apply forward transform aug_image = transform(scaled_image) # Get prediction logits = self.model(aug_image) # Apply inverse transform to prediction aligned_logits = inverse_transform(logits) # Resize to original if scaled if scale != 1.0: aligned_logits = F.interpolate( aligned_logits, size=original_size, mode='bilinear', align_corners=False ) # Accumulate if accumulated is None: accumulated = aligned_logits else: accumulated = accumulated + aligned_logits count += 1 return accumulated / countFor object detection, rotation TTA requires rotating bounding boxes back to original coordinates—non-trivial with axis-aligned boxes. Many practitioners limit detection TTA to horizontal flip only, which has straightforward coordinate transformation.
The choice of aggregation method affects both accuracy and calibration. Different methods suit different scenarios.
Averages softmax probabilities:
$$p_{avg}(y=c) = \frac{1}{K}\sum_{k=1}^K p_k(y=c)$$
Advantages:
Disadvantages:
Computes geometric mean of probabilities:
$$p_{geo}(y=c) \propto \left(\prod_{k=1}^K p_k(y=c)\right)^{1/K}$$
Advantages:
Disadvantages:
| Method | Formula | Best For | Calibration |
|---|---|---|---|
| Arithmetic Mean | 1/K Σ p_k | General use | Good |
| Geometric Mean | (∏ p_k)^(1/K) | Rare class prediction | Excellent |
| Majority Vote | mode(argmax p_k) | High-confidence decisions | Poor |
| Maximum Confidence | max_k(max p_k) | Very clean test data | Poor |
| Weighted Mean | Σ w_k p_k | Quality-varying transforms | Depends on weights |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
import torchimport torch.nn.functional as Ffrom typing import List def aggregate_mean(predictions: torch.Tensor) -> torch.Tensor: """ Arithmetic mean of predictions. Parameters: ----------- predictions : torch.Tensor Predictions of shape (K, batch_size, num_classes) Returns: -------- Aggregated predictions of shape (batch_size, num_classes) """ return predictions.mean(dim=0) def aggregate_geometric_mean(predictions: torch.Tensor) -> torch.Tensor: """ Geometric mean of predictions. More robust to outliers, better for imbalanced classes. """ # Add small epsilon to avoid log(0) log_probs = torch.log(predictions + 1e-8) mean_log = log_probs.mean(dim=0) # Exponentiate and renormalize geo_mean = torch.exp(mean_log) return geo_mean / geo_mean.sum(dim=-1, keepdim=True) def aggregate_vote(predictions: torch.Tensor) -> torch.Tensor: """ Majority voting. Returns one-hot-like probability based on vote counts. """ K, B, C = predictions.shape # Get class predictions class_preds = predictions.argmax(dim=-1) # (K, B) # Count votes for each sample results = [] for b in range(B): votes = class_preds[:, b] counts = torch.bincount(votes, minlength=C).float() results.append(counts / counts.sum()) return torch.stack(results) def aggregate_weighted_mean( predictions: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: """ Weighted mean of predictions. Useful when some augmentations are more reliable than others. Parameters: ----------- predictions : torch.Tensor Shape (K, batch_size, num_classes) weights : torch.Tensor Shape (K,) - weights for each augmentation """ weights = weights / weights.sum() # Normalize weights = weights.view(-1, 1, 1) # (K, 1, 1) return (predictions * weights).sum(dim=0) def aggregate_max_confidence(predictions: torch.Tensor) -> torch.Tensor: """ Select prediction with highest confidence. Good when only one view captures the object well. """ K, B, C = predictions.shape # Max confidence for each augmentation confidences = predictions.max(dim=-1).values # (K, B) # Index of most confident augmentation per sample best_k = confidences.argmax(dim=0) # (B,) # Gather predictions from most confident batch_idx = torch.arange(B, device=predictions.device) return predictions[best_k, batch_idx, :] class AdaptiveTTA: """ Adaptive TTA that learns aggregation weights from validation data. """ def __init__( self, model, transforms: List, device: str = 'cuda' ): self.model = model self.transforms = transforms self.device = device # Learnable weights for each transform self.weights = torch.ones(len(transforms), device=device) def learn_weights(self, val_loader): """ Learn optimal weights from validation set. Uses accuracy-based weighting: transforms that produce higher individual accuracy get higher weights. """ self.model.eval() # Track per-transform accuracy correct_per_transform = torch.zeros(len(self.transforms)) total = 0 with torch.no_grad(): for images, labels in val_loader: images = images.to(self.device) labels = labels.to(self.device) for i, transform in enumerate(self.transforms): aug_images = transform(images) logits = self.model(aug_images) pred = logits.argmax(dim=-1) correct_per_transform[i] += (pred == labels).sum().item() total += labels.size(0) # Weight by relative accuracy accuracies = correct_per_transform / total self.weights = accuracies / accuracies.sum() self.weights = self.weights.to(self.device) return dict(zip(range(len(self.transforms)), accuracies.tolist())) def predict(self, images: torch.Tensor) -> torch.Tensor: """Predict using learned weights.""" self.model.eval() images = images.to(self.device) predictions = [] with torch.no_grad(): for transform in self.transforms: aug_images = transform(images) logits = self.model(aug_images) probs = F.softmax(logits, dim=-1) predictions.append(probs) predictions = torch.stack(predictions) # (K, B, C) return aggregate_weighted_mean(predictions, self.weights)Not all training augmentations are appropriate for TTA. Selection should consider model invariances, task requirements, and computational budget.
1. Consistency with training augmentations The model should be familiar with the transformations used at test time. Applying augmentations not seen during training can hurt performance.
2. Semantic preservation Transformations must preserve the image's meaning. Color inversion changes semantics for many tasks but not all.
3. Invertibility (for dense prediction) For segmentation/detection, you need to un-transform predictions. Non-invertible transforms (random crop, dropout) can't be used.
4. Computational efficiency Each augmentation multiplies inference time. Balance benefit against cost.
Multi-scale inference runs the model at different input resolutions and combines predictions:
$$\hat{y} = \frac{1}{|S|}\sum_{s \in S} f(\text{resize}(x, s))$$
where $S$ might be ${0.5, 0.75, 1.0, 1.25, 1.5}$.
Benefits:
Challenges:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
import torchimport torch.nn.functional as Ffrom typing import List, Tuple class MultiScaleTTA: """ Multi-scale test-time augmentation. Runs inference at multiple resolutions and combines predictions. """ def __init__( self, model, scales: List[float] = [0.5, 0.75, 1.0, 1.25, 1.5], include_flip: bool = True, device: str = 'cuda' ): self.model = model self.scales = scales self.include_flip = include_flip self.device = device def predict_segmentation( self, image: torch.Tensor ) -> torch.Tensor: """ Multi-scale TTA for segmentation. Parameters: ----------- image : torch.Tensor Input image of shape (B, C, H, W) Returns: -------- Averaged logits of shape (B, num_classes, H, W) """ self.model.eval() image = image.to(self.device) B, C, H, W = image.shape accumulated = None count = 0 with torch.no_grad(): for scale in self.scales: # Compute scaled size scaled_H = int(H * scale) scaled_W = int(W * scale) # Ensure even dimensions (for some architectures) scaled_H = scaled_H - scaled_H % 32 scaled_W = scaled_W - scaled_W % 32 # Resize image scaled_image = F.interpolate( image, size=(scaled_H, scaled_W), mode='bilinear', align_corners=False ) for flip in ([False, True] if self.include_flip else [False]): # Apply flip if needed if flip: input_img = torch.flip(scaled_image, dims=[-1]) else: input_img = scaled_image # Get prediction logits = self.model(input_img) # Un-flip if needed if flip: logits = torch.flip(logits, dims=[-1]) # Resize back to original logits = F.interpolate( logits, size=(H, W), mode='bilinear', align_corners=False ) # Accumulate if accumulated is None: accumulated = logits else: accumulated = accumulated + logits count += 1 return accumulated / count def predict_classification( self, image: torch.Tensor, crop_mode: str = 'center' # 'center', 'five', 'ten' ) -> torch.Tensor: """ Multi-scale TTA for classification with multi-crop. Standard approach: scale image, take crops, aggregate. """ self.model.eval() image = image.to(self.device) all_probs = [] with torch.no_grad(): for scale in self.scales: # Scale the image scaled = F.interpolate( image, scale_factor=scale, mode='bilinear', align_corners=False ) # Get crops crops = self._get_crops(scaled, crop_mode) for crop in crops: # With and without flip for flip in ([False, True] if self.include_flip else [False]): if flip: crop_input = torch.flip(crop, dims=[-1]) else: crop_input = crop logits = self.model(crop_input) probs = F.softmax(logits, dim=-1) all_probs.append(probs) # Average all predictions return torch.stack(all_probs).mean(dim=0) def _get_crops( self, image: torch.Tensor, mode: str ) -> List[torch.Tensor]: """ Extract crops from image for multi-crop inference. """ B, C, H, W = image.shape crop_size = 224 # Assuming standard ImageNet crop if mode == 'center': # Center crop only h_start = (H - crop_size) // 2 w_start = (W - crop_size) // 2 return [image[:, :, h_start:h_start+crop_size, w_start:w_start+crop_size]] elif mode == 'five': # Center + four corners crops = [] # Center h_c, w_c = (H - crop_size) // 2, (W - crop_size) // 2 crops.append(image[:, :, h_c:h_c+crop_size, w_c:w_c+crop_size]) # Corners for h_start, w_start in [(0, 0), (0, W-crop_size), (H-crop_size, 0), (H-crop_size, W-crop_size)]: crops.append(image[:, :, h_start:h_start+crop_size, w_start:w_start+crop_size]) return crops elif mode == 'ten': # Five crops + their horizontal flips five_crops = self._get_crops(image, 'five') return five_crops + [torch.flip(c, dims=[-1]) for c in five_crops] else: raise ValueError(f"Unknown crop mode: {mode}")Beyond improving accuracy, TTA provides a natural mechanism for estimating prediction uncertainty—the disagreement among augmented views indicates model confidence.
The variance of predictions across TTA views provides an uncertainty estimate:
$$\text{Uncertainty}(x) = \text{Var}[f(T_k(x))]_{k=1}^K$$
High variance indicates the model is sensitive to the input view, suggesting lower confidence. Low variance suggests robust prediction.
For classification, predictive entropy captures uncertainty:
$$H[Y|X] = -\sum_c p_{TTA}(y=c|x) \log p_{TTA}(y=c|x)$$
High entropy = high uncertainty (uniform-like distribution). Low entropy = confident prediction.
TTA enables decomposing uncertainty into epistemic (model uncertainty) and aleatoric (data uncertainty):
$$I[Y; T|X] = H[Y|X] - \mathbb{E}_{T}[H[Y|X,T]]$$
The mutual information between prediction and augmentation measures how much the model "disagrees with itself"—a pure epistemic uncertainty measure.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
import torchimport torch.nn.functional as Ffrom typing import Tuple, List, Callableimport numpy as np class TTAUncertainty: """ Test-Time Augmentation for uncertainty estimation. Provides predictions along with calibrated uncertainty estimates based on prediction variance across augmented views. """ def __init__( self, model, transforms: List[Callable], device: str = 'cuda' ): self.model = model self.transforms = transforms self.device = device def predict_with_uncertainty( self, image: torch.Tensor ) -> Tuple[torch.Tensor, dict]: """ Get predictions with multiple uncertainty estimates. Parameters: ----------- image : torch.Tensor Input image of shape (B, C, H, W) Returns: -------- predictions : torch.Tensor Aggregated predictions of shape (B, num_classes) uncertainties : dict Dictionary containing various uncertainty measures: - 'predictive_entropy': Total uncertainty - 'expected_entropy': Aleatoric uncertainty - 'mutual_information': Epistemic uncertainty - 'prediction_variance': Variance across views - 'confidence': Max probability of aggregated prediction """ self.model.eval() image = image.to(self.device) # Collect predictions from all views all_probs = [] with torch.no_grad(): for transform in self.transforms: aug_image = transform(image) logits = self.model(aug_image) probs = F.softmax(logits, dim=-1) all_probs.append(probs) # Stack: (K, B, C) all_probs = torch.stack(all_probs) K, B, C = all_probs.shape # Mean prediction mean_probs = all_probs.mean(dim=0) # (B, C) # Variance across views (per class) variance = all_probs.var(dim=0) # (B, C) mean_variance = variance.mean(dim=-1) # (B,) - average across classes # Predictive entropy H[Y|X] predictive_entropy = -(mean_probs * torch.log(mean_probs + 1e-8)).sum(dim=-1) # Expected entropy E[H[Y|X,T]] - average entropy of individual predictions individual_entropies = -(all_probs * torch.log(all_probs + 1e-8)).sum(dim=-1) # (K, B) expected_entropy = individual_entropies.mean(dim=0) # (B,) # Mutual information I[Y;T|X] = H[Y|X] - E[H[Y|X,T]] # This is the "epistemic" uncertainty - model's self-disagreement mutual_information = predictive_entropy - expected_entropy # Confidence: max probability in aggregated prediction confidence = mean_probs.max(dim=-1).values # Agreement rate: fraction of views predicting the same class predictions = all_probs.argmax(dim=-1) # (K, B) mode_prediction = mean_probs.argmax(dim=-1) # (B,) agreement = (predictions == mode_prediction.unsqueeze(0)).float().mean(dim=0) uncertainties = { 'predictive_entropy': predictive_entropy, 'expected_entropy': expected_entropy, # Aleatoric 'mutual_information': mutual_information, # Epistemic 'prediction_variance': mean_variance, 'confidence': confidence, 'agreement_rate': agreement } return mean_probs, uncertainties def calibration_curve( self, dataloader, n_bins: int = 10 ) -> dict: """ Compute reliability diagram for TTA predictions. Returns accuracy vs. confidence for calibration analysis. """ confidences = [] accuracies = [] for images, labels in dataloader: images = images.to(self.device) labels = labels.to(self.device) probs, uncertainties = self.predict_with_uncertainty(images) preds = probs.argmax(dim=-1) confidences.extend(uncertainties['confidence'].cpu().numpy()) accuracies.extend((preds == labels).cpu().numpy()) confidences = np.array(confidences) accuracies = np.array(accuracies) # Bin by confidence bins = np.linspace(0, 1, n_bins + 1) bin_accs = [] bin_confs = [] bin_counts = [] for i in range(n_bins): mask = (confidences >= bins[i]) & (confidences < bins[i + 1]) if mask.sum() > 0: bin_accs.append(accuracies[mask].mean()) bin_confs.append(confidences[mask].mean()) bin_counts.append(mask.sum()) else: bin_accs.append(0) bin_confs.append((bins[i] + bins[i+1]) / 2) bin_counts.append(0) # Expected Calibration Error total = len(confidences) ece = sum(abs(bin_accs[i] - bin_confs[i]) * bin_counts[i] / total for i in range(n_bins) if bin_counts[i] > 0) return { 'bin_accuracies': bin_accs, 'bin_confidences': bin_confs, 'bin_counts': bin_counts, 'ece': ece }TTA is often compared to Monte Carlo Dropout for uncertainty. TTA samples different inputs to the same model; MC Dropout samples different models (dropout masks) on the same input. They capture complementary uncertainties and can be combined for even better calibration.
TTA multiplies inference cost by the number of augmentations. Understanding these trade-offs guides practical deployment.
For K augmentations:
Parallel TTA:
Sequential TTA:
Hybrid:
| Augmentation Set | Forward Passes | Typical Accuracy Gain | Recommended Use |
|---|---|---|---|
| None (baseline) | 1 | 0% | Real-time applications |
| H-Flip only | 2 | +0.3-0.5% | Good default TTA |
| H-Flip + V-Flip | 4 | +0.5-0.8% | Medical/satellite |
| Five-crop + flip | 10 | +1-2% | Image classification |
| Multi-scale + flip | 10-20 | +2-3% | Segmentation |
| Full TTA suite | 50+ | +2-4% | Competitions only |
TTA accuracy gains typically diminish with more augmentations:
The law of diminishing returns suggests:
Smart TTA applies full augmentation only when needed:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
import torchimport torch.nn.functional as Ffrom typing import List, Callable class AdaptiveEarlyStopTTA: """ Adaptive TTA that stops early when predictions are confident. Reduces average inference cost while maintaining accuracy on most samples. """ def __init__( self, model, transforms: List[Callable], confidence_threshold: float = 0.95, agreement_threshold: float = 3, # Stop if 3 consecutive agree min_augmentations: int = 2, device: str = 'cuda' ): """ Parameters: ----------- confidence_threshold : float Stop if aggregated confidence exceeds this agreement_threshold : int Stop if this many consecutive augmentations agree on class min_augmentations : int Always run at least this many augmentations """ self.model = model self.transforms = transforms self.confidence_threshold = confidence_threshold self.agreement_threshold = agreement_threshold self.min_augmentations = min_augmentations self.device = device def predict( self, image: torch.Tensor ) -> tuple: """ Adaptive TTA prediction with early stopping. Returns: -------- predictions : torch.Tensor Class probabilities n_augmentations_used : int Number of augmentations actually applied """ self.model.eval() image = image.to(self.device) all_probs = [] all_preds = [] with torch.no_grad(): for i, transform in enumerate(self.transforms): aug_image = transform(image) logits = self.model(aug_image) probs = F.softmax(logits, dim=-1) all_probs.append(probs) all_preds.append(probs.argmax(dim=-1)) # Check stopping conditions after minimum if i >= self.min_augmentations - 1: # Compute current aggregated prediction mean_probs = torch.stack(all_probs).mean(dim=0) confidence = mean_probs.max(dim=-1).values # Check confidence threshold if confidence.min() > self.confidence_threshold: break # Check agreement threshold if len(all_preds) >= self.agreement_threshold: recent_preds = torch.stack(all_preds[-self.agreement_threshold:]) agreement = (recent_preds == recent_preds[0]).all(dim=0) if agreement.all(): break # Final aggregation final_probs = torch.stack(all_probs).mean(dim=0) return final_probs, len(all_probs)We've explored test-time augmentation comprehensively—from theoretical foundations through practical implementation and adaptive strategies.
What's Next:
We'll now explore Augmentation Strategies—how to combine training-time and test-time augmentations into coherent strategies tailored to specific domains, model architectures, and performance requirements. This final page synthesizes everything we've learned into actionable guidelines.
You now understand how to apply augmentations at inference time to improve predictions, estimate uncertainty, and make trade-offs between accuracy and computational cost. TTA is a valuable tool for high-stakes applications where prediction reliability is paramount.