Loading learning content...
Traditional CNN architectures alternate between convolutional layers for feature extraction and pooling layers for spatial downsampling. But this separation raises a fundamental question: why use fixed, hand-designed pooling operations when we could learn the downsampling transformation end-to-end?
Strided convolutions answer this question by combining feature extraction and spatial reduction into a single learned operation. Instead of applying a convolution with stride 1 followed by pooling, we apply a convolution with stride greater than 1, directly producing a downsampled output. This simple change has profound implications for network expressiveness, gradient flow, and architectural design.
By the end of this page, you will understand the mathematical formulation of strided convolutions, how stride affects output dimensions and receptive fields, the trade-offs between strided convolutions and pooling, gradient flow properties and training dynamics, anti-aliasing considerations, and architectural patterns that leverage strided convolutions.
A convolution operation with stride $s > 1$ samples the output at every $s$-th position, effectively downsampling the feature map.
Standard Convolution Recap:
For an input $X \in \mathbb{R}^{C_{in} \times H \times W}$ and kernel $K \in \mathbb{R}^{C_{out} \times C_{in} \times k \times k}$, the standard 2D convolution with stride $s$ at output position $(i, j)$ is:
$$Y_{c_{out}, i, j} = \sum_{c_{in}} \sum_{m=0}^{k-1} \sum_{n=0}^{k-1} K_{c_{out}, c_{in}, m, n} \cdot X_{c_{in}, i \cdot s + m, j \cdot s + n}$$
The key is the factor $s$ in the indexing: we skip $s$ input positions between each output position.
Output Dimensions:
For input size $H \times W$, kernel size $k$, stride $s$, and padding $p$:
$$H_{out} = \left\lfloor \frac{H + 2p - k}{s} \right\rfloor + 1$$
$$W_{out} = \left\lfloor \frac{W + 2p - k}{s} \right\rfloor + 1$$
| Configuration | Kernel | Stride | Padding | Effect |
|---|---|---|---|---|
| Halving (common) | 3×3 | 2 | 1 | H,W → H/2, W/2 |
| Halving (alternative) | 4×4 | 2 | 1 | H,W → H/2, W/2 |
| Aggressive | 3×3 | 4 | varies | H,W → H/4, W/4 |
| Stem (ResNet) | 7×7 | 2 | 3 | 224→112 for initial reduction |
| No reduction | 3×3 | 1 | 1 | H,W → H, W (same size) |
When stride equals 2, using an even kernel size (like 4×4) provides symmetric coverage. Odd kernels (like 3×3) with stride 2 create slightly asymmetric sampling, though this rarely matters in practice. Most architectures use 3×3 with stride 2 for simplicity.
Comparison with Conv + Pool:
Consider reducing spatial dimensions by half. Two approaches:
Approach A: Conv (stride 1) + Pool (stride 2)
Input: (C_in, H, W)
↓ Conv 3×3, s=1, p=1 → (C_out, H, W)
↓ MaxPool 2×2, s=2 → (C_out, H/2, W/2)
Approach B: Strided Conv (stride 2)
Input: (C_in, H, W)
↓ Conv 3×3, s=2, p=1 → (C_out, H/2, W/2)
Approach B accomplishes the same spatial reduction with:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
import torchimport torch.nn as nn def demonstrate_stride_effects(): """ Illustrate how stride affects convolution output. """ # Input: 32 channels, 28×28 spatial x = torch.randn(1, 32, 28, 28) # Stride 1: preserves spatial size (with appropriate padding) conv_s1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) out_s1 = conv_s1(x) print(f"Stride 1: {x.shape} → {out_s1.shape}") # (1, 64, 28, 28) # Stride 2: halves spatial size conv_s2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) out_s2 = conv_s2(x) print(f"Stride 2: {x.shape} → {out_s2.shape}") # (1, 64, 14, 14) # Stride 4: quarters spatial size conv_s4 = nn.Conv2d(32, 64, kernel_size=3, stride=4, padding=1) out_s4 = conv_s4(x) print(f"Stride 4: {x.shape} → {out_s4.shape}") # (1, 64, 7, 7) # Parameter counts are identical regardless of stride! print(f"Parameters (all same): {sum(p.numel() for p in conv_s1.parameters())}") demonstrate_stride_effects() def compute_output_dimensions(H, W, kernel, stride, padding): """ Calculate output dimensions for strided convolution. """ H_out = (H + 2*padding - kernel) // stride + 1 W_out = (W + 2*padding - kernel) // stride + 1 return H_out, W_out def explore_common_configs(): """ Show output dimensions for common strided convolution configurations. """ input_sizes = [(224, 224), (32, 32), (112, 112)] configs = [ ("3×3, s=2, p=1", 3, 2, 1), ("4×4, s=2, p=1", 4, 2, 1), ("7×7, s=2, p=3", 7, 2, 3), ("3×3, s=2, p=0", 3, 2, 0), ] for name, k, s, p in configs: print(f"{name}:") for H, W in input_sizes: H_out, W_out = compute_output_dimensions(H, W, k, s, p) print(f" {H}×{W} → {H_out}×{W_out}") explore_common_configs()The choice between pooling and strided convolutions is not purely a matter of performance—each approach has distinct properties that make it more suitable for certain scenarios.
The "Strided Convolutions – All You Need" Paper:
Springenberg et al. (2014) systematically studied replacing pooling with strided convolutions. Their key findings:
This work catalyzed the modern trend toward pooling-free architectures.
| Property | Strided Conv | Max Pool | Avg Pool |
|---|---|---|---|
| Parameters | k² × C_in × C_out | 0 | 0 |
| Learnable | Yes | No | No |
| Operation type | Linear (weighted sum) | Nonlinear (max) | Linear (mean) |
| Gradient distribution | Weighted by kernel | Winner-take-all | Uniform |
| Translation invariance | Must be learned | Built-in (local) | Built-in (local) |
| Memory (forward) | Store input + output | Store input + indices | Store input |
| Compute (forward) | MAC operations | Comparisons | Additions + division |
Most state-of-the-art architectures (ResNet, EfficientNet, ConvNeXt) use strided convolutions for intermediate downsampling and reserve pooling only for the final global aggregation (GAP before classification). This hybrid approach captures the best of both worlds.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
import torchimport torch.nn as nnimport time class PoolingDownBlock(nn.Module): """ Traditional: Conv (stride 1) + MaxPool """ def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1) self.bn = nn.BatchNorm2d(out_ch) self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.pool(x) return x class StridedDownBlock(nn.Module): """ Modern: Strided Conv (stride 2) """ def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1) self.bn = nn.BatchNorm2d(out_ch) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x def compare_approaches(): """ Compare the two approaches in terms of output, parameters, and speed. """ in_ch, out_ch = 64, 128 x = torch.randn(32, in_ch, 56, 56) pool_block = PoolingDownBlock(in_ch, out_ch) strided_block = StridedDownBlock(in_ch, out_ch) # Output shapes out_pool = pool_block(x) out_strided = strided_block(x) print("Output shapes (should be identical):") print(f" Pooling approach: {out_pool.shape}") print(f" Strided approach: {out_strided.shape}") # Parameter counts pool_params = sum(p.numel() for p in pool_block.parameters()) strided_params = sum(p.numel() for p in strided_block.parameters()) print(f"Parameter counts:") print(f" Pooling approach: {pool_params:,}") print(f" Strided approach: {strided_params:,}") print(f" Difference: {pool_params - strided_params:,} (pooling does intermediate conv at full res)") # Note: The pooling approach has same conv params, but does the conv # at full resolution before downsampling. Strided does it directly. compare_approaches() def analyze_gradient_patterns(): """ Compare gradient patterns between max pooling and strided convolution. """ x = torch.randn(1, 1, 4, 4, requires_grad=True) # Max pooling path pool = nn.MaxPool2d(2, 2) pooled = pool(x) loss_pool = pooled.sum() loss_pool.backward() grad_pool = x.grad.clone() x.grad.zero_() # Strided conv path (with identity-ish kernel for comparison) conv = nn.Conv2d(1, 1, 2, stride=2, bias=False) nn.init.constant_(conv.weight, 0.25) # Average-like behavior strided = conv(x) loss_strided = strided.sum() loss_strided.backward() grad_strided = x.grad.clone() print("Input:") print(x.detach().squeeze()) print("Max pooling gradient (sparse, winner-take-all):") print(grad_pool.squeeze()) print("Strided conv gradient (distributed by kernel weights):") print(grad_strided.squeeze()) analyze_gradient_patterns()Strided convolutions affect how receptive fields grow through the network. Understanding this is crucial for designing networks that "see" appropriately sized regions at each layer.
Receptive Field Growth Formula:
For a sequence of layers with kernel sizes $k_i$ and strides $s_i$, the receptive field $R_n$ at layer $n$ is:
$$R_n = 1 + \sum_{i=1}^{n} (k_i - 1) \prod_{j=1}^{i-1} s_j$$
The key insight: stride from earlier layers multiplies the contribution of later layers.
Example Comparison:
Consider building a network where the final feature map should "see" a 32×32 region of the input:
| Layer | Kernel | Stride | RF (Conv+Pool) | RF (Strided Conv) |
|---|---|---|---|---|
| Input | — | — | 1 | 1 |
| Layer 1 | 3×3 | 1→pool2 | 3 → 6 | 3 (s=2 directly would be 3) |
| Layer 2 | 3×3 | 1→pool2 | 10 → 14 | 7 |
| Layer 3 | 3×3 | 1→pool2 | 22 → 30 | 15 |
| Layer 4 | 3×3 | 1/pool2 | 46 → 62 | 31 |
The theoretical receptive field assumes equal contribution from all input pixels. In practice, the 'effective receptive field' is much smaller due to Gaussian-like attention patterns—central pixels contribute more than peripheral ones. Strided convolutions can have different effective receptive field profiles compared to pooling.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
import torchimport torch.nn as nnimport numpy as np def calculate_receptive_field(layers): """ Calculate theoretical receptive field for a sequence of conv/pool layers. Args: layers: List of (kernel_size, stride) tuples Returns: Final receptive field size """ rf = 1 cumulative_stride = 1 for i, (kernel, stride) in enumerate(layers): rf += (kernel - 1) * cumulative_stride cumulative_stride *= stride return rf def compare_rf_growth(): """ Compare receptive field growth for different downsampling strategies. """ # Strategy 1: Conv (3×3, s=1) + MaxPool (2×2, s=2) at each stage conv_pool_layers = [ (3, 1), (2, 2), # Stage 1 (3, 1), (2, 2), # Stage 2 (3, 1), (2, 2), # Stage 3 (3, 1), (2, 2), # Stage 4 ] # Strategy 2: Strided conv (3×3, s=2) at each stage strided_layers = [ (3, 2), # Stage 1 (3, 2), # Stage 2 (3, 2), # Stage 3 (3, 2), # Stage 4 ] # Strategy 3: Two 3×3 convs per stage, then strided deep_strided = [ (3, 1), (3, 2), # Stage 1 (3, 1), (3, 2), # Stage 2 (3, 1), (3, 2), # Stage 3 (3, 1), (3, 2), # Stage 4 ] print("Receptive field comparison (4 downsampling stages):") print(f" Conv + Pool: {calculate_receptive_field(conv_pool_layers)}") print(f" Strided Conv: {calculate_receptive_field(strided_layers)}") print(f" Deep + Strided: {calculate_receptive_field(deep_strided)}") compare_rf_growth() class ReceptiveFieldVisualizer: """ Visualize the effective receptive field by backpropagating from a single output neuron. """ def __init__(self, model): self.model = model def compute_erf(self, input_shape, output_position): """ Compute effective receptive field via gradient-based method. """ self.model.eval() # Create input that requires gradients x = torch.zeros(input_shape, requires_grad=True) # Forward pass output = self.model(x) # Create gradient target (single position) grad_target = torch.zeros_like(output) B, C, H, W = output.shape grad_target[0, :, output_position[0], output_position[1]] = 1.0 # Backward pass output.backward(gradient=grad_target) # The gradient magnitude shows the receptive field erf = x.grad.abs().sum(dim=1).squeeze() # Sum over channels return erf.detach().numpy() # Demonstrate receptive field differencesdef demonstrate_rf_differences(): """ Show how different architectures have different RF patterns. """ input_shape = (1, 3, 64, 64) # Model 1: Max pooling based model_pool = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), ) # Model 2: Strided conv based model_strided = nn.Sequential( nn.Conv2d(3, 32, 3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(), ) # Both produce 16×16 output from 64×64 input x = torch.randn(input_shape) print(f"Pool model output: {model_pool(x).shape}") print(f"Strided model output: {model_strided(x).shape}") # The receptive field sizes and patterns differ demonstrate_rf_differences()Design Implications:
Pure strided convolutions build receptive fields more slowly than conv+pool. More strided layers may be needed to achieve the same final receptive field.
To compensate, architectures using strided convolutions often:
The effective receptive field depends not just on architecture but on learned weights. Networks learn to focus on task-relevant regions, which may be smaller or differently shaped than theoretical calculations suggest.
A critical but often overlooked issue with strided convolutions (and pooling) is aliasing—the violation of the Nyquist sampling theorem that can harm translation invariance.
The Aliasing Problem:
When we downsample by stride $s$, we're sampling at 1/$s$ the original frequency. If the feature map contains high-frequency components beyond the Nyquist limit, these alias into lower frequencies, creating artifacts that change unpredictably with small input translations.
Zhang's Key Observation (2019):
Richard Zhang's paper "Making Convolutional Networks Shift-Invariant Again" demonstrated that:
Classic CNNs with strided operations can classify an object as 'cat' in one position but 'dog' after a 1-pixel shift—not because of visibility changes, but because of aliasing artifacts in the downsampling process. This is a fundamental robustness issue.
BlurPool: The Solution:
BlurPool applies a low-pass blur filter before downsampling:
$$\text{BlurPool}(X) = \text{Subsample}(\text{Blur}(X))$$
The blur filter removes high-frequency components that would alias, ensuring that downsampling only affects frequencies we're actually keeping. This is the digital signal processing principle of anti-aliasing.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as np class BlurPool(nn.Module): """ Anti-aliased downsampling layer. Applies a blur (low-pass filter) before subsampling to remove aliasing artifacts. """ def __init__(self, channels, filter_size=4, stride=2, padding=1): super().__init__() self.channels = channels self.stride = stride self.padding = padding # Create blur kernel (binomial/Gaussian-like) if filter_size == 2: kernel = np.array([1., 1.]) elif filter_size == 3: kernel = np.array([1., 2., 1.]) elif filter_size == 4: kernel = np.array([1., 3., 3., 1.]) elif filter_size == 5: kernel = np.array([1., 4., 6., 4., 1.]) else: raise ValueError(f"filter_size {filter_size} not supported") # Create 2D kernel via outer product kernel = torch.tensor(kernel, dtype=torch.float32) kernel = kernel[:, None] * kernel[None, :] kernel = kernel / kernel.sum() # Normalize # Expand to match channel count (depthwise) kernel = kernel.repeat(channels, 1, 1, 1) self.register_buffer('kernel', kernel) def forward(self, x): # Depthwise blur followed by subsampling blurred = F.conv2d( x, self.kernel, stride=self.stride, padding=self.padding, groups=self.channels ) return blurred class AntialiasedStridedConv(nn.Module): """ Strided convolution with anti-aliasing. Applies convolution at stride 1, then uses BlurPool for anti-aliased downsampling. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, blur_filter_size=4): super().__init__() # Convolution without downsampling self.conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=1, padding=padding ) # Anti-aliased downsampling self.blur_pool = BlurPool(out_channels, blur_filter_size, stride) def forward(self, x): x = self.conv(x) x = self.blur_pool(x) return x def test_shift_invariance(): """ Test shift invariance of different downsampling methods. """ channels = 32 # Standard strided conv standard = nn.Conv2d(channels, channels, 3, stride=2, padding=1) # Anti-aliased strided conv antialiased = AntialiasedStridedConv(channels, channels) # Create test input and shifted versions x = torch.randn(1, channels, 32, 32) x_shifted = torch.roll(x, shifts=1, dims=3) # Shift by 1 pixel # Process both out_std = standard(x) out_std_shifted = standard(x_shifted) out_aa = antialiased(x) out_aa_shifted = antialiased(x_shifted) # Measure consistency (shifted output should be shifted version of original) # Perfect shift invariance: output shifts by stride/2 for 1-pixel input shift # Compare outputs (accounting for different output positions) std_diff = (out_std[:,:,:,:-1] - out_std_shifted[:,:,:,:-1]).abs().mean() aa_diff = (out_aa[:,:,:,:-1] - out_aa_shifted[:,:,:,:-1]).abs().mean() print("Shift consistency (lower = more invariant):") print(f" Standard strided conv: {std_diff:.4f}") print(f" Antialiased strided conv: {aa_diff:.4f}") test_shift_invariance() class AntialiasedMaxPool(nn.Module): """ BlurPool-style anti-aliased max pooling. Applies max at stride 1, then blur-downsamples. """ def __init__(self, channels, kernel_size=2, stride=2, blur_filter_size=4): super().__init__() self.max_pool = nn.MaxPool2d(kernel_size, stride=1, padding=(kernel_size-1)//2) self.blur_pool = BlurPool(channels, blur_filter_size, stride) def forward(self, x): x = self.max_pool(x) x = self.blur_pool(x) return xWhen to Use Anti-Aliasing:
| Scenario | Anti-Aliasing Value | Notes |
|---|---|---|
| Standard classification | Moderate | Improves consistency, slight accuracy gain |
| Object detection | High | Position sensitivity matters greatly |
| Video processing | High | Frame-to-frame consistency crucial |
| Adversarial robustness | High | Aliasing can be exploited by attacks |
| Resource-constrained | Low | BlurPool adds compute overhead |
Trade-offs:
Modern CNN architectures have developed several patterns for incorporating strided convolutions effectively.
Pattern 1: ResNet-Style Downsampling
ResNet uses strided convolutions in the first layer of residual blocks that change spatial dimensions:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
import torchimport torch.nn as nn class ResNetBottleneck(nn.Module): """ ResNet bottleneck block with optional downsampling. The downsampling happens in the 3×3 conv (middle layer) for better gradient flow compared to first layer. """ expansion = 4 def __init__(self, in_channels, out_channels, stride=1): super().__init__() width = out_channels # Bottleneck width # 1×1 reduce self.conv1 = nn.Conv2d(in_channels, width, 1, bias=False) self.bn1 = nn.BatchNorm2d(width) # 3×3 process (with optional striding) self.conv2 = nn.Conv2d(width, width, 3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(width) # 1×1 expand self.conv3 = nn.Conv2d(width, out_channels * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) self.relu = nn.ReLU(inplace=True) # Shortcut needs adjustment if dimensions change self.shortcut = nn.Identity() if stride != 1 or in_channels != out_channels * self.expansion: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride, bias=False), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): identity = self.shortcut(x) out = self.relu(self.bn1(self.conv1(x))) out = self.relu(self.bn2(self.conv2(out))) # Downsampling here out = self.bn3(self.conv3(out)) out += identity out = self.relu(out) return out class ResNetV1Bottleneck(nn.Module): """ Original ResNet-v1 downsampling in first conv. Generally considered inferior to v1.5/v2. """ expansion = 4 def __init__(self, in_channels, out_channels, stride=1): super().__init__() width = out_channels # Downsampling in first 1×1 conv (v1 original) self.conv1 = nn.Conv2d(in_channels, width, 1, stride=stride, bias=False) self.bn1 = nn.BatchNorm2d(width) self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(width) self.conv3 = nn.Conv2d(width, out_channels * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) self.relu = nn.ReLU(inplace=True) self.shortcut = nn.Identity() if stride != 1 or in_channels != out_channels * self.expansion: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride, bias=False), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): identity = self.shortcut(x) out = self.relu(self.bn1(self.conv1(x))) # Downsampling here out = self.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) out += identity return self.relu(out) def compare_variants(): """ Compare information flow in different ResNet variants. """ x = torch.randn(1, 256, 14, 14) v1 = ResNetV1Bottleneck(256, 128, stride=2) # Stride in 1×1 v2 = ResNetBottleneck(256, 128, stride=2) # Stride in 3×3 out_v1 = v1(x) out_v2 = v2(x) print(f"ResNet-v1 output: {out_v1.shape}") print(f"ResNet-v1.5/v2 output: {out_v2.shape}") print("ResNet-v1.5 (stride in 3×3) is preferred because:") print(" - 3×3 conv has larger kernel, less information loss") print(" - Better gradient flow through residual branch") compare_variants()Pattern 2: Stem Downsampling
The network "stem" performs initial aggressive downsampling on high-resolution inputs:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
import torchimport torch.nn as nn class ResNetStem(nn.Module): """ Classic ResNet stem: 7×7 conv + max pool. Reduces 224→56 (4× reduction). """ def __init__(self, in_channels=3, stem_channels=64): super().__init__() self.conv = nn.Conv2d(in_channels, stem_channels, 7, stride=2, padding=3, bias=False) self.bn = nn.BatchNorm2d(stem_channels) self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(3, stride=2, padding=1) def forward(self, x): x = self.conv(x) # 224 → 112 x = self.bn(x) x = self.relu(x) x = self.pool(x) # 112 → 56 return x class EfficientNetStem(nn.Module): """ EfficientNet stem: single strided 3×3 conv. More parameter-efficient than 7×7. """ def __init__(self, in_channels=3, stem_channels=32): super().__init__() self.conv = nn.Conv2d(in_channels, stem_channels, 3, stride=2, padding=1, bias=False) self.bn = nn.BatchNorm2d(stem_channels) self.act = nn.SiLU() # Swish activation def forward(self, x): x = self.conv(x) # 224 → 112 x = self.bn(x) x = self.act(x) return x class ConvNeXtStem(nn.Module): """ ConvNeXt stem: Patchify with 4×4 strided conv. Directly reduces 4× in one step, ViT-inspired. """ def __init__(self, in_channels=3, stem_channels=96): super().__init__() self.conv = nn.Conv2d(in_channels, stem_channels, kernel_size=4, stride=4) self.norm = nn.LayerNorm(stem_channels) # Different norm! def forward(self, x): x = self.conv(x) # 224 → 56 in one step # LayerNorm expects (B, H, W, C) x = x.permute(0, 2, 3, 1) x = self.norm(x) x = x.permute(0, 3, 1, 2) return x def compare_stems(): x = torch.randn(1, 3, 224, 224) stems = [ ("ResNet", ResNetStem()), ("EfficientNet", EfficientNetStem()), ("ConvNeXt", ConvNeXtStem()), ] for name, stem in stems: out = stem(x) params = sum(p.numel() for p in stem.parameters()) print(f"{name}: {x.shape} → {out.shape}, params: {params:,}") compare_stems()| Architecture | Stem Design | Reduction | Parameters |
|---|---|---|---|
| ResNet | 7×7 s=2 + MaxPool s=2 | 4× | ~9.5K |
| EfficientNet | 3×3 s=2 | 2× | ~0.9K |
| ConvNeXt | 4×4 s=4 | 4× | ~4.6K |
| Vision Transformer | 16×16 s=16 (patch embed) | 16× | ~0.6M |
The choice of downsampling mechanism significantly affects how gradients flow during backpropagation, with implications for training stability and convergence.
Strided Convolution Gradients:
For a strided convolution with stride $s$, the backward pass involves:
This is sometimes called "fractional striding" or transposed convolution. The gradient is distributed based on the learned kernel weights.
Unlike max pooling's winner-take-all gradient, strided convolutions distribute gradients to all input positions, weighted by kernel values. This can provide more uniform training signals across the feature map, though the distribution is learned rather than predetermined.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
import torchimport torch.nn as nn def analyze_gradient_statistics(): """ Compare gradient statistics for different downsampling methods. """ batch_size = 32 channels = 64 H, W = 28, 28 # Create input x = torch.randn(batch_size, channels, H, W, requires_grad=True) # Different downsampling approaches approaches = { "Max Pool": nn.MaxPool2d(2, 2), "Avg Pool": nn.AvgPool2d(2, 2), "Strided Conv": nn.Conv2d(channels, channels, 3, stride=2, padding=1), } print("Gradient Statistics (mean, std, sparsity):") print("-" * 60) for name, layer in approaches.items(): x_copy = x.clone().detach().requires_grad_(True) # Forward pass out = layer(x_copy) # Simulate loss gradient from above grad_output = torch.randn_like(out) # Backward pass out.backward(gradient=grad_output) # Analyze gradients grad = x_copy.grad grad_mean = grad.abs().mean().item() grad_std = grad.std().item() grad_sparsity = (grad == 0).float().mean().item() print(f"{name:15s}: mean={grad_mean:.4f}, std={grad_std:.4f}, " f"sparsity={grad_sparsity*100:.1f}%") analyze_gradient_statistics() def visualize_gradient_patterns(): """ Visualize how different methods distribute gradients spatially. """ # Single sample, single channel for clarity x = torch.randn(1, 1, 8, 8, requires_grad=True) # Create downstream gradient (single position activated) downstream_grad = torch.zeros(1, 1, 4, 4) downstream_grad[0, 0, 1, 1] = 1.0 # Single strong gradient methods = { "Max Pool": nn.MaxPool2d(2, 2), "Avg Pool": nn.AvgPool2d(2, 2), "Strided 2×2": nn.Conv2d(1, 1, 2, stride=2, bias=False), "Strided 3×3": nn.Conv2d(1, 1, 3, stride=2, padding=1, bias=False), } print("Gradient patterns from single downstream position [1,1]:") for name, layer in methods.items(): x_copy = x.clone().detach().requires_grad_(True) out = layer(x_copy) out.backward(gradient=downstream_grad) gradient = x_copy.grad.squeeze() nonzero = (gradient != 0).sum().item() print(f"{name}:") print(f" Input positions receiving gradient: {nonzero}/64") # The pattern depends on kernel weights for strided conv visualize_gradient_patterns() class GradientFlowMonitor: """ Monitor gradient magnitudes through network layers during training. """ def __init__(self, model): self.model = model self.gradient_stats = {} self._register_hooks() def _register_hooks(self): for name, module in self.model.named_modules(): module.register_full_backward_hook( lambda m, grad_in, grad_out, n=name: self._save_gradient(n, grad_in, grad_out) ) def _save_gradient(self, name, grad_input, grad_output): if grad_output[0] is not None: self.gradient_stats[name] = { 'out_mean': grad_output[0].abs().mean().item(), 'out_std': grad_output[0].std().item(), } def report(self): for name, stats in self.gradient_stats.items(): print(f"{name}: mean={stats['out_mean']:.6f}")Training Implications:
| Aspect | Max Pool | Avg Pool | Strided Conv |
|---|---|---|---|
| Gradient sparsity | ~75% zeros | 0% zeros | 0% zeros |
| Gradient magnitude | Preserved | Reduced by 1/k² | Weighted by kernel |
| Training signal | Concentrated on "winners" | Distributed evenly | Learned distribution |
| Feature competition | Strong (winner-take-all) | None | Learned |
| Convergence behavior | Often faster initially | Slower, more stable | Task-dependent |
Practical Observations:
Deep networks: Strided convolutions often train more stably due to more uniform gradient flow
Transfer learning: Networks pretrained with pooling may not transfer perfectly when pooling is replaced
Initialization: Strided convolution weights should be initialized carefully; standard init is usually fine but task-specific tuning can help
Strided convolutions offer a powerful, learnable alternative to fixed pooling operations, and have become the dominant approach in modern CNN architectures.
You now understand strided convolutions as a core building block of modern CNNs. This knowledge enables you to make informed decisions about downsampling strategies and appreciate why state-of-the-art architectures have evolved away from traditional pooling.
What's Next:
In the final page of this module, we explore Pooling Alternatives—including dilated/atrous convolutions, deformable convolutions, attention-based pooling, and other modern approaches that challenge or complement traditional spatial aggregation methods.