Loading content...
Standard dropout uses a fixed probability p for all neurons in a layer. But this is a strong assumption—why should every neuron be equally important, or equally prone to overfitting?
Variational Dropout extends the Bayesian interpretation of dropout to learn optimal dropout rates for each weight or neuron. The result is both a better-calibrated model and, remarkably, a form of automatic network pruning.
The Key Insight:
If dropout corresponds to a posterior over weights, then the dropout rate corresponds to the uncertainty about each weight. A high dropout rate means the weight is highly uncertain—perhaps unnecessary. A low dropout rate means the weight is important and well-determined.
By learning dropout rates during training, we can:
This page covers: (1) The mathematical formulation of variational dropout; (2) Learning individual dropout rates as variational parameters; (3) The multiplicative Gaussian interpretation; (4) Sparse variational dropout and automatic pruning; and (5) Practical implementation and training considerations.
Let's trace the development from standard dropout to variational dropout to understand the conceptual progression.
Standard Dropout Review:
In standard dropout, we multiply activations by a Bernoulli random variable: $$\tilde{a} = a \cdot \epsilon, \quad \epsilon \sim \text{Bernoulli}(1-p)$$
The dropout rate p is a hyperparameter, fixed across all neurons in a layer.
Gaussian Dropout:
Bernoulli dropout is discrete (0 or 1). Gaussian dropout uses continuous multiplicative noise: $$\tilde{a} = a \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(1, \alpha)$$
Here, α controls the noise variance. When α = p/(1-p), Gaussian dropout has the same mean and variance as Bernoulli dropout.
Why does this matter? Gaussian dropout is differentiable with respect to α. This opens the door to learning the noise level—i.e., learning the effective dropout rate.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
import numpy as npimport matplotlib.pyplot as plt class GaussianDropout: """ Gaussian dropout: multiply by Gaussian noise instead of Bernoulli. Equivalent to Bernoulli dropout in terms of mean and variance, but differentiable with respect to the noise parameter. """ def __init__(self, alpha: float = 1.0): """ Initialize with noise variance alpha. Args: alpha: Variance of multiplicative noise. alpha = p/(1-p) gives equivalent regularization to Bernoulli dropout with rate p. """ self.alpha = alpha self.training = True def forward(self, x: np.ndarray) -> np.ndarray: """ Apply Gaussian multiplicative noise. x̃ = x * ε, where ε ~ N(1, α) Note: Mean is preserved since E[ε] = 1 """ if not self.training: return x # No noise at inference # Sample multiplicative noise epsilon = np.random.normal(1.0, np.sqrt(self.alpha), size=x.shape) return x * epsilon @staticmethod def bernoulli_to_gaussian_alpha(p: float) -> float: """ Convert Bernoulli dropout rate to equivalent Gaussian alpha. For Bernoulli: E[ε] = 1-p, Var[ε] = p(1-p) For Gaussian with mean 1, we want same variance: Var[ε] = α = p(1-p) / (1-p)² = p/(1-p) But with inverted Bernoulli (scale by 1/(1-p)): E[ε] = 1, Var[ε] = p/(1-p) So α = p/(1-p) makes Gaussian equivalent to inverted Bernoulli. """ if p >= 1.0: return float('inf') return p / (1 - p) def compare_bernoulli_gaussian(): """ Show equivalence between Bernoulli and Gaussian dropout. """ np.random.seed(42) print("Bernoulli vs Gaussian Dropout Comparison") print("=" * 55) x = 5.0 # Test activation value num_samples = 100000 dropout_rates = [0.2, 0.5, 0.8] print(f"\nInput value: {x}") print(f"{'Dropout Rate':<15} {'Method':<15} {'Mean':<10} {'Std':<10}") print("-" * 55) for p in dropout_rates: # Bernoulli dropout (inverted) mask = np.random.binomial(1, 1 - p, size=num_samples) bernoulli_outputs = x * mask / (1 - p) # Gaussian dropout (equivalent) alpha = p / (1 - p) epsilon = np.random.normal(1.0, np.sqrt(alpha), size=num_samples) gaussian_outputs = x * epsilon print(f"{p:<15.1f} {'Bernoulli':<15} {bernoulli_outputs.mean():<10.4f} {bernoulli_outputs.std():<10.4f}") print(f"{'':<15} {'Gaussian':<15} {gaussian_outputs.mean():<10.4f} {gaussian_outputs.std():<10.4f}") print("-" * 55) print("\n✓ Gaussian dropout matches Bernoulli in mean and variance") print(" Key advantage: α is differentiable → can be learned!") def demonstrate_noise_levels(): """Show how alpha affects the output distribution.""" np.random.seed(42) print("\n" + "=" * 55) print("Effect of Alpha (Noise Variance)") print("=" * 55) x = 10.0 alphas = [0.0, 0.25, 1.0, 4.0, 16.0] print(f"\nInput value: {x}") print(f"{'Alpha':<10} {'Equiv. p':<12} {'Output Mean':<15} {'Output Std':<15}") print("-" * 55) for alpha in alphas: # Equivalent dropout rate p_equiv = alpha / (1 + alpha) # Inverse of alpha = p/(1-p) # Generate samples if alpha > 0: epsilon = np.random.normal(1.0, np.sqrt(alpha), size=10000) else: epsilon = np.ones(10000) output = x * epsilon print(f"{alpha:<10.2f} {p_equiv:<12.1%} {output.mean():<15.2f} {output.std():<15.2f}") print("\nInterpretation:") print(" α → 0: No dropout (ε always near 1)") print(" α → ∞: Nearly all outputs are zeros/noise (very high dropout)") print(" Intermediate α: Controlled regularization") compare_bernoulli_gaussian()demonstrate_noise_levels()The shift from Bernoulli to Gaussian dropout is crucial because Gaussian parameters are continuous and differentiable. We can now backpropagate through the dropout operation with respect to the noise level, allowing us to learn optimal dropout rates during training.
Variational dropout formalizes the idea of learning dropout rates within the variational inference framework.
Weight Posterior Parameterization:
In standard variational inference, we'd parameterize the posterior as: $$q(\mathbf{W}) = \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\Sigma})$$
where μ and Σ are learned. This requires storing variance parameters for every weight—often impractical.
The Multiplicative Noise Trick:
Variational dropout parameterizes the posterior differently: $$q(W_{ij}) = \theta_{ij} \cdot \epsilon_{ij}, \quad \epsilon_{ij} \sim \mathcal{N}(1, \alpha_{ij})$$
Here:
The effective weight distribution is: $$q(W_{ij}) = \mathcal{N}(\theta_{ij}, \alpha_{ij} \theta_{ij}^2)$$
This is a Gaussian with mean θ and variance proportional to θ². The parameter αᵢⱼ controls relative uncertainty.
The ELBO for Variational Dropout:
The training objective becomes: $$\mathcal{L}{\text{VD}} = \mathbb{E}{q(\mathbf{W})}\left[\log p(\mathcal{D}|\mathbf{W})\right] - \sum_{ij} \text{KL}\left[q(W_{ij}) | p(W_{ij})\right]$$
The key insight: with a proper derivation of the KL term, we can learn αᵢⱼ using gradient descent!
The KL Divergence:
For a log-uniform prior (improper but commonly used): $$p(|W|) \propto 1/|W|$$
The KL divergence has an elegant approximation: $$\text{KL}[q(W)||p(W)] \approx k_1 \sigma(k_2 + k_3 \log \alpha) - 0.5 \cdot m(-\log \alpha) - k_1$$
where σ is the sigmoid function, m is the softplus, and k₁, k₂, k₃ are constants.
This approximation is differentiable with respect to α, enabling end-to-end learning.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
import numpy as np class VariationalDropoutLinear: """ Variational Dropout Linear Layer. Each weight has its own learned dropout rate (alpha parameter). The layer outputs a distribution over activations, not point estimates. Key innovation: alpha is learned during training, not set as hyperparameter. """ def __init__( self, in_features: int, out_features: int, log_alpha_init: float = -5.0, # log(alpha) for numerical stability threshold: float = 3.0 # For sparsification ): """ Initialize layer. Args: in_features: Input dimension out_features: Output dimension log_alpha_init: Initial log(alpha); -5 ≈ very low dropout threshold: log(alpha) above which weights are pruned """ self.in_features = in_features self.out_features = out_features self.threshold = threshold # Weight mean parameters (what we normally call "weights") self.theta = np.random.randn(in_features, out_features) * 0.01 # Log-alpha parameters (per-weight dropout rates) # Initialized low = low dropout = keep all weights initially self.log_alpha = np.full((in_features, out_features), log_alpha_init) # Bias (no dropout applied) self.bias = np.zeros(out_features) self.training = True @property def alpha(self) -> np.ndarray: """Get alpha (dropout rate) parameters.""" return np.exp(np.clip(self.log_alpha, -10, 10)) def forward(self, x: np.ndarray) -> np.ndarray: """ Forward pass with local reparameterization. Instead of sampling weights and computing x @ W, we directly sample the output distribution. For y = x @ W where W ~ N(θ, α·θ²): E[y] = x @ θ Var[y] = x² @ (α·θ²) So we can sample: y ~ N(μ, σ²) where μ = x@θ, σ² = x²@ (α·θ²) """ # Output mean mu = x @ self.theta + self.bias if not self.training: return mu # Use mean at inference # Output variance (element-wise) # Var[y_j] = sum_i x_i² * alpha_ij * theta_ij² sigma_sq = (x ** 2) @ (self.alpha * self.theta ** 2) sigma = np.sqrt(sigma_sq + 1e-8) # Sample output using reparameterization eps = np.random.randn(*mu.shape) return mu + sigma * eps def kl_divergence(self) -> float: """ Compute KL divergence from the prior. Uses the approximation from Kingma et al. (2015) for KL[N(θ, α·θ²) || log-uniform prior]. KL ≈ 0.5 * log(1 + 1/α) - C(α) where C(α) is a correction term. """ # Constants for the approximation k1 = 0.63576 k2 = 1.87320 k3 = 1.48695 # Compute approximation log_alpha = self.log_alpha alpha = self.alpha # Sigmoid term sigmoid_term = k1 * self._sigmoid(k2 + k3 * log_alpha) # Softplus term softplus_term = 0.5 * self._softplus(-log_alpha) # Per-weight KL kl_per_weight = sigmoid_term - softplus_term - k1 # Total KL (sum over all weights) return -np.sum(kl_per_weight) def _sigmoid(self, x: np.ndarray) -> np.ndarray: return 1 / (1 + np.exp(-np.clip(x, -20, 20))) def _softplus(self, x: np.ndarray) -> np.ndarray: return np.log1p(np.exp(np.clip(x, -20, 20))) def sparsity(self) -> float: """Fraction of weights that would be pruned (high alpha).""" return np.mean(self.log_alpha > self.threshold) def get_sparse_weights(self) -> np.ndarray: """Get weights with high-dropout weights zeroed out.""" mask = self.log_alpha <= self.threshold return self.theta * mask def demonstrate_variational_dropout(): """Demonstrate variational dropout layer behavior.""" np.random.seed(42) print("Variational Dropout Layer Demonstration") print("=" * 60) # Create layer layer = VariationalDropoutLinear( in_features=100, out_features=50, log_alpha_init=-5.0 # Start with low dropout ) # Input x = np.random.randn(32, 100) # Multiple forward passes (should vary due to sampling) layer.training = True outputs = [layer.forward(x) for _ in range(10)] print("\n1. Output variability (training mode):") output_means = [out.mean() for out in outputs] output_stds = [out.std() for out in outputs] print(f" Mean of outputs: {np.mean(output_means):.4f} ± {np.std(output_means):.4f}") print(f" Std of outputs: {np.mean(output_stds):.4f} ± {np.std(output_stds):.4f}") # Inference mode (deterministic) layer.training = False inference_outputs = [layer.forward(x) for _ in range(5)] print("\n2. Output variability (inference mode):") print(f" All outputs identical: {all(np.allclose(inference_outputs[0], out) for out in inference_outputs)}") # KL divergence print(f"\n3. KL divergence: {layer.kl_divergence():.4f}") # Simulate training - some weights become high-dropout print("\n4. Simulating learned dropout rates...") # Artificially set some weights to have high dropout layer.log_alpha[:30, :] = 5.0 # High alpha = high dropout layer.log_alpha[30:60, :] = 0.0 # Medium alpha layer.log_alpha[60:, :] = -5.0 # Low alpha = keep print(f" Sparsity (fraction of pruned weights): {layer.sparsity():.1%}") # Compare full vs sparse weights full_weights = layer.theta sparse_weights = layer.get_sparse_weights() print(f" Full weights non-zero: {np.sum(full_weights != 0)}") print(f" Sparse weights non-zero: {np.sum(sparse_weights != 0)}") print("\n✓ High log(α) weights can be pruned for sparse networks!") demonstrate_variational_dropout()Instead of sampling weights and then computing outputs, variational dropout uses 'local reparameterization': compute the mean and variance of the output, then sample outputs directly. This reduces variance during training and is more computationally efficient.
One of the most remarkable properties of variational dropout is its ability to induce extreme sparsity—zeroing out the vast majority of weights while maintaining performance.
The Sparsity-Inducing Property:
When αᵢⱼ → ∞ (log αᵢⱼ → ∞), the multiplicative noise becomes so large that the weight effectively contributes nothing. The KL divergence term encourages this: for the log-uniform prior, setting weights to "infinite dropout" incurs zero KL cost.
This means the model can freely prune weights that don't contribute to prediction accuracy. The optimization naturally finds which weights are necessary and which can be removed.
Automatic Relevance Determination (ARD):
This behavior connects to a classical Bayesian technique: Automatic Relevance Determination. In ARD, we place a hierarchical prior on weights where each weight has its own precision (inverse variance). Weights with low precision are effectively pruned.
Variational dropout achieves ARD behavior through a different mechanism—learning dropout rates—but the outcome is the same: input features and hidden units that don't contribute are automatically identified and removed.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
import numpy as npfrom typing import Tuple, List class SparseVariationalDropoutNet: """ Neural network with Sparse Variational Dropout. Key properties: 1. Each weight learns its own dropout rate 2. High-dropout weights can be pruned (set to zero) 3. Results in extremely sparse networks while maintaining accuracy """ def __init__( self, layer_dims: List[int], threshold: float = 3.0, initial_log_alpha: float = -5.0 ): """ Initialize network. Args: layer_dims: List of layer dimensions [input, hidden1, ..., output] threshold: log(alpha) above which weights are considered pruned initial_log_alpha: Initial log(alpha) for all weights """ self.threshold = threshold self.layers = [] for i in range(len(layer_dims) - 1): layer = { 'theta': np.random.randn(layer_dims[i], layer_dims[i+1]) * 0.01, 'log_alpha': np.full((layer_dims[i], layer_dims[i+1]), initial_log_alpha), 'bias': np.zeros(layer_dims[i+1]) } self.layers.append(layer) self.training = True def forward(self, x: np.ndarray) -> np.ndarray: """Forward pass with variational dropout.""" h = x for i, layer in enumerate(self.layers): theta = layer['theta'] alpha = np.exp(np.clip(layer['log_alpha'], -10, 10)) bias = layer['bias'] # Mean output mu = h @ theta + bias if self.training and i < len(self.layers) - 1: # No dropout on output # Variance output sigma_sq = (h ** 2) @ (alpha * theta ** 2) sigma = np.sqrt(sigma_sq + 1e-8) # Sample eps = np.random.randn(*mu.shape) h = mu + sigma * eps else: h = mu # Activation (except last layer) if i < len(self.layers) - 1: h = np.maximum(0, h) # ReLU return h def total_kl(self) -> float: """Total KL divergence across all layers.""" total = 0.0 k1, k2, k3 = 0.63576, 1.87320, 1.48695 for layer in self.layers: log_alpha = layer['log_alpha'] sigmoid_term = k1 / (1 + np.exp(-(k2 + k3 * log_alpha))) softplus_term = 0.5 * np.log1p(np.exp(-log_alpha)) kl_per_weight = sigmoid_term - softplus_term - k1 total -= np.sum(kl_per_weight) return total def sparsity(self) -> Tuple[float, List[float]]: """ Compute sparsity statistics. Returns: overall: Fraction of pruned weights across entire network per_layer: List of per-layer sparsity fractions """ per_layer = [] total_weights = 0 total_pruned = 0 for layer in self.layers: pruned = np.sum(layer['log_alpha'] > self.threshold) total = layer['log_alpha'].size per_layer.append(pruned / total) total_weights += total total_pruned += pruned return total_pruned / total_weights, per_layer def get_effective_weights(self) -> List[np.ndarray]: """Get weight matrices with pruned weights zeroed.""" effective = [] for layer in self.layers: mask = layer['log_alpha'] <= self.threshold effective.append(layer['theta'] * mask) return effective def compression_ratio(self) -> float: """Compute compression ratio (dense / sparse parameters).""" dense_params = sum(l['theta'].size for l in self.layers) sparse_params = sum( np.sum(l['log_alpha'] <= self.threshold) for l in self.layers ) return dense_params / max(1, sparse_params) def simulate_sparse_training(): """ Simulate the effect of variational dropout training. In real training, log_alpha would be learned via gradient descent. Here we simulate the outcome where many weights become high-dropout. """ np.random.seed(42) print("Sparse Variational Dropout Simulation") print("=" * 60) # Create network net = SparseVariationalDropoutNet( layer_dims=[784, 1000, 500, 10], threshold=3.0 ) # Initial state initial_sparsity, _ = net.sparsity() print(f"\nInitial state:") print(f" Sparsity: {initial_sparsity:.1%}") print(f" Compression ratio: {net.compression_ratio():.1f}x") print(f" KL divergence: {net.total_kl():.1f}") # Simulate training outcome # In real training, optimization drives many log_alpha values up print("\nSimulating training (setting learned dropout rates)...") for layer in net.layers: # Most weights become high-dropout (pruned) log_alpha = layer['log_alpha'] # Simulate: ~95% of weights become prunable random_mask = np.random.rand(*log_alpha.shape) < 0.95 log_alpha[random_mask] = np.random.uniform(3.5, 10.0, size=random_mask.sum()) # Remaining ~5% have learned appropriate dropout rates remaining = ~random_mask log_alpha[remaining] = np.random.uniform(-5.0, 2.0, size=remaining.sum()) # Final state final_sparsity, per_layer = net.sparsity() print(f"\nAfter training:") print(f" Sparsity: {final_sparsity:.1%}") print(f" Compression ratio: {net.compression_ratio():.1f}x") print(f" KL divergence: {net.total_kl():.1f}") print(f"\nPer-layer sparsity:") for i, sp in enumerate(per_layer): print(f" Layer {i+1}: {sp:.1%} pruned") # Compare original vs effective weight matrices print(f"\nWeight matrix statistics:") effective = net.get_effective_weights() for i, (layer, eff) in enumerate(zip(net.layers, effective)): orig_nonzero = np.sum(layer['theta'] != 0) eff_nonzero = np.sum(eff != 0) print(f" Layer {i+1}: {orig_nonzero:,} → {eff_nonzero:,} non-zero weights") # Test inference x = np.random.randn(32, 784) net.training = False output = net.forward(x) print(f"\nInference output shape: {output.shape}") print(f" Output mean: {output.mean():.4f}") print(f" Output std: {output.std():.4f}") simulate_sparse_training()| Network | Dataset | Sparsity | Test Accuracy |
|---|---|---|---|
| LeNet-300-100 | MNIST | 98.9% | 98.3% |
| LeNet-5-Caffe | MNIST | 99.1% | 99.2% |
| VGG-like | CIFAR-10 | 98.1% | 91.8% |
| ResNet-56 | CIFAR-10 | 93.2% | 93.1% |
Sparse variational dropout can prune 98%+ of weights while maintaining accuracy. This is far more aggressive than magnitude-based pruning, L1 regularization, or other sparsity-inducing methods. The key is that VD learns which weights are truly unnecessary.
Training variational dropout networks requires careful attention to optimization dynamics and gradient estimation.
The Training Objective:
$$\mathcal{L} = \frac{1}{N} \sum_n \mathbb{E}{q(\mathbf{W})}\left[\ell(f{\mathbf{W}}(\mathbf{x}n), y_n)\right] + \frac{\beta}{N} \sum{ij} \text{KL}\left[q(W_{ij}) | p(W_{ij})\right]$$
where β is an optional weight on the KL term.
Gradient Estimation:
We need gradients with respect to both θ (weight means) and log α (dropout rates). The reparameterization trick provides low-variance gradient estimates:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch import Tensorfrom typing import Tuple class VariationalDropoutLinear(nn.Module): """ PyTorch implementation of Variational Dropout linear layer. """ def __init__( self, in_features: int, out_features: int, log_alpha_init: float = -5.0, threshold: float = 3.0 ): super().__init__() self.in_features = in_features self.out_features = out_features self.threshold = threshold # Weight mean parameters self.theta = nn.Parameter(torch.randn(in_features, out_features) * 0.01) # Log alpha parameters (learnable dropout rates) self.log_alpha = nn.Parameter(torch.full((in_features, out_features), log_alpha_init)) # Bias self.bias = nn.Parameter(torch.zeros(out_features)) # KL constants self.k1 = 0.63576 self.k2 = 1.87320 self.k3 = 1.48695 @property def alpha(self) -> Tensor: return torch.exp(torch.clamp(self.log_alpha, -10, 10)) def forward(self, x: Tensor) -> Tensor: """ Forward pass using local reparameterization. """ # Compute mean output mu = F.linear(x, self.theta.t(), self.bias) if not self.training: return mu # Compute variance output sigma_sq = F.linear(x.pow(2), (self.alpha * self.theta.pow(2)).t()) sigma = torch.sqrt(sigma_sq + 1e-8) # Sample eps = torch.randn_like(mu) return mu + sigma * eps def kl_divergence(self) -> Tensor: """Compute KL divergence for this layer.""" log_alpha = torch.clamp(self.log_alpha, -10, 10) sigmoid_term = self.k1 * torch.sigmoid(self.k2 + self.k3 * log_alpha) softplus_term = 0.5 * F.softplus(-log_alpha) return -torch.sum(sigmoid_term - softplus_term - self.k1) def sparsity(self) -> float: """Fraction of prunable weights.""" with torch.no_grad(): return (self.log_alpha > self.threshold).float().mean().item() class VariationalDropoutNetwork(nn.Module): """Complete network with variational dropout layers.""" def __init__( self, input_dim: int, hidden_dims: list, output_dim: int, log_alpha_init: float = -5.0 ): super().__init__() layers = [] dims = [input_dim] + hidden_dims + [output_dim] for i in range(len(dims) - 1): layers.append(VariationalDropoutLinear( dims[i], dims[i+1], log_alpha_init )) if i < len(dims) - 2: # No activation after last layer layers.append(nn.ReLU()) self.layers = nn.ModuleList(layers) self.vd_layers = [l for l in self.layers if isinstance(l, VariationalDropoutLinear)] def forward(self, x: Tensor) -> Tensor: for layer in self.layers: x = layer(x) return x def total_kl(self) -> Tensor: """Total KL divergence across all VD layers.""" return sum(layer.kl_divergence() for layer in self.vd_layers) def total_sparsity(self) -> float: """Average sparsity across all VD layers.""" return sum(l.sparsity() for l in self.vd_layers) / len(self.vd_layers) def train_variational_dropout(): """Training loop for variational dropout network.""" torch.manual_seed(42) print("Variational Dropout Training") print("=" * 60) # Create network model = VariationalDropoutNetwork( input_dim=784, hidden_dims=[400, 200], output_dim=10 ) # Optimizer optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Simulated training data X_train = torch.randn(1000, 784) y_train = torch.randint(0, 10, (1000,)) batch_size = 100 n_epochs = 20 kl_weight = 1.0 / 1000 # Scale KL by 1/N print("\nTraining progress:") print(f"{'Epoch':<8} {'Loss':<12} {'NLL':<12} {'KL':<12} {'Sparsity':<12}") print("-" * 60) for epoch in range(n_epochs): model.train() epoch_nll = 0 epoch_kl = 0 for i in range(0, len(X_train), batch_size): X_batch = X_train[i:i+batch_size] y_batch = y_train[i:i+batch_size] optimizer.zero_grad() # Forward pass logits = model(X_batch) # NLL loss (cross-entropy) nll = F.cross_entropy(logits, y_batch) # KL divergence kl = model.total_kl() # Total loss loss = nll + kl_weight * kl # Backward loss.backward() optimizer.step() epoch_nll += nll.item() epoch_kl += kl.item() # Epoch statistics avg_nll = epoch_nll / (len(X_train) // batch_size) avg_kl = epoch_kl / (len(X_train) // batch_size) sparsity = model.total_sparsity() if epoch % 4 == 0 or epoch == n_epochs - 1: print(f"{epoch+1:<8} {avg_nll + kl_weight*avg_kl:<12.4f} " f"{avg_nll:<12.4f} {avg_kl:<12.1f} {sparsity:<12.1%}") # Final statistics print("\n" + "-" * 60) print("Final layer statistics:") for i, layer in enumerate(model.vd_layers): print(f" Layer {i+1}: sparsity = {layer.sparsity():.1%}") print(f"\nTotal sparsity: {model.total_sparsity():.1%}") if __name__ == "__main__": train_variational_dropout()Variational dropout has inspired several important variants and extensions.
1. Structured Variational Dropout:
Instead of per-weight dropout rates, learn dropout rates for groups of weights:
Structured VD is simpler, lower overhead, and can prune entire neurons/filters.
2. Concrete Dropout:
Uses a continuous relaxation of Bernoulli dropout (via the Gumbel-Softmax trick). This allows learning binary dropout rates with gradient descent, bridging the gap between Gaussian and Bernoulli dropout.
3. Information Dropout:
Learns to drop information (via an information bottleneck) rather than neurons. The objective includes mutual information terms, leading to more principled feature selection.
4. Bayesian Compression:
Extends sparse variational dropout with quantization for extreme model compression. Combines weight pruning with low-bit weight quantization.
| Variant | Key Idea | Advantage |
|---|---|---|
| Per-weight VD | Individual α per weight | Maximum flexibility |
| Per-neuron VD | Shared α per neuron | Prunes entire neurons |
| Per-filter VD | Shared α per conv filter | Structured sparsity for CNNs |
| Concrete Dropout | Continuous Bernoulli relaxation | Exact α interpretation |
| Information Dropout | Information-theoretic objective | Principled feature selection |
Variational dropout can be viewed as a soft version of neural architecture search. Instead of searching over discrete architectures, VD learns which parts of a large network to keep or remove. The final sparse network is an 'architecture' discovered through training.
When should you use variational dropout? What are the trade-offs?
When to Use Variational Dropout:
Uncertainty is important: When you need calibrated uncertainty estimates, VD provides better uncertainties than standard dropout
Model compression is a goal: If you want to discover a smaller, sparser model from a large overparameterized one
Automatic feature selection: When you don't know which input features or hidden units are important
Bayesian inference with scalability: When you want approximate Bayesian inference without the cost of MCMC or full VI
When NOT to Use:
Simple regularization: If you just need regularization, standard dropout is simpler and faster
Speed-critical training: VD is slower due to variance computation and KL terms
Modern architectures: ResNets and Transformers may not benefit much; they use different regularization
Start with standard dropout for regularization. Consider variational dropout when you specifically need: (1) uncertainty quantification, (2) model compression via learned sparsity, or (3) automatic feature relevance discovery. For most applications, standard dropout with MC inference is sufficient.
Variational dropout extends the Bayesian interpretation of dropout to learn optimal dropout rates for each weight. Let's consolidate the key insights:
What's Next:
In the final page of this module, we explore dropping other things—extensions of the dropout concept beyond neuron activations. We'll cover DropConnect (dropping weights), Spatial Dropout (dropping feature maps), DropBlock, and other creative applications of the dropout principle.
You now understand variational dropout—how to learn optimal dropout rates and leverage this for extreme sparsity and better uncertainty. The key insight: by treating dropout rates as learnable parameters within the variational inference framework, we can automatically discover which parts of a network are necessary.