Loading learning content...
Standard neural networks produce point predictions without quantifying uncertainty. When a network classifies an image as "cat" with 99% confidence, we have no way to know if this confidence is well-calibrated or if the network is simply overconfident on out-of-distribution inputs.
Bayesian deep learning addresses this by treating network weights as random variables with posterior distributions. Instead of learning a single weight configuration, we learn a distribution over weights, enabling principled uncertainty quantification. Variational inference makes this tractable for modern deep networks with millions of parameters.
By the end of this page, you will understand Bayesian neural networks and why they matter, master variational weight posteriors and their parameterizations, learn practical techniques like MC Dropout and last-layer Bayesian methods, understand the computational tradeoffs, and know when to use Bayesian approaches in deep learning.
Standard deep learning finds point estimates of weights w* that minimize a loss:
$$w^* = \arg\min_w \mathcal{L}(w; \mathcal{D})$$
This ignores important questions:
| Type | Also Called | Source | Reducible? | Example |
|---|---|---|---|---|
| Aleatoric | Data uncertainty | Noise in observations | No | Sensor noise, label ambiguity |
| Epistemic | Model uncertainty | Limited training data | Yes (with more data) | Untrained regions of input space |
The Bayesian Approach:
Instead of point estimates, maintain a posterior over weights:
$$p(w|\mathcal{D}) = \frac{p(\mathcal{D}|w)p(w)}{p(\mathcal{D})}$$
For prediction on new input x*:
$$p(y|x^, \mathcal{D}) = \int p(y|x^, w) p(w|\mathcal{D}) dw$$
This integral over all possible weight configurations naturally captures uncertainty. If many weight configurations agree, we're confident. If they disagree, we're uncertain.
Benefits:
The posterior p(w|D) is intractable for neural networks—the integral has millions of dimensions, the likelihood landscape is complex and multimodal. This is where VI becomes essential: we approximate p(w|D) with a tractable q(w) and optimize the ELBO.
Variational BNNs approximate the posterior over weights with a parametric family, typically factorized Gaussians.
The ELBO for BNNs:
$$\mathcal{L}(\phi) = \mathbb{E}{q\phi(w)}\left[\log p(\mathcal{D}|w)\right] - D_{KL}(q_\phi(w) | p(w))$$
Mean-Field Approximation:
The simplest choice is a factorized Gaussian:
$$q_\phi(w) = \prod_{i} \mathcal{N}(w_i; \mu_i, \sigma_i^2)$$
Each weight has its own mean μᵢ and standard deviation σᵢ. This doubles the number of parameters but enables efficient training.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
import torchimport torch.nn as nnimport torch.nn.functional as Fimport math class VariationalLinear(nn.Module): """ Variational (Bayesian) linear layer. Weights are drawn from learned Gaussian distributions. """ def __init__(self, in_features, out_features, prior_std=1.0): super().__init__() self.in_features = in_features self.out_features = out_features # Prior: N(0, prior_std^2) self.prior_std = prior_std self.prior_log_std = math.log(prior_std) # Variational parameters for weights self.weight_mu = nn.Parameter( torch.empty(out_features, in_features).normal_(0, 0.1) ) self.weight_log_std = nn.Parameter( torch.empty(out_features, in_features).fill_(-3) # Start with low variance ) # Variational parameters for bias self.bias_mu = nn.Parameter(torch.zeros(out_features)) self.bias_log_std = nn.Parameter(torch.full((out_features,), -3.0)) def forward(self, x, sample=True): """ Forward pass with weight sampling. Args: x: Input tensor sample: If True, sample weights; if False, use means """ if sample and self.training: # Sample weights using reparameterization trick weight_std = torch.exp(self.weight_log_std) weight = self.weight_mu + weight_std * torch.randn_like(weight_std) bias_std = torch.exp(self.bias_log_std) bias = self.bias_mu + bias_std * torch.randn_like(bias_std) else: # Use mean weights (MAP-like) weight = self.weight_mu bias = self.bias_mu return F.linear(x, weight, bias) def kl_divergence(self): """ Compute KL(q(w) || p(w)) for this layer. Both are Gaussians, so we have closed form. """ # KL for weights weight_var = torch.exp(2 * self.weight_log_std) prior_var = self.prior_std ** 2 kl_weights = 0.5 * ( weight_var / prior_var + self.weight_mu.pow(2) / prior_var - 1 - 2 * self.weight_log_std + 2 * self.prior_log_std ).sum() # KL for bias bias_var = torch.exp(2 * self.bias_log_std) kl_bias = 0.5 * ( bias_var / prior_var + self.bias_mu.pow(2) / prior_var - 1 - 2 * self.bias_log_std + 2 * self.prior_log_std ).sum() return kl_weights + kl_bias class VariationalBNN(nn.Module): """ Variational Bayesian Neural Network for regression/classification. """ def __init__(self, input_dim, hidden_dims, output_dim, prior_std=1.0): super().__init__() layers = [] prev_dim = input_dim for h_dim in hidden_dims: layers.append(VariationalLinear(prev_dim, h_dim, prior_std)) layers.append(nn.ReLU()) prev_dim = h_dim layers.append(VariationalLinear(prev_dim, output_dim, prior_std)) self.layers = nn.ModuleList(layers) def forward(self, x, sample=True): for layer in self.layers: if isinstance(layer, VariationalLinear): x = layer(x, sample=sample) else: x = layer(x) return x def kl_divergence(self): """Total KL divergence across all variational layers.""" kl = 0.0 for layer in self.layers: if isinstance(layer, VariationalLinear): kl += layer.kl_divergence() return kl def elbo_loss(self, x, y, num_samples=1, kl_weight=1.0): """ Compute negative ELBO for training. Args: x, y: Input and target num_samples: Number of weight samples for likelihood estimation kl_weight: Weight on KL term (for KL annealing) """ # Estimate expected log-likelihood log_liks = [] for _ in range(num_samples): pred = self.forward(x, sample=True) # Assuming regression with Gaussian likelihood log_lik = -0.5 * (pred - y).pow(2).sum() log_liks.append(log_lik) expected_log_lik = torch.stack(log_liks).mean() # KL divergence kl = self.kl_divergence() # Negative ELBO loss = -expected_log_lik + kl_weight * kl return loss, { 'log_likelihood': expected_log_lik.item(), 'kl': kl.item(), }Full variational BNNs are computationally expensive. Several practical approximations provide uncertainty estimates with much lower cost.
MC Dropout:
A remarkable result: standard dropout at test time approximates Bayesian inference.
Procedure:
Theoretically, this approximates a Gaussian process posterior. Computationally, it's nearly free—just requires multiple forward passes.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
import torchimport torch.nn as nn class MCDropoutModel(nn.Module): """ Standard neural network that uses MC Dropout for uncertainty estimation. """ def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.1): super().__init__() layers = [] prev_dim = input_dim for h_dim in hidden_dims: layers.extend([ nn.Linear(prev_dim, h_dim), nn.ReLU(), nn.Dropout(dropout_rate), # Dropout before each layer ]) prev_dim = h_dim layers.append(nn.Linear(prev_dim, output_dim)) self.network = nn.Sequential(*layers) def forward(self, x): return self.network(x) def predict_with_uncertainty(self, x, num_samples=50): """ Make predictions with uncertainty estimates using MC Dropout. Args: x: Input tensor num_samples: Number of forward passes with different dropout masks Returns: mean: Mean prediction std: Standard deviation (epistemic uncertainty) """ self.train() # Enable dropout predictions = [] with torch.no_grad(): for _ in range(num_samples): pred = self.forward(x) predictions.append(pred) predictions = torch.stack(predictions) # (num_samples, batch, output_dim) mean = predictions.mean(dim=0) std = predictions.std(dim=0) return mean, std def get_classification_uncertainty(self, x, num_samples=50): """ For classification, compute predictive entropy and mutual information. Returns: probs: Mean class probabilities predictive_entropy: Total uncertainty mutual_info: Epistemic uncertainty (model uncertainty) """ self.train() logits_samples = [] with torch.no_grad(): for _ in range(num_samples): logits = self.forward(x) logits_samples.append(logits) logits_samples = torch.stack(logits_samples) probs_samples = torch.softmax(logits_samples, dim=-1) # Mean probabilities mean_probs = probs_samples.mean(dim=0) # Predictive entropy: H[E_w[p(y|x,w)]] predictive_entropy = -(mean_probs * torch.log(mean_probs + 1e-10)).sum(dim=-1) # Expected entropy: E_w[H[p(y|x,w)]] sample_entropies = -(probs_samples * torch.log(probs_samples + 1e-10)).sum(dim=-1) expected_entropy = sample_entropies.mean(dim=0) # Mutual information = Predictive entropy - Expected entropy # This measures epistemic (model) uncertainty mutual_info = predictive_entropy - expected_entropy return { 'probs': mean_probs, 'predictive_entropy': predictive_entropy, 'epistemic_uncertainty': mutual_info, 'aleatoric_uncertainty': expected_entropy, }Last-Layer Bayesian Methods:
A cost-effective approach: keep all layers deterministic except the last one.
This is particularly effective because the feature extractor captures data structure while the Bayesian output layer quantifies prediction uncertainty.
| Method | Training Cost | Inference Cost | Uncertainty Quality | Ease of Use |
|---|---|---|---|---|
| Full Variational BNN | High (2x params) | Medium | High | Medium |
| MC Dropout | None (standard) | Linear in samples | Medium | Very Easy |
| Last-Layer Bayesian | Low | Low-Medium | Medium | Easy |
| Deep Ensembles | High (K models) | Linear in K | High | Easy |
| SWAG | Low (store stats) | Medium | Medium-High | Medium |
Deep Ensembles are a simple but powerful approach to uncertainty: train multiple networks independently and aggregate their predictions.
Procedure:
Why It Works:
Different initializations lead to different local optima. These optima represent different hypotheses about the data. The ensemble effectively samples from a posterior over functions, though without the formal Bayesian framework.
Connection to Bayesian Inference:
While not formally Bayesian, ensembles approximate the integral: $$p(y|x^, \mathcal{D}) \approx \frac{1}{K} \sum_{k=1}^{K} p(y|x^, w_k)$$
where each w_k is a mode of the posterior.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
import torchimport torch.nn as nnfrom typing import List class DeepEnsemble: """ Deep Ensemble for uncertainty estimation. Trains multiple networks and aggregates predictions. """ def __init__(self, model_class, model_kwargs, num_members=5): """ Args: model_class: Neural network class model_kwargs: Arguments to instantiate models num_members: Number of ensemble members """ self.num_members = num_members self.models = [model_class(**model_kwargs) for _ in range(num_members)] def train_member(self, member_idx, train_loader, optimizer, loss_fn, epochs): """Train a single ensemble member.""" model = self.models[member_idx] model.train() for epoch in range(epochs): for x, y in train_loader: optimizer.zero_grad() pred = model(x) loss = loss_fn(pred, y) loss.backward() optimizer.step() def predict(self, x, return_individual=False): """ Make predictions with uncertainty. Returns: mean: Ensemble mean prediction std: Standard deviation across members (epistemic uncertainty) individual: (optional) Predictions from each member """ predictions = [] for model in self.models: model.eval() with torch.no_grad(): pred = model(x) predictions.append(pred) predictions = torch.stack(predictions) # (K, batch, output_dim) mean = predictions.mean(dim=0) std = predictions.std(dim=0) if return_individual: return mean, std, predictions return mean, std def predict_proba(self, x): """ For classification: get mean probabilities and uncertainty. """ logits_list = [] for model in self.models: model.eval() with torch.no_grad(): logits = model(x) logits_list.append(logits) logits = torch.stack(logits_list) probs = torch.softmax(logits, dim=-1) mean_probs = probs.mean(dim=0) # Predictive entropy pred_entropy = -(mean_probs * torch.log(mean_probs + 1e-10)).sum(dim=-1) # Expected entropy (aleatoric) member_entropies = -(probs * torch.log(probs + 1e-10)).sum(dim=-1) expected_entropy = member_entropies.mean(dim=0) # Mutual information (epistemic) mutual_info = pred_entropy - expected_entropy return { 'probs': mean_probs, 'uncertainty': pred_entropy, 'epistemic': mutual_info, 'aleatoric': expected_entropy, }Use 5-10 ensemble members for a good tradeoff between uncertainty quality and compute. Diversify training with different data shuffling, augmentation, or even architectures. For classification, combine with temperature scaling for calibration.
Modern neural networks have millions to billions of parameters. Scaling VI to these sizes requires careful approximations.
Challenges:
Solutions:
SWAG (Stochastic Weight Averaging Gaussian):
A practical approach that extracts a posterior from standard training:
SWAG is cheap (only stores running statistics) and provides surprisingly good uncertainty estimates.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
import torchimport torch.nn as nnimport copy class SWAG: """ Stochastic Weight Averaging Gaussian. Approximates the posterior by fitting a Gaussian to SGD iterates. """ def __init__(self, base_model, max_rank=20): """ Args: base_model: Trained neural network max_rank: Rank of low-rank covariance component """ self.base_model = base_model self.max_rank = max_rank # Running mean self.mean = self._get_params_flat(base_model) # Running second moment (for diagonal variance) self.sq_mean = self.mean ** 2 # Low-rank deviations self.deviations = [] self.n_snapshots = 1 def _get_params_flat(self, model): """Flatten all parameters into a single vector.""" return torch.cat([p.data.view(-1) for p in model.parameters()]) def _set_params_flat(self, model, flat_params): """Set model parameters from flat vector.""" offset = 0 for p in model.parameters(): numel = p.numel() p.data.copy_(flat_params[offset:offset + numel].view(p.shape)) offset += numel def collect(self, model): """Collect a weight snapshot during training.""" params = self._get_params_flat(model) # Update mean and squared mean self.n_snapshots += 1 n = self.n_snapshots self.mean = (self.mean * (n - 1) + params) / n self.sq_mean = (self.sq_mean * (n - 1) + params ** 2) / n # Store deviation for low-rank component deviation = params - self.mean self.deviations.append(deviation) # Keep only max_rank deviations if len(self.deviations) > self.max_rank: self.deviations.pop(0) def sample(self, scale=1.0): """ Sample a model from the SWAG posterior. Args: scale: Scale factor for the variance Returns: Model with sampled weights """ model = copy.deepcopy(self.base_model) # Diagonal variance var = torch.clamp(self.sq_mean - self.mean ** 2, min=1e-10) std = torch.sqrt(var) # Sample from diagonal sample = self.mean + scale * std * torch.randn_like(self.mean) # Add low-rank perturbation if len(self.deviations) > 0: D = torch.stack(self.deviations, dim=1) # (num_params, num_deviations) z = torch.randn(len(self.deviations), device=D.device) low_rank_sample = (D @ z) / (len(self.deviations) ** 0.5) sample = sample + scale * low_rank_sample self._set_params_flat(model, sample) return model def predict(self, x, num_samples=30, scale=1.0): """Make predictions with SWAG uncertainty.""" predictions = [] for _ in range(num_samples): model = self.sample(scale=scale) model.eval() with torch.no_grad(): pred = model(x) predictions.append(pred) predictions = torch.stack(predictions) return predictions.mean(dim=0), predictions.std(dim=0)Bayesian approaches aren't always necessary. Here's guidance on when they provide value:
Use Bayesian Methods When:
Standard Methods May Suffice When:
Start with the cheapest method that might work: MC Dropout or deep ensembles. If uncertainty quality is insufficient, try SWAG or last-layer Bayesian methods. Only invest in full variational BNNs when simpler methods demonstrably fail and the application truly requires high-quality uncertainty.
We've explored how variational inference enables uncertainty quantification in deep neural networks. Here are the key takeaways:
Congratulations! You've completed the Advanced VI Topics module. You now have a comprehensive understanding of modern variational inference techniques including normalizing flows, amortized inference, implicit VI, hierarchical models, and Bayesian deep learning. These tools form the foundation for state-of-the-art probabilistic modeling.