Loading learning content...
Traditional pooling operations—max pooling, average pooling, and strided convolutions—have served as workhorses for spatial downsampling since the earliest CNNs. However, the research community has developed numerous alternatives that address specific limitations or offer enhanced capabilities for particular applications.
These alternatives range from dilated convolutions that expand receptive fields without downsampling, to attention-based pooling that learns what to aggregate, to deformable operations that adapt spatial sampling to image content. Understanding this landscape is essential for designing modern architectures and selecting the right tool for each task.
By the end of this page, you will understand dilated (atrous) convolutions for receptive field expansion, deformable convolutions that learn spatial sampling offsets, attention-based pooling mechanisms, learnable pooling and aggregation modules, specialized pooling for dense prediction tasks, and emerging approaches from transformer architectures.
Dilated convolutions (also called atrous convolutions) insert gaps between kernel elements, effectively expanding the receptive field without adding parameters or losing resolution through pooling.
Mathematical Formulation:
For a standard convolution kernel $K$ of size $k \times k$, a dilated convolution with dilation rate $d$ effectively creates a kernel with spacing:
$$K_{dilated}[i,j] = K[i,j] \quad \text{at positions } (i \cdot d, j \cdot d)$$
The effective kernel size becomes:
$$k_{eff} = k + (k-1)(d-1) = d(k-1) + 1$$
For a 3×3 kernel:
The term 'atrous' comes from French 'à trous' meaning 'with holes'. This aptly describes the algorithm: a convolution kernel with holes (zeros) inserted between parameters. The technique was independently developed in multiple fields, including wavelet analysis ('algorithme à trous') and signal processing.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
import torchimport torch.nn as nnimport torch.nn.functional as F class DilatedConvBlock(nn.Module): """ Dilated convolution block demonstrating receptive field expansion. """ def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1): super().__init__() # Padding must account for dilation to maintain spatial size padding = dilation * (kernel_size - 1) // 2 self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, padding=padding, bias=False ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): return self.relu(self.bn(self.conv(x))) class AtrousSpatialPyramidPooling(nn.Module): """ ASPP module from DeepLab for multi-scale feature extraction. Applies parallel dilated convolutions at different rates to capture objects at multiple scales without explicit pooling/downsampling. """ def __init__(self, in_channels, out_channels, rates=[6, 12, 18]): super().__init__() # 1×1 convolution self.conv1x1 = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) # Dilated convolutions at different rates self.dilated_convs = nn.ModuleList([ nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) for rate in rates ]) # Image-level pooling for global context self.image_pooling = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) # Fusion of all branches num_branches = 2 + len(rates) # 1×1, image pool, and dilated convs self.fusion = nn.Sequential( nn.Conv2d(out_channels * num_branches, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): size = x.shape[2:] # H, W # 1×1 convolution feat_1x1 = self.conv1x1(x) # Dilated convolutions dilated_feats = [conv(x) for conv in self.dilated_convs] # Image-level pooling (upsampled to match spatial size) img_feat = self.image_pooling(x) img_feat = F.interpolate(img_feat, size=size, mode='bilinear', align_corners=False) # Concatenate all features all_feats = torch.cat([feat_1x1, *dilated_feats, img_feat], dim=1) # Fuse to output channels return self.fusion(all_feats) def demonstrate_dilation_effects(): """ Show how dilation expands receptive field without losing resolution. """ x = torch.randn(1, 64, 32, 32) dilations = [1, 2, 4, 8] for d in dilations: conv = nn.Conv2d(64, 64, kernel_size=3, dilation=d, padding=d, bias=False) out = conv(x) effective_size = 3 + (3-1)*(d-1) print(f"Dilation {d}: input {x.shape[2:]} → output {out.shape[2:]}, " f"effective kernel: {effective_size}×{effective_size}") demonstrate_dilation_effects() def test_aspp(): """Test ASPP module.""" x = torch.randn(2, 2048, 32, 32) aspp = AtrousSpatialPyramidPooling(2048, 256) out = aspp(x) print(f"ASPP: {x.shape} → {out.shape}") test_aspp()Key Applications:
| Architecture | Dilation Use | Purpose |
|---|---|---|
| DeepLab v2/v3 | ASPP module | Multi-scale dense prediction |
| WaveNet | Exponential dilation (1,2,4,8...) | Long-range audio dependencies |
| Dilated ResNet | Replace pooling in backbone | Preserve resolution for segmentation |
| TCN | Stacked dilated 1D convs | Temporal sequence modeling |
The Gridding Problem:
When using large dilation rates, the sparse sampling pattern can create gridding artifacts—the kernel only touches isolated pixels, potentially missing local patterns. Solutions include:
A dilation rate of 8 with a 3×3 kernel samples pixels at positions 0, 8, and 16—missing everything in between. If important details exist at non-sampled positions, they're invisible to this convolution. Multi-rate ASPP mitigates this by covering different sampling patterns.
Deformable convolutions allow the network to learn spatially-varying sampling offsets, enabling the kernel to adapt its shape to the image content rather than using a fixed rectangular grid.
Core Idea:
Standard convolution samples at fixed positions around the center pixel. Deformable convolution adds learnable 2D offsets $\Delta p$ to each sampling position:
$$y(p) = \sum_{k=1}^{K} w_k \cdot x(p + p_k + \Delta p_k)$$
where $p_k$ are the fixed kernel positions and $\Delta p_k$ are learned offsets.
Bilinear Interpolation:
Since offsets $\Delta p_k$ are typically non-integer (computed by a learned convolution layer), bilinear interpolation is used to sample at subpixel locations, making the operation fully differentiable.
Deformable convolutions effectively let the network learn geometric transformations of receptive fields. For object detection, this allows box-shaped features to adapt to object shapes. For segmentation, edges can be traced more accurately.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
import torchimport torch.nn as nnimport torch.nn.functional as F class DeformableConv2d(nn.Module): """ Deformable Convolution v1. Learns per-kernel-position offsets that allow the convolution to sample from adaptive locations based on input content. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True): super().__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding # Regular convolution weights self.weight = nn.Parameter( torch.Tensor(out_channels, in_channels, kernel_size, kernel_size) ) if bias: self.bias = nn.Parameter(torch.Tensor(out_channels)) else: self.bias = None # Offset prediction: 2 values (x,y) per kernel position self.offset_conv = nn.Conv2d( in_channels, 2 * kernel_size * kernel_size, # 2 offsets per position kernel_size=kernel_size, stride=stride, padding=padding ) self._init_weights() def _init_weights(self): nn.init.kaiming_uniform_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias) nn.init.zeros_(self.offset_conv.weight) nn.init.zeros_(self.offset_conv.bias) def forward(self, x): # Predict offsets for each spatial position offsets = self.offset_conv(x) # (B, 2*k*k, H, W) # Apply deformable convolution # Note: Actual implementation uses optimized CUDA kernels # This is a conceptual representation return self._deform_conv(x, offsets) def _deform_conv(self, x, offsets): """ Simplified deformable convolution implementation. Production code uses torchvision.ops.deform_conv2d """ B, C, H, W = x.shape k = self.kernel_size # Create base grid of kernel positions grid_y, grid_x = torch.meshgrid( torch.arange(k), torch.arange(k), indexing='ij' ) base_offsets = torch.stack([grid_x, grid_y], dim=-1).float() base_offsets = base_offsets.reshape(1, k*k, 1, 1, 2) # This is a simplified version; real implementation is more complex # and uses torchvision.ops.deform_conv2d for efficiency # For demonstration, fall back to regular conv return F.conv2d(x, self.weight, self.bias, self.stride, self.padding) class DeformableConv2dV2(nn.Module): """ Deformable Convolution v2 with modulation. Adds learnable modulation scalars that control the importance of each sampling position, providing attention-like behavior. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): super().__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding # Regular convolution self.weight = nn.Parameter( torch.Tensor(out_channels, in_channels, kernel_size, kernel_size) ) # Offset AND modulation prediction self.offset_modulation = nn.Conv2d( in_channels, 3 * kernel_size * kernel_size, # 2 offsets + 1 modulation kernel_size=kernel_size, stride=stride, padding=padding ) self._init_weights() def _init_weights(self): nn.init.kaiming_uniform_(self.weight) nn.init.zeros_(self.offset_modulation.weight) nn.init.zeros_(self.offset_modulation.bias) def forward(self, x): k_sq = self.kernel_size ** 2 offset_mod = self.offset_modulation(x) # Split into offsets and modulation offsets = offset_mod[:, :2*k_sq] modulation = torch.sigmoid(offset_mod[:, 2*k_sq:]) # Apply modulated deformable conv (conceptual) # Real implementation: torchvision.ops.deform_conv2d with mask return F.conv2d(x, self.weight, stride=self.stride, padding=self.padding) # Using torchvision's optimized implementationdef use_torchvision_deform(): """ Demonstrate using torchvision's deformable convolution. """ try: from torchvision.ops import DeformConv2d x = torch.randn(2, 64, 28, 28) # Create deformable conv layer deform_conv = DeformConv2d(64, 128, kernel_size=3, padding=1) # Create offset conv (separate from deformable conv in torchvision) offset_conv = nn.Conv2d(64, 2 * 3 * 3, kernel_size=3, padding=1) # Predict offsets offsets = offset_conv(x) # Apply deformable convolution out = deform_conv(x, offsets) print(f"Deformable conv: {x.shape} → {out.shape}") print(f"Offsets shape: {offsets.shape}") except ImportError: print("torchvision not available for deformable conv demo") use_torchvision_deform()Deformable Convolution v1 vs v2:
| Feature | v1 | v2 |
|---|---|---|
| Offsets | Learned per position | Learned per position |
| Modulation | None | Learned importance weights |
| Behavior | All positions contribute equally | Soft attention over positions |
| Use case | General geometric adaptation | Better for occlusion handling |
Applications:
Implementation Considerations:
Attention-based pooling replaces fixed aggregation rules (max, average) with learned attention weights that determine the contribution of each spatial position to the output.
Core Concept:
Instead of $y = \frac{1}{N}\sum_i x_i$ (average) or $y = \max_i x_i$ (max), compute:
$$y = \sum_i a_i \cdot x_i, \quad \text{where} \quad a_i = \frac{\exp(score(x_i))}{\sum_j \exp(score(x_j))}$$
The attention weights $a_i$ are computed by a learned scoring function and normalized via softmax.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
import torchimport torch.nn as nnimport torch.nn.functional as F class SoftAttentionPooling(nn.Module): """ Learnable soft attention pooling. A small network predicts importance scores for each spatial position, then aggregates features using these scores as weights. """ def __init__(self, channels, hidden_dim=None): super().__init__() hidden_dim = hidden_dim or channels // 4 self.attention = nn.Sequential( nn.Conv2d(channels, hidden_dim, 1), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, 1, 1), # Single attention map ) def forward(self, x): B, C, H, W = x.shape # Compute attention scores scores = self.attention(x) # (B, 1, H, W) # Softmax over spatial dimensions weights = F.softmax(scores.view(B, -1), dim=1) weights = weights.view(B, 1, H, W) # Weighted aggregation pooled = (x * weights).sum(dim=[2, 3]) # (B, C) return pooled class MultiHeadAttentionPooling(nn.Module): """ Multi-head attention pooling for richer aggregation. Uses multiple attention heads to capture different aspects of the spatial distribution. """ def __init__(self, channels, num_heads=4, head_dim=64): super().__init__() self.num_heads = num_heads self.head_dim = head_dim # Project to query space (learnable query per head) self.queries = nn.Parameter(torch.randn(num_heads, head_dim)) # Project features to key/value space self.key_proj = nn.Conv2d(channels, num_heads * head_dim, 1) self.value_proj = nn.Conv2d(channels, num_heads * head_dim, 1) # Output projection self.out_proj = nn.Linear(num_heads * head_dim, channels) def forward(self, x): B, C, H, W = x.shape N = H * W # Project to keys and values keys = self.key_proj(x).view(B, self.num_heads, self.head_dim, N) values = self.value_proj(x).view(B, self.num_heads, self.head_dim, N) # Compute attention (query @ key.T) queries = self.queries.unsqueeze(0).unsqueeze(-1) # (1, heads, dim, 1) attention = torch.matmul(queries.transpose(-2,-1), keys) / (self.head_dim ** 0.5) attention = F.softmax(attention, dim=-1) # (B, heads, 1, N) # Aggregate values attended = torch.matmul(values, attention.squeeze(-2).unsqueeze(-1)) attended = attended.squeeze(-1).view(B, -1) # (B, heads * dim) return self.out_proj(attended) class OrderlessPooling(nn.Module): """ Orderless pooling using Fisher Vector style aggregation. Captures the distribution of features rather than just their mean or max. """ def __init__(self, channels, num_components=64): super().__init__() self.num_components = num_components # Learnable Gaussian mixture model parameters self.means = nn.Parameter(torch.randn(num_components, channels)) self.log_vars = nn.Parameter(torch.zeros(num_components, channels)) self.mixture_weights = nn.Parameter(torch.ones(num_components) / num_components) def forward(self, x): B, C, H, W = x.shape N = H * W # Flatten spatial dimensions x_flat = x.view(B, C, N).permute(0, 2, 1) # (B, N, C) # Compute soft assignments to components means = self.means.unsqueeze(0).unsqueeze(1) # (1, 1, K, C) x_expanded = x_flat.unsqueeze(2) # (B, N, 1, C) # Squared distances sq_dist = ((x_expanded - means) ** 2).sum(-1) # (B, N, K) # Soft assignment assignments = F.softmax(-sq_dist, dim=2) # (B, N, K) # First-order statistics (difference from means) diff = x_expanded - means # (B, N, K, C) weighted_diff = (assignments.unsqueeze(-1) * diff).sum(dim=1) # (B, K, C) return weighted_diff.view(B, -1) # (B, K*C) class SelfAttentionPool(nn.Module): """ Self-attention based global pooling. Each position attends to all others before aggregation, capturing long-range dependencies. """ def __init__(self, channels): super().__init__() self.query = nn.Conv2d(channels, channels // 8, 1) self.key = nn.Conv2d(channels, channels // 8, 1) self.value = nn.Conv2d(channels, channels, 1) self.gamma = nn.Parameter(torch.zeros(1)) self.gap = nn.AdaptiveAvgPool2d(1) def forward(self, x): B, C, H, W = x.shape # Self-attention q = self.query(x).view(B, -1, H*W).permute(0, 2, 1) # (B, N, C') k = self.key(x).view(B, -1, H*W) # (B, C', N) v = self.value(x).view(B, -1, H*W) # (B, C, N) attention = F.softmax(torch.bmm(q, k), dim=-1) # (B, N, N) out = torch.bmm(v, attention.permute(0, 2, 1)) # (B, C, N) out = out.view(B, C, H, W) # Residual connection out = self.gamma * out + x # Global pooling after self-attention return self.gap(out).view(B, C) def demonstrate_attention_pooling(): x = torch.randn(4, 512, 14, 14) poolers = { "GAP": nn.AdaptiveAvgPool2d(1), "Soft Attention": SoftAttentionPooling(512), "Multi-Head Attention": MultiHeadAttentionPooling(512), "Self-Attention + Pool": SelfAttentionPool(512), } for name, pooler in poolers.items(): out = pooler(x) if len(out.shape) == 4: out = out.view(out.size(0), -1) print(f"{name}: {x.shape} → {out.shape}") demonstrate_attention_pooling()| Method | Parameters | Captures | Best For |
|---|---|---|---|
| Soft Attention | Few | Spatial importance | Single salient region |
| Multi-Head | Moderate | Multiple aspects | Complex scenes |
| Orderless | Many | Feature distribution | Texture/fine-grained |
| Self-Attention + Pool | Moderate | Long-range dependencies | Relational reasoning |
Attention pooling shines when different spatial regions have different importance for the task. For uniform textures, GAP works fine. For scenes with localized objects of interest, attention can significantly improve performance by focusing on relevant regions.
Region of Interest (ROI) Pooling extracts fixed-size features from arbitrary-sized regions within feature maps—a fundamental operation for object detection and instance-level tasks.
Basic ROI Pooling:
Given a region of interest (x₁, y₁, x₂, y₂) in image coordinates:
Problem: Integer quantization of region boundaries loses spatial precision.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torchvision.ops import roi_pool, roi_align, ps_roi_pool class ROIOperationsDemo: """ Demonstrate different ROI pooling variants. """ @staticmethod def basic_roi_pool(): """ Standard ROI pooling with integer quantization. """ # Feature maps (batch, channels, height, width) features = torch.randn(1, 256, 50, 50) # ROIs: (batch_idx, x1, y1, x2, y2) in feature map coordinates # Each row is one ROI, first column is batch index rois = torch.tensor([ [0, 10.6, 8.3, 35.2, 28.7], # Fractional coordinates [0, 5.1, 12.4, 22.8, 30.1], ], dtype=torch.float32) # Use torchvision's roi_pool output_size = 7 spatial_scale = 1.0 # feature_map_size / image_size pooled = roi_pool(features, rois, output_size, spatial_scale) print(f"ROI Pool output: {pooled.shape}") # (2, 256, 7, 7) return pooled @staticmethod def roi_align_demo(): """ ROI Align: bilinear interpolation avoids quantization. Instead of snapping to integer positions, samples at any position using bilinear interpolation. """ features = torch.randn(1, 256, 50, 50) rois = torch.tensor([ [0, 10.6, 8.3, 35.2, 28.7], [0, 5.1, 12.4, 22.8, 30.1], ], dtype=torch.float32) output_size = 7 spatial_scale = 1.0 # sampling_ratio: number of sampling points per output bin aligned = roi_align(features, rois, output_size, spatial_scale, sampling_ratio=2) print(f"ROI Align output: {aligned.shape}") # (2, 256, 7, 7) return aligned @staticmethod def ps_roi_pool_demo(): """ Position-Sensitive ROI Pooling. Used in R-FCN: channels are position-sensitive, meaning different channels represent different relative positions within the object. """ # Position-sensitive feature maps need special channel structure # k×k output requires k²×C channels k = 7 # Output size C = 10 # Classes (or features per position) features = torch.randn(1, k*k*C, 50, 50) # 490 channels rois = torch.tensor([ [0, 10, 8, 35, 28], [0, 5, 12, 22, 30], ], dtype=torch.float32) output_size = k spatial_scale = 1.0 ps_pooled = ps_roi_pool(features, rois, output_size, spatial_scale) print(f"PS ROI Pool output: {ps_pooled.shape}") # (2, 10, 7, 7) return ps_pooled class ROIAlignLayer(nn.Module): """ ROI Align wrapper with learnable feature refinement. """ def __init__(self, output_size, spatial_scale, sampling_ratio=2, channels=256): super().__init__() self.output_size = output_size self.spatial_scale = spatial_scale self.sampling_ratio = sampling_ratio # Optional post-processing self.refine = nn.Sequential( nn.Conv2d(channels, channels, 3, padding=1), nn.BatchNorm2d(channels), nn.ReLU(inplace=True) ) def forward(self, features, rois): aligned = roi_align( features, rois, self.output_size, self.spatial_scale, self.sampling_ratio ) return self.refine(aligned) # Compare quantization effectsdef compare_roi_methods(): """ Show the difference between ROI pooling and ROI align in terms of alignment precision. """ # Create feature map with known gradient features = torch.zeros(1, 1, 20, 20) for i in range(20): for j in range(20): features[0, 0, i, j] = i + j * 0.1 # ROI that doesn't align with feature grid rois = torch.tensor([[0, 5.3, 3.7, 15.6, 12.9]]) pool_out = roi_pool(features, rois, 7, 1.0) align_out = roi_align(features, rois, 7, 1.0, 2) print("Feature sample values at extracted ROI:") print(f" ROI Pool (quantized): mean={pool_out.mean():.4f}") print(f" ROI Align (bilinear): mean={align_out.mean():.4f}") print("\nROI Align preserves more precise spatial information.") compare_roi_methods() # Run demonstrationsdemo = ROIOperationsDemo()demo.basic_roi_pool()demo.roi_align_demo()| Method | Quantization | Speed | Precision | Used In |
|---|---|---|---|---|
| ROI Pooling | Integer | Fast | Lower | Fast R-CNN, SPP-Net |
| ROI Align | Bilinear | Moderate | Higher | Mask R-CNN, modern detectors |
| PS ROI Pool | Integer | Fast | Position-sensitive | R-FCN |
| Deformable ROI Pool | Learned offsets | Slower | Adaptive | Deformable ConvNets |
The Quantization Problem:
In standard ROI pooling:
ROI Align Solution:
Bilinear interpolation samples at exact positions, avoiding all quantization:
This seemingly simple change yielded significant improvements in Mask R-CNN for instance segmentation, where pixel-precise features matter.
Dense prediction tasks require recovering spatial resolution lost during pooling. Unpooling and related operations address this need.
Unpooling Strategies:
| Method | Description | Quality | Parameters |
|---|---|---|---|
| Nearest Neighbor | Repeat pixel values | Blocky | 0 |
| Bilinear | Linear interpolation | Smooth | 0 |
| Bicubic | Cubic interpolation | Smoother | 0 |
| Max Unpooling | Place values at saved indices | Sparse | 0 (needs indices) |
| Transposed Conv | Learned upsampling | Adaptive | k² × C_in × C_out |
| Pixel Shuffle | Rearrange depth to space | Efficient | Depends on input |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
import torchimport torch.nn as nnimport torch.nn.functional as F class MaxUnpooling(nn.Module): """ Max unpooling using saved indices from max pooling. Places pooled values back at their original positions, filling remaining positions with zeros. """ def __init__(self, kernel_size=2, stride=2): super().__init__() self.pool = nn.MaxPool2d(kernel_size, stride, return_indices=True) self.unpool = nn.MaxUnpool2d(kernel_size, stride) def forward(self, x): # Pool and save indices pooled, indices = self.pool(x) # Unpool using saved indices unpooled = self.unpool(pooled, indices) return pooled, unpooled class TransposedConvUpsampling(nn.Module): """ Learned upsampling via transposed convolution. Also called 'deconvolution' (technically a misnomer) or 'fractionally strided convolution'. """ def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1): super().__init__() self.deconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): return self.relu(self.bn(self.deconv(x))) class PixelShuffleUpsampling(nn.Module): """ Pixel shuffle (sub-pixel convolution) for efficient upsampling. Rearranges elements from (C×r², H, W) to (C, H×r, W×r). Popular in super-resolution networks. """ def __init__(self, in_channels, out_channels, scale_factor=2): super().__init__() # Expand channels to accommodate pixel shuffle self.conv = nn.Conv2d(in_channels, out_channels * scale_factor**2, kernel_size=3, padding=1) self.bn = nn.BatchNorm2d(out_channels * scale_factor**2) self.relu = nn.ReLU(inplace=True) self.shuffle = nn.PixelShuffle(scale_factor) def forward(self, x): x = self.relu(self.bn(self.conv(x))) return self.shuffle(x) class BilinearUpsampling(nn.Module): """ Simple bilinear upsampling followed by 1×1 conv for channel adjustment. Often used in U-Net style decoders. """ def __init__(self, in_channels, out_channels, scale_factor=2): super().__init__() self.scale_factor = scale_factor self.conv = nn.Conv2d(in_channels, out_channels, 1) def forward(self, x): x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) return self.conv(x) class LearnedUpsampling(nn.Module): """ Combines bilinear upsampling with learned refinement. Often better than pure transposed convolution (avoids checkerboard). """ def __init__(self, in_channels, out_channels, scale_factor=2): super().__init__() self.scale_factor = scale_factor self.refine = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) return self.refine(x) def compare_upsampling_methods(): """ Compare different upsampling approaches. """ x = torch.randn(1, 64, 16, 16) target_size = (32, 32) methods = { "Nearest": lambda t: F.interpolate(t, scale_factor=2, mode='nearest'), "Bilinear": lambda t: F.interpolate(t, scale_factor=2, mode='bilinear', align_corners=False), "TransposedConv": TransposedConvUpsampling(64, 64), "PixelShuffle": PixelShuffleUpsampling(64, 64), "LearnedBilinear": LearnedUpsampling(64, 64), } for name, method in methods.items(): try: out = method(x) params = sum(p.numel() for p in method.parameters()) if hasattr(method, 'parameters') else 0 print(f"{name}: {x.shape} → {out.shape}, params: {params}") except Exception as e: print(f"{name}: Error - {e}") compare_upsampling_methods() class SegmentationDecoder(nn.Module): """ Example decoder for semantic segmentation using various upsampling. """ def __init__(self, encoder_channels=[256, 512, 1024, 2048], num_classes=21): super().__init__() # Progressive upsampling with skip connections self.up4 = LearnedUpsampling(2048, 1024) self.up3 = LearnedUpsampling(1024+1024, 512) # +1024 from skip self.up2 = LearnedUpsampling(512+512, 256) self.up1 = LearnedUpsampling(256+256, 64) # Final classifier self.head = nn.Sequential( nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, num_classes, 1) ) def forward(self, features): # features: list of [f1, f2, f3, f4] from encoder c1, c2, c3, c4 = features x = self.up4(c4) x = torch.cat([x, c3], dim=1) x = self.up3(x) x = torch.cat([x, c2], dim=1) x = self.up2(x) x = torch.cat([x, c1], dim=1) x = self.up1(x) return self.head(x)Transposed convolutions can produce checkerboard artifacts when kernel size isn't divisible by stride (uneven overlap). Solutions: (1) Use kernel_size = 2×stride, (2) Use bilinear upsampling + convolution, or (3) Use PixelShuffle which doesn't have this problem.
Vision Transformers (ViT) and related architectures approach spatial processing fundamentally differently from CNNs, often eliminating traditional pooling entirely.
The ViT Paradigm:
Instead of progressive pooling:
This approach maintains full spatial token count throughout the network.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
import torchimport torch.nn as nnimport torch.nn.functional as F class PatchEmbedding(nn.Module): """ ViT-style patch embedding: converts image to sequence of tokens. This is the equivalent of 'stem' downsampling in CNNs, but done in a single step without intermediate pooling. """ def __init__(self, image_size=224, patch_size=16, in_channels=3, embed_dim=768): super().__init__() self.image_size = image_size self.patch_size = patch_size self.num_patches = (image_size // patch_size) ** 2 # Projection via convolution with stride = patch_size self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): # (B, C, H, W) → (B, E, H/P, W/P) → (B, N, E) x = self.projection(x) # (B, embed_dim, H/P, W/P) x = x.flatten(2).transpose(1, 2) # (B, num_patches, embed_dim) return x class WindowedAttention(nn.Module): """ Swin Transformer-style windowed attention. Partitions feature map into windows and applies attention within windows. Reduces the quadratic complexity of full attention. """ def __init__(self, dim, window_size=7, num_heads=8): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3) self.proj = nn.Linear(dim, dim) def forward(self, x): B, H, W, C = x.shape ws = self.window_size # Partition into windows x = x.view(B, H // ws, ws, W // ws, ws, C) x = x.permute(0, 1, 3, 2, 4, 5).contiguous() x = x.view(-1, ws * ws, C) # (num_windows*B, ws*ws, C) # Multi-head self-attention within windows qkv = self.qkv(x).reshape(-1, ws*ws, 3, self.num_heads, C // self.num_heads) q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) attn = (q @ k.transpose(-2, -1)) / (C // self.num_heads) ** 0.5 attn = F.softmax(attn, dim=-1) x = (attn @ v).transpose(1, 2).reshape(-1, ws*ws, C) x = self.proj(x) # Reverse window partition (not shown for brevity) return x class PatchMerging(nn.Module): """ Swin Transformer patch merging for downsampling. Concatenates 2×2 neighboring patches and projects to half the channels. Effectively downsamples by 2× while doubling channels. """ def __init__(self, dim): super().__init__() self.dim = dim self.norm = nn.LayerNorm(4 * dim) self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) def forward(self, x): B, H, W, C = x.shape # Partition into 2×2 patches x0 = x[:, 0::2, 0::2, :] # Top-left x1 = x[:, 1::2, 0::2, :] # Bottom-left x2 = x[:, 0::2, 1::2, :] # Top-right x3 = x[:, 1::2, 1::2, :] # Bottom-right # Concatenate and reduce x = torch.cat([x0, x1, x2, x3], dim=-1) # (B, H/2, W/2, 4C) x = self.norm(x) x = self.reduction(x) # (B, H/2, W/2, 2C) return x class HybridViTPooling(nn.Module): """ Hybrid approach: CNN backbone with transformer-style pooling. Uses CNN for local features, then attention for global aggregation. """ def __init__(self, in_channels, embed_dim=512, num_heads=8): super().__init__() # Project CNN features to tokens self.to_tokens = nn.Conv2d(in_channels, embed_dim, 1) # Cross-attention with learnable query self.query = nn.Parameter(torch.randn(1, 1, embed_dim)) self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) def forward(self, x): B, C, H, W = x.shape # Reshape to tokens tokens = self.to_tokens(x) # (B, embed_dim, H, W) tokens = tokens.flatten(2).transpose(1, 2) # (B, H*W, embed_dim) # Cross-attention: query learns what to extract query = self.query.expand(B, -1, -1) # (B, 1, embed_dim) pooled, _ = self.cross_attn(query, tokens, tokens) return pooled.squeeze(1) # (B, embed_dim) def demonstrate_transformer_pooling(): # ViT-style: single large patch embedding x = torch.randn(4, 3, 224, 224) patch_embed = PatchEmbedding() tokens = patch_embed(x) print(f"ViT Patch Embedding: {x.shape} → {tokens.shape}") # (4, 3, 224, 224) → (4, 196, 768) where 196 = 14×14 patches # Swin-style: patch merging x_swin = torch.randn(4, 56, 56, 96) # (B, H, W, C) merger = PatchMerging(96) merged = merger(x_swin) print(f"Swin Patch Merging: {x_swin.shape} → {merged.shape}") # (4, 56, 56, 96) → (4, 28, 28, 192) # Hybrid CNN + Transformer pooling cnn_features = torch.randn(4, 2048, 7, 7) hybrid = HybridViTPooling(2048, 512) pooled = hybrid(cnn_features) print(f"Hybrid pooling: {cnn_features.shape} → {pooled.shape}") demonstrate_transformer_pooling()| Aspect | CNN (Traditional) | Transformer (ViT-style) | Hybrid |
|---|---|---|---|
| Spatial reduction | Progressive pooling | Single patchify step | CNN backbone + attention pool |
| Local processing | Convolutions | Window attention (Swin) | Convolutions |
| Global processing | GAP at end | Full self-attention | Cross-attention |
| Inductive bias | Strong (locality) | Weak (position embeddings) | Moderate |
| Data efficiency | Good | Needs more data | Good with pretrain |
The landscape of spatial aggregation extends far beyond traditional max and average pooling. Understanding these alternatives enables you to select the right approach for each application.
| Task | Recommended Approach |
|---|---|
| Classification | Standard pooling or GAP |
| Semantic segmentation | Dilated convolutions + ASPP |
| Object detection | ROI Align + FPN |
| Instance segmentation | ROI Align + mask head |
| Fine-grained recognition | Attention pooling or deformable |
| Very high resolution | Windowed attention (Swin) |
Congratulations! You've completed the Pooling and Downsampling module. You now understand the full spectrum of spatial aggregation techniques—from classical pooling through modern attention-based and transformer approaches. This knowledge enables you to make informed architectural decisions for any computer vision task.
What's Next:
With your comprehensive understanding of pooling and downsampling, you're ready to explore CNN Architectures in the next module, where we'll see how these operations combine with other building blocks to create influential networks like LeNet, AlexNet, VGGNet, and Inception.