Loading learning content...
Having understood the motivation behind normalization, we now turn to the precise mechanics of Batch Normalization (BatchNorm)—the technique that revolutionized deep learning training in 2015.
BatchNorm is deceptively simple in concept: normalize each feature to have zero mean and unit variance, then scale and shift. But this apparent simplicity masks profound effects on training dynamics, gradient flow, and network expressivity.
This page develops a complete mathematical understanding of BatchNorm, from basic formulas to subtle implementation details that can make or break training.
By the end of this page, you will master the complete BatchNorm forward pass, understand the role of learnable parameters γ and β, analyze the gradient flow through BatchNorm, and know the implementation details necessary for correct behavior.
Batch Normalization operates on a mini-batch of data, normalizing the activations at each layer using statistics computed over the batch. Let's build up the transformation step by step.
Setup:
Consider a mini-batch of m samples. For a given layer with d features, let x be the batch of pre-activations:
BatchNorm normalizes each feature dimension independently across the batch.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
import numpy as np def batchnorm_forward_pedagogical(x, gamma, beta, eps=1e-5): """ Step-by-step Batch Normalization forward pass. Args: x: Input tensor of shape (batch_size, features) gamma: Scale parameter, shape (features,) beta: Shift parameter, shape (features,) eps: Small constant for numerical stability Returns: y: Normalized and scaled output cache: Values needed for backward pass """ m, d = x.shape # ═══════════════════════════════════════════════════════════════ # STEP 1: Compute batch mean for each feature # ═══════════════════════════════════════════════════════════════ # μ_k = (1/m) * Σᵢ x_i,k # Shape: (d,) - one mean per feature mu = np.mean(x, axis=0) print(f"Step 1 - Batch means: shape={mu.shape}") print(f" First 3 features: {mu[:3]}") # ═══════════════════════════════════════════════════════════════ # STEP 2: Center the data (subtract mean) # ═══════════════════════════════════════════════════════════════ # x_centered_i,k = x_i,k - μ_k x_centered = x - mu # Broadcasting over batch dimension print(f"\nStep 2 - Centered data: shape={x_centered.shape}") print(f" New column means (should be ~0): {np.mean(x_centered, axis=0)[:3]}") # ═══════════════════════════════════════════════════════════════ # STEP 3: Compute batch variance for each feature # ═══════════════════════════════════════════════════════════════ # σ²_k = (1/m) * Σᵢ (x_i,k - μ_k)² # Note: Using biased variance (divide by m, not m-1) # Shape: (d,) var = np.mean(x_centered ** 2, axis=0) print(f"\nStep 3 - Batch variances: shape={var.shape}") print(f" First 3 features: {var[:3]}") # ═══════════════════════════════════════════════════════════════ # STEP 4: Compute standard deviation (with epsilon for stability) # ═══════════════════════════════════════════════════════════════ # σ_k = √(σ²_k + ε) # epsilon prevents division by zero when variance is 0 std = np.sqrt(var + eps) print(f"\nStep 4 - Standard deviations: {std[:3]}") # ═══════════════════════════════════════════════════════════════ # STEP 5: Normalize (divide by standard deviation) # ═══════════════════════════════════════════════════════════════ # x̂_i,k = (x_i,k - μ_k) / σ_k # Result: zero mean, unit variance for each feature x_norm = x_centered / std print(f"\nStep 5 - Normalized data statistics:") print(f" Mean per feature (should be ~0): {np.mean(x_norm, axis=0)[:3]}") print(f" Std per feature (should be ~1): {np.std(x_norm, axis=0)[:3]}") # ═══════════════════════════════════════════════════════════════ # STEP 6: Scale and shift (learnable transformation) # ═══════════════════════════════════════════════════════════════ # y_i,k = γ_k * x̂_i,k + β_k # gamma and beta are learned during training y = gamma * x_norm + beta print(f"\nStep 6 - Final output after scale and shift:") print(f" Output shape: {y.shape}") print(f" Output mean per feature: {np.mean(y, axis=0)[:3]}") print(f" Output std per feature: {np.std(y, axis=0)[:3]}") # Cache values needed for backward pass cache = { 'x': x, 'x_centered': x_centered, 'std': std, 'x_norm': x_norm, 'gamma': gamma, 'eps': eps } return y, cache # Example usagenp.random.seed(42)batch_size, num_features = 32, 64 # Input with non-zero mean and non-unit variancex = np.random.randn(batch_size, num_features) * 3 + 2 # Learnable parameters (initialized to identity transform)gamma = np.ones(num_features) # Scalebeta = np.zeros(num_features) # Shift y, cache = batchnorm_forward_pedagogical(x, gamma, beta)The Complete Formula:
Combining all steps, Batch Normalization computes:
$$\hat{x}^{(k)} = \frac{x^{(k)} - \mu_{\mathcal{B}}^{(k)}}{\sqrt{\sigma_{\mathcal{B}}^{2(k)} + \epsilon}}$$
$$y^{(k)} = \gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$
where:
A crucial design choice in BatchNorm is the inclusion of learnable parameters γ (scale) and β (shift). Without these, BatchNorm would strictly enforce zero mean and unit variance, which could limit the network's expressivity.
Why Are γ and β Necessary?
Consider a sigmoid activation σ(x). If inputs are normalized to zero mean and unit variance, most activations will fall in the approximately-linear region around x=0. The network loses the ability to exploit the sigmoid's saturation properties when needed.
With γ and β, the network can learn to:
| γ Value | β Value | Effect | Use Case |
|---|---|---|---|
| 1.0 | 0.0 | Identity (normalized) | Default initialization |
| < 1.0 | any | Compress variance | Reduce activation range |
1.0 | any | Expand variance | Increase activation range |
| any |
| Shift positive | Bias toward ReLU active region |
| any |
| Shift negative | Bias toward ReLU inactive region |
| σ_B | μ_B | Undo normalization | Recover original distribution |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
import numpy as npimport matplotlib.pyplot as plt def demonstrate_gamma_beta_effects(): """ Visualize how gamma and beta affect the output distribution. """ np.random.seed(42) # Generate normalized data (zero mean, unit variance) x_norm = np.random.randn(10000) # Different (gamma, beta) configurations configs = [ (1.0, 0.0, "Identity: γ=1, β=0"), (2.0, 0.0, "Scaled: γ=2, β=0"), (0.5, 0.0, "Compressed: γ=0.5, β=0"), (1.0, 2.0, "Shifted: γ=1, β=2"), (2.0, 1.0, "Both: γ=2, β=1"), ] results = {} for gamma, beta, label in configs: y = gamma * x_norm + beta results[label] = { 'mean': np.mean(y), 'std': np.std(y), 'min': np.min(y), 'max': np.max(y) } print(f"{label}") print(f" Mean: {results[label]['mean']:.3f}, Std: {results[label]['std']:.3f}") print(f" Range: [{results[label]['min']:.3f}, {results[label]['max']:.3f}]\n") return results # The key insight: gamma and beta give the network full control# over the output distribution while still benefiting from# the normalization during the forward pass. results = demonstrate_gamma_beta_effects() # Initialization strategy:# gamma = 1, beta = 0 means the layer starts as identity# This is important: at initialization, BatchNorm doesn't change the signalprint("\nInitialization insight:")print("Starting with γ=1, β=0 means BatchNorm is initially a 'pass-through'")print("The network can then learn to adjust these parameters as needed")Initializing γ=1 and β=0 means BatchNorm outputs exactly the normalized values. This creates an identity mapping at initialization, allowing the network to gradually learn what scale and shift are optimal. This principle—starting as identity—appears throughout deep learning in skip connections, gating mechanisms, and residual networks.
Mathematical Perspective on Expressivity:
Claim: With learnable γ and β, BatchNorm can represent any linear transformation of the whitened features.
Proof sketch: The normalized output x̂ has zero mean and unit variance. The transformation y = γx̂ + β can produce:
Therefore, BatchNorm doesn't restrict the family of functions the network can represent—it only changes how those functions are parameterized.
The computation of batch statistics—mean and variance—is central to BatchNorm's operation. Understanding the nuances of this computation is essential for correct implementation and debugging.
Batch Mean:
$$\mu_{\mathcal{B}}^{(k)} = \frac{1}{m} \sum_{i=1}^{m} x_i^{(k)}$$
This is a simple average over the batch dimension for each feature.
Batch Variance:
$$\sigma_{\mathcal{B}}^{2(k)} = \frac{1}{m} \sum_{i=1}^{m} (x_i^{(k)} - \mu_{\mathcal{B}}^{(k)})^2$$
Important: BatchNorm uses the biased variance estimator (dividing by m, not m-1). This is a design choice, not an oversight.
| Aspect | Biased (1/m) | Unbiased (1/(m-1)) |
|---|---|---|
| Formula | Σ(x-μ)²/m | Σ(x-μ)²/(m-1) |
| Used in | BatchNorm training | Statistical estimation |
| Gradient flow | Simpler gradients | Slightly more complex |
| Behavior at m=1 | Zero variance | Division by zero |
| Batch size sensitivity | Lower | Higher for small batches |
Convolutional Networks: Spatial Statistics
For convolutional layers, the input has shape (N, C, H, W):
BatchNorm in CNNs computes statistics per channel, averaging over the batch and both spatial dimensions:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
import numpy as np def batchnorm_conv_forward(x, gamma, beta, eps=1e-5): """ BatchNorm for convolutional layers. Args: x: Input tensor of shape (N, C, H, W) gamma: Scale parameter, shape (C,) beta: Shift parameter, shape (C,) The key difference from standard BatchNorm: Statistics are computed per-channel, averaging over N, H, and W. """ N, C, H, W = x.shape # Reshape for computation: treat each spatial location as a sample # This gives us N*H*W "samples" per channel # Compute mean per channel # Mean over batch (axis 0) and spatial dimensions (axes 2, 3) mu = np.mean(x, axis=(0, 2, 3), keepdims=True) # Shape: (1, C, 1, 1) # Compute variance per channel var = np.mean((x - mu) ** 2, axis=(0, 2, 3), keepdims=True) # Normalize x_norm = (x - mu) / np.sqrt(var + eps) # Scale and shift (broadcasting gamma and beta) # gamma and beta have shape (C,), need to reshape to (1, C, 1, 1) gamma_reshaped = gamma.reshape(1, C, 1, 1) beta_reshaped = beta.reshape(1, C, 1, 1) y = gamma_reshaped * x_norm + beta_reshaped # Effective number of samples per channel effective_batch_size = N * H * W print(f"Effective batch size per channel: {effective_batch_size}") print(f"Statistics computed from {N} images × {H}×{W} spatial locations") return y, (mu, var, x_norm) # Examplenp.random.seed(42)N, C, H, W = 32, 64, 14, 14 # Typical intermediate CNN layerx = np.random.randn(N, C, H, W) * 2 + 0.5 gamma = np.ones(C)beta = np.zeros(C) y, cache = batchnorm_conv_forward(x, gamma, beta) # Each channel's statistics come from 32 * 14 * 14 = 6272 values# This is much larger than the batch size alone!In CNNs, the effective sample size for computing batch statistics is N × H × W, not just N. For a batch of 32 images at 14×14 spatial resolution, each channel's statistics are computed from 32 × 14 × 14 = 6,272 values. This makes BatchNorm statistics more stable in CNNs than in fully-connected layers.
The small constant ε in the denominator (√(σ² + ε)) prevents division by zero when variance is very small. While seemingly trivial, proper handling of ε is crucial for stable training.
When Variance Approaches Zero:
Variance can be very small when:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
import numpy as np def demonstrate_epsilon_importance(): """ Show why epsilon is critical for numerical stability. """ # Create a feature with very small variance x = np.array([1.0, 1.0, 1.0, 1.0, 1.0001]) # Nearly constant mu = np.mean(x) var = np.var(x) print(f"Feature values: {x}") print(f"Mean: {mu}") print(f"Variance: {var:.2e}") print() # Without epsilon: numerical disaster eps_values = [0, 1e-10, 1e-8, 1e-5, 1e-3] for eps in eps_values: try: std = np.sqrt(var + eps) x_norm = (x - mu) / std # Check for numerical issues has_nan = np.any(np.isnan(x_norm)) has_inf = np.any(np.isinf(x_norm)) max_abs = np.max(np.abs(x_norm)) status = "✓ Stable" if (not has_nan and not has_inf and max_abs < 1e10) else "✗ Unstable" print(f"ε = {eps:.0e}: std = {std:.2e}, max_abs = {max_abs:.2e} {status}") except Exception as e: print(f"ε = {eps:.0e}: Error - {e}") print() print("Key insight: ε = 1e-5 is the common default because it's:") print(" 1. Large enough to prevent numerical issues") print(" 2. Small enough not to affect normal variance values") demonstrate_epsilon_importance() # Gradient considerationsprint("\n--- Gradient Analysis ---")print("The gradient of x_norm w.r.t. variance involves 1/(var + eps)^(3/2)")print("Without eps, this explodes when var → 0")print("This would cause gradient explosions and training instability")Choosing ε:
The choice of ε involves a trade-off:
Common choices:
Gradient Through the Normalization:
The gradient of the loss with respect to the variance involves:
$$\frac{\partial \mathcal{L}}{\partial \sigma^2} \propto \frac{1}{(\sigma^2 + \epsilon)^{3/2}}$$
Without ε, this term explodes as variance approaches zero, causing gradient explosions.
When using float16 (half precision) for training, the default ε=1e-5 may be too small—float16 has limited precision. Many frameworks automatically upcast BatchNorm computations to float32 or use a larger ε. Always verify your framework's behavior when enabling mixed precision.
Understanding backpropagation through BatchNorm is essential for debugging and for understanding its regularization effects. The gradients are more complex than for simple element-wise operations because each output depends on the entire batch.
Backward Pass Derivation:
Let y = γx̂ + β where x̂ = (x - μ)/σ. Given ∂L/∂y, we need to compute ∂L/∂x, ∂L/∂γ, and ∂L/∂β.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
import numpy as np def batchnorm_backward(dy, cache): """ Backward pass for Batch Normalization. Args: dy: Gradient of loss w.r.t. output y, shape (m, d) cache: Values from forward pass Returns: dx: Gradient w.r.t. input x dgamma: Gradient w.r.t. scale parameter dbeta: Gradient w.r.t. shift parameter """ x, x_centered, std, x_norm, gamma = ( cache['x'], cache['x_centered'], cache['std'], cache['x_norm'], cache['gamma'] ) m, d = x.shape # ═══════════════════════════════════════════════════════════════ # Gradient w.r.t. beta (shift parameter) # ═══════════════════════════════════════════════════════════════ # y = gamma * x_norm + beta # ∂L/∂β = Σᵢ ∂L/∂yᵢ (sum over batch) dbeta = np.sum(dy, axis=0) # ═══════════════════════════════════════════════════════════════ # Gradient w.r.t. gamma (scale parameter) # ═══════════════════════════════════════════════════════════════ # ∂L/∂γ = Σᵢ ∂L/∂yᵢ * x̂ᵢ dgamma = np.sum(dy * x_norm, axis=0) # ═══════════════════════════════════════════════════════════════ # Gradient w.r.t. normalized input # ═══════════════════════════════════════════════════════════════ dx_norm = dy * gamma # Shape: (m, d) # ═══════════════════════════════════════════════════════════════ # Gradient w.r.t. input x (the complex part) # ═══════════════════════════════════════════════════════════════ # The complexity comes from the fact that μ and σ depend on x # x̂ = (x - μ) / σ # # Using chain rule through the normalization: # dx = dx_norm / σ # - (1/m) * Σⱼ dx_normⱼ / σ (through mean) # - x̂ * (1/m) * Σⱼ dx_normⱼ * x̂ⱼ (through variance) # Variance of sigma^-1 inv_std = 1.0 / std # Gradient through standard deviation dx_centered = dx_norm * inv_std # Gradient through variance dvar = -0.5 * np.sum(dx_norm * x_centered, axis=0) * (inv_std ** 3) dx_centered += (2.0 / m) * x_centered * dvar # Gradient through mean dmu = -np.sum(dx_centered, axis=0) dx = dx_centered + (1.0 / m) * dmu return dx, dgamma, dbeta def batchnorm_backward_efficient(dy, cache): """ Efficient backward pass - single formula. This is mathematically equivalent but computationally more efficient. """ x_norm, gamma = cache['x_norm'], cache['gamma'] std = cache['std'] m = dy.shape[0] dbeta = np.sum(dy, axis=0) dgamma = np.sum(dy * x_norm, axis=0) # Efficient combined formula dx_norm = dy * gamma dx = (1.0 / (m * std)) * ( m * dx_norm - np.sum(dx_norm, axis=0) - x_norm * np.sum(dx_norm * x_norm, axis=0) ) return dx, dgamma, dbeta # Verify gradients numericallydef numerical_gradient(f, x, eps=1e-5): """Compute numerical gradient for verification.""" grad = np.zeros_like(x) it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite']) while not it.finished: idx = it.multi_index old_val = x[idx] x[idx] = old_val + eps fxph = f() x[idx] = old_val - eps fxmh = f() x[idx] = old_val grad[idx] = (fxph - fxmh) / (2 * eps) it.iternext() return grad # Example verificationnp.random.seed(42)m, d = 4, 3x = np.random.randn(m, d)gamma = np.random.randn(d)beta = np.random.randn(d)dy = np.random.randn(m, d) # Forward passy, cache = batchnorm_forward_pedagogical(x, gamma, beta) # Backward passdx, dgamma, dbeta = batchnorm_backward(dy, cache) print("Gradient verification (should be very small differences):")print(f"dx max diff: {np.max(np.abs(dx - numerical_gradient(lambda: np.sum(dy * batchnorm_forward_pedagogical(x, gamma, beta)[0]), x))):.2e}")Key Insight: The Gradient Depends on the Whole Batch
Unlike element-wise operations where ∂L/∂xᵢ depends only on yᵢ, in BatchNorm each gradient ∂L/∂xᵢ depends on all outputs y₁, y₂, ..., yₘ through the batch statistics.
This has important implications:
The formula can be written compactly as:
$$\frac{\partial \mathcal{L}}{\partial x_i} = \frac{\gamma}{m\sigma} \left( m \frac{\partial \mathcal{L}}{\partial \hat{x}_i} - \sum_j \frac{\partial \mathcal{L}}{\partial \hat{x}_j} - \hat{x}_i \sum_j \frac{\partial \mathcal{L}}{\partial \hat{x}_j} \hat{x}_j \right)$$
During training, gradients flow through the batch statistics (μ and σ). During inference with fixed running statistics, the gradients would be different—but this doesn't matter because we don't backpropagate during inference. The distinction becomes important only in unusual scenarios like adversarial training or meta-learning.
Real-world BatchNorm implementations involve several additional details beyond the core algorithm. Understanding these is essential for debugging and for achieving optimal performance.
Affine Parameter Naming:
Different frameworks use different names for the learnable parameters:
| Framework | Scale (γ) | Shift (β) | Running Mean | Running Var |
|---|---|---|---|---|
| PyTorch | weight | bias | running_mean | running_var |
| TensorFlow/Keras | gamma | beta | moving_mean | moving_variance |
| Caffe | scale | bias | mean | variance |
| Academic Papers | γ | β | μ | σ² |
12345678910111213141516171819202122232425262728293031323334353637383940
import torchimport torch.nn as nn # PyTorch BatchNorm implementation detailsbn = nn.BatchNorm2d( num_features=64, # Number of channels (C) eps=1e-5, # Epsilon for numerical stability momentum=0.1, # For running statistics update affine=True, # Whether to learn gamma and beta track_running_stats=True # Whether to maintain running statistics) # Inspect the parametersprint("BatchNorm2d Parameters:")print(f" weight (gamma): shape={bn.weight.shape}, requires_grad={bn.weight.requires_grad}")print(f" bias (beta): shape={bn.bias.shape}, requires_grad={bn.bias.requires_grad}")print(f" running_mean: shape={bn.running_mean.shape}, requires_grad={bn.running_mean.requires_grad}")print(f" running_var: shape={bn.running_var.shape}, requires_grad={bn.running_var.requires_grad}")print(f" num_batches_tracked: {bn.num_batches_tracked}") # The 'momentum' parameter controls running statistics update# running_stat_new = (1 - momentum) * running_stat + momentum * batch_stat# Note: This is inverse of typical momentum convention! # Different BatchNorm variants in PyTorchprint("\nPyTorch BatchNorm Variants:")print(f" BatchNorm1d: For (N, C) or (N, C, L) inputs")print(f" BatchNorm2d: For (N, C, H, W) inputs") print(f" BatchNorm3d: For (N, C, D, H, W) inputs") # Affine=False disables learnable parametersbn_no_affine = nn.BatchNorm2d(64, affine=False)print(f"\nWith affine=False:")print(f" weight: {bn_no_affine.weight}") # Noneprint(f" bias: {bn_no_affine.bias}") # None # track_running_stats=False uses batch statistics even in eval modebn_no_running = nn.BatchNorm2d(64, track_running_stats=False)print(f"\nWith track_running_stats=False:")print(f" running_mean: {bn_no_running.running_mean}") # NoneTraining vs. Evaluation Mode:
BatchNorm behaves differently depending on whether the model is in training or evaluation mode:
Training mode (model.train()):
Evaluation mode (model.eval()):
Forgetting to call model.eval() before inference is a common source of bugs. With BatchNorm, this causes the model to use (invalid) batch statistics from the test batch instead of the learned running statistics. Always ensure model.eval() is called before inference and model.train() before training.
The success of batch normalization inspired numerous variants, each addressing specific limitations or use cases. Understanding these helps you choose the right normalization for your architecture.
Synchronized BatchNorm (SyncBatchNorm):
In distributed training across multiple GPUs, each GPU typically computes batch statistics from its local subset of the batch. SyncBatchNorm synchronizes statistics across GPUs.
| Variant | Use Case | Key Difference |
|---|---|---|
| Standard BatchNorm | Single GPU training | Statistics from local batch |
| SyncBatchNorm | Multi-GPU training | Statistics synchronized across GPUs |
| Ghost BatchNorm | Large batch training | Statistics from batch subsets |
| Batch Renormalization | Small batches | Constrains batch/running stat ratio |
| Virtual BatchNorm | GANs | Uses reference batch for generator |
Batch Renormalization:
Introduced to address BatchNorm's poor behavior with small batch sizes. It adds a constraint that the batch statistics should not deviate too far from the running statistics:
$$\hat{x} = \frac{x - \mu_B}{\sigma_B} \cdot r + d$$
where r = clip(σ_B/σ_running, 1/r_max, r_max) and d = clip((μ_B - μ_running)/σ_running, -d_max, d_max).
The clip bounds (r_max, d_max) are gradually relaxed during training, starting strict and becoming looser.
Why This Helps:
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
import numpy as np def batch_renorm_forward(x, gamma, beta, running_mean, running_var, training=True, momentum=0.1, eps=1e-5, r_max=3.0, d_max=5.0): """ Batch Renormalization forward pass. Args: r_max: Maximum allowed ratio of batch std to running std d_max: Maximum allowed normalized difference in means """ m, d = x.shape # Compute batch statistics batch_mean = np.mean(x, axis=0) batch_var = np.var(x, axis=0) batch_std = np.sqrt(batch_var + eps) if training: # Compute renormalization parameters r and d running_std = np.sqrt(running_var + eps) # r = σ_batch / σ_running, clipped to [1/r_max, r_max] r = np.clip(batch_std / running_std, 1/r_max, r_max) # d = (μ_batch - μ_running) / σ_running, clipped to [-d_max, d_max] d_correction = np.clip((batch_mean - running_mean) / running_std, -d_max, d_max) # Normalize with batch statistics, then apply r and d correction x_norm = (x - batch_mean) / batch_std x_norm = x_norm * r + d_correction # This is the key renormalization! # Update running statistics running_mean = (1 - momentum) * running_mean + momentum * batch_mean running_var = (1 - momentum) * running_var + momentum * batch_var else: # Evaluation: use running statistics directly x_norm = (x - running_mean) / np.sqrt(running_var + eps) # Scale and shift y = gamma * x_norm + beta return y, running_mean, running_var # During training:# - Start with strict r_max=1, d_max=0 (forces batch stats == running stats)# - Gradually increase to r_max=3, d_max=5 (allows more deviation)# - This creates a smooth transition from running stats to batch statsBatch Renormalization is most useful when: (1) batch sizes are small (< 16), (2) batch composition varies significantly during training, or (3) you observe training instability with standard BatchNorm. With large batches, standard BatchNorm is usually sufficient and simpler.
We've developed a complete understanding of the Batch Normalization algorithm. Let's consolidate the key concepts:
What's Next:
The next page examines the critical distinction between training and inference behavior in BatchNorm. We'll cover running statistics, the momentum parameter, and common pitfalls when switching between training and evaluation modes.
You now understand the complete mathematical formulation of Batch Normalization, including forward pass, backward pass, and the roles of all parameters. This foundation enables you to implement, debug, and optimize BatchNorm in any framework.