Loading content...
Multi-class classification extends binary classification to scenarios where inputs belong to exactly one of $K > 2$ mutually exclusive classes. Recognizing which digit (0-9) appears in an image, identifying the language of a text snippet, classifying species of flowers, or predicting the next word in a sequence—these are all multi-class problems.
The output layer must produce a probability distribution over $K$ classes: a vector of non-negative values that sum to one. This requires the softmax function, which generalizes the sigmoid to multiple outputs, and the categorical cross-entropy loss, which generalizes binary cross-entropy.
This page provides rigorous treatment of the mathematical foundations, numerical stability considerations, temperature scaling for calibration, and the connection to information theory. Understanding these concepts deeply enables not just correct implementation, but informed design decisions for novel prediction tasks.
This page covers: the softmax function derivation and properties, categorical cross-entropy as maximum likelihood, the log-sum-exp trick for numerical stability, temperature scaling for calibration, label smoothing regularization, the connection between softmax and exponential family distributions, and multi-class architecture patterns.
The softmax function transforms a vector of $K$ real numbers (logits) into a probability distribution:
$$\text{softmax}(\mathbf{z})i = \frac{e^{z_i}}{\sum{j=1}^K e^{z_j}}$$
where $\mathbf{z} = (z_1, \ldots, z_K)$ are the logits (pre-activation outputs from the final linear layer), and the result is a vector of probabilities $(p_1, \ldots, p_K)$ with $\sum_i p_i = 1$ and $p_i > 0$ for all $i$.
Key properties of softmax:
Normalization: Outputs sum to exactly 1 (valid probability distribution)
Strictly positive: Every output is positive, even for very negative logits
Monotonicity: Higher logits → higher probabilities (ranking preserved)
Translation invariance: $\text{softmax}(\mathbf{z} + c) = \text{softmax}(\mathbf{z})$ for any scalar $c$. This is exploited for numerical stability.
Sensitivity to scale: Multiplying logits by a constant $\tau$ changes the 'sharpness' of the distribution (temperature scaling)
Gradient: $\frac{\partial p_i}{\partial z_j} = p_i (\delta_{ij} - p_j)$, where $\delta_{ij}$ is the Kronecker delta
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
import numpy as npimport torchimport torch.nn.functional as F def softmax_naive(z): """ Naive softmax implementation. WARNING: Numerically unstable for large values! """ exp_z = np.exp(z) return exp_z / np.sum(exp_z) def softmax_stable(z): """ Numerically stable softmax using translation invariance. Key insight: softmax(z - max(z)) = softmax(z) But subtracting max prevents exp overflow. """ z_shifted = z - np.max(z) exp_z = np.exp(z_shifted) return exp_z / np.sum(exp_z) def softmax_with_temperature(z, temperature=1.0): """ Temperature-scaled softmax. temperature > 1: Softer distribution (more uniform) temperature < 1: Sharper distribution (more peaked) temperature → 0: Approaches argmax (one-hot) temperature → ∞: Approaches uniform distribution """ z_scaled = z / temperature return softmax_stable(z_scaled) # Demonstrate instabilityprint("=== Numerical Stability Demo ===")z_normal = np.array([1.0, 2.0, 3.0])z_large = np.array([1000.0, 2000.0, 3000.0]) print(f"Normal logits: {softmax_naive(z_normal)}")print(f"Large logits (naive): {softmax_naive(z_large)}") # Will show nan!print(f"Large logits (stable): {softmax_stable(z_large)}") # Temperature scaling demonstrationprint("\n=== Temperature Scaling Demo ===")logits = np.array([2.0, 1.0, 0.5, 0.0, -1.0])for temp in [0.1, 0.5, 1.0, 2.0, 10.0]: probs = softmax_with_temperature(logits, temp) entropy = -np.sum(probs * np.log(probs + 1e-10)) print(f"T={temp:4.1f}: {probs.round(3)}, entropy={entropy:.2f}") # PyTorch comparisonz_torch = torch.tensor([[-1.0, 0.0, 1.0, 2.0]])print(f"\nPyTorch softmax: {F.softmax(z_torch, dim=-1)}")print(f"PyTorch log_softmax: {F.log_softmax(z_torch, dim=-1)}")When computing cross-entropy, we need log(softmax(z)). Computing softmax then log is wasteful and numerically fragile. Log-softmax computes the log directly using the log-sum-exp trick: log_softmax(z)ᵢ = zᵢ - log(Σⱼ exp(zⱼ)), which is both faster and more stable.
Connection to sigmoid:
For binary classification ($K = 2$), softmax reduces to sigmoid. If $\mathbf{z} = (z_1, z_2)$:
$$p_1 = \frac{e^{z_1}}{e^{z_1} + e^{z_2}} = \frac{1}{1 + e^{-(z_1 - z_2)}} = \sigma(z_1 - z_2)$$
This is why binary classification uses a single output unit predicting $z_1 - z_2$ rather than two outputs with softmax—they're mathematically equivalent, but one unit is more efficient.
The softmax bottleneck:
Softmax has an inherent limitation: it can only represent at most $\min(K, d)$ truly distinct probability patterns, where $d$ is the dimension of the hidden state before the output layer. This 'softmax bottleneck' matters for tasks like language modeling where the vocabulary $K$ is very large. Solutions include mixture of softmax models or breaking the single softmax into hierarchical structures.
Just as binary classification models a Bernoulli distribution, multi-class classification models a Categorical distribution (also called multinoulli or generalized Bernoulli). A single draw from a categorical distribution with $K$ categories and probability vector $\mathbf{p} = (p_1, \ldots, p_K)$ produces one of $K$ outcomes:
$$p(Y = k | \mathbf{x}) = p_k = \text{softmax}(z_k)$$
For a one-hot encoded label vector $\mathbf{y}$ where $y_k = 1$ for the true class and $y_j = 0$ otherwise:
$$p(\mathbf{y} | \mathbf{x}) = \prod_{k=1}^K p_k^{y_k}$$
Since exactly one $y_k = 1$, this simplifies to:
$$p(\mathbf{y} | \mathbf{x}) = p_c$$
where $c$ is the true class index.
Maximum likelihood leads to cross-entropy:
To find optimal network parameters, we maximize the likelihood of observed data:
$$\mathcal{L}(\theta) = \prod_{i=1}^n p(y_i | x_i; \theta)$$
Taking the negative log and averaging:
$$-\frac{1}{n} \log \mathcal{L} = -\frac{1}{n} \sum_i \log p_{c_i}$$
This is exactly categorical cross-entropy loss.
| Property | Formula | Interpretation |
|---|---|---|
| PMF | $p(Y=k) = p_k$ | Direct probability of class $k$ |
| Mean (one-hot) | $E[\mathbf{Y}] = \mathbf{p}$ | Expected one-hot vector equals probability vector |
| Variance | $\text{Var}(Y_k) = p_k(1-p_k)$ | Variance of each indicator variable |
| Covariance | $\text{Cov}(Y_i, Y_j) = -p_i p_j$ for $i \neq j$ | Indicators are negatively correlated (exactly one is 1) |
| Entropy | $H = -\sum_k p_k \log p_k$ | Uncertainty in the distribution; maximum when uniform |
| Support | ${1, 2, \ldots, K}$ | $K$ mutually exclusive outcomes |
Multi-class classification assumes exactly one class is correct per example. If examples can belong to multiple classes simultaneously, you need multi-label classification (covered in the next page), which uses independent sigmoids instead of softmax. Confusing these leads to incorrect models.
Categorical cross-entropy (also called softmax cross-entropy or log loss) measures the dissimilarity between the true label distribution and the predicted probability distribution:
$$\mathcal{L}{\text{CE}} = -\frac{1}{n} \sum{i=1}^n \sum_{k=1}^K y_{ik} \log(p_{ik})$$
Since $\mathbf{y}i$ is one-hot (only $y{i,c_i} = 1$), this simplifies to:
$$\mathcal{L}{\text{CE}} = -\frac{1}{n} \sum{i=1}^n \log(p_{i,c_i})$$
where $c_i$ is the true class for sample $i$ and $p_{i,c_i}$ is the predicted probability for that class.
Information-theoretic interpretation:
Cross-entropy measures the expected number of bits needed to encode labels using the predicted distribution $\hat{p}$ when the true distribution is $p$:
$$H(p, \hat{p}) = -\sum_k p_k \log \hat{p}k = H(p) + D{KL}(p || \hat{p})$$
where $H(p)$ is the true entropy (fixed) and $D_{KL}$ is the KL divergence. Minimizing cross-entropy is equivalent to minimizing KL divergence from predictions to truth.
Gradient of cross-entropy with respect to logits:
$$\frac{\partial \mathcal{L}}{\partial z_k} = p_k - y_k$$
This is the residual—identical in form to the binary case! The gradient is simply the difference between predicted probability and target. For the true class, this is $p_{\text{true}} - 1 = -(1 - p_{\text{true}})$; for other classes, it's just $p_k$.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as np def cross_entropy_manual(logits, targets): """ Manual cross-entropy computation. Shows the math but is NOT numerically optimal. Args: logits: [batch_size, num_classes] targets: [batch_size] (class indices) """ # Stable softmax via log-sum-exp log_sum_exp = torch.logsumexp(logits, dim=-1) log_probs = logits - log_sum_exp.unsqueeze(-1) # Select log probability of true class batch_size = logits.size(0) true_log_probs = log_probs[torch.arange(batch_size), targets] return -true_log_probs.mean() def cross_entropy_with_one_hot(logits, one_hot_targets): """ Cross-entropy using one-hot encoded targets. Equivalent to the formula: -sum(y * log(p)) """ log_probs = F.log_softmax(logits, dim=-1) return -(one_hot_targets * log_probs).sum(dim=-1).mean() # Create sample databatch_size = 4num_classes = 5logits = torch.randn(batch_size, num_classes)targets = torch.randint(0, num_classes, (batch_size,)) # One-hot encodingone_hot = F.one_hot(targets, num_classes).float() # Compare implementationsprint("=== Cross-Entropy Implementations ===")print(f"Manual: {cross_entropy_manual(logits, targets):.6f}")print(f"One-hot: {cross_entropy_with_one_hot(logits, one_hot):.6f}")print(f"PyTorch CE: {F.cross_entropy(logits, targets):.6f}")print(f"PyTorch NLL+LSM: {F.nll_loss(F.log_softmax(logits, dim=-1), targets):.6f}") # Gradient verificationlogits_grad = logits.clone().requires_grad_(True)loss = F.cross_entropy(logits_grad, targets)loss.backward() probs = F.softmax(logits, dim=-1)expected_grad = (probs - one_hot) / batch_size print(f"\n=== Gradient Verification ===")print(f"Computed gradient (first sample):\n{logits_grad.grad[0]}")print(f"Expected (p - y):\n{expected_grad[0]}")print(f"Match: {torch.allclose(logits_grad.grad, expected_grad, atol=1e-5)}")PyTorch's CrossEntropyLoss combines log_softmax and NLLLoss in a numerically stable way. Don't apply softmax then log then NLL separately—this loses precision and is slower. Always pass raw logits to CrossEntropyLoss.
Computing softmax and cross-entropy naively can lead to overflow (exp of large positive numbers) or underflow (exp of large negative numbers followed by log of tiny numbers). The solution is the log-sum-exp (LSE) trick.
The problem:
For logits like $\mathbf{z} = (1000, 1001, 1002)$:
The solution: shift logits by their maximum
Due to translation invariance: $$\text{softmax}(\mathbf{z}) = \text{softmax}(\mathbf{z} - \max_j z_j)$$
After shifting, the largest logit is 0, so $e^0 = 1$ (no overflow), and other terms are $e^{\text{negative}}$ (bounded below by 0).
For log-softmax:
The log-sum-exp trick computes $\log(\sum_j e^{z_j})$ stably:
$$\text{LSE}(\mathbf{z}) = m + \log\left(\sum_j e^{z_j - m}\right)$$
where $m = \max_j z_j$. Then:
$$\log \text{softmax}(z_i) = z_i - \text{LSE}(\mathbf{z})$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
import numpy as npimport torch def logsumexp_naive(z): """Naive log-sum-exp. Overflows for large values!""" return np.log(np.sum(np.exp(z))) def logsumexp_stable(z): """ Stable log-sum-exp using the max trick. log(sum(exp(z))) = m + log(sum(exp(z - m))) This works because: - exp(z - m) <= exp(0) = 1 for all z (since m = max(z)) - At least one term equals 1 (when z_j = m) - No overflow possible """ m = np.max(z) return m + np.log(np.sum(np.exp(z - m))) def log_softmax_stable(z): """Stable log-softmax using LSE.""" return z - logsumexp_stable(z) def cross_entropy_stable(logits, target_idx): """ Stable cross-entropy from scratch. CE = -log(p_target) = -log(softmax(z)_target) = -(z_target - LSE(z)) = LSE(z) - z_target """ lse = logsumexp_stable(logits) return lse - logits[target_idx] # Demonstrate stabilityprint("=== Log-Sum-Exp Stability ===") # Small values: both workz_small = np.array([1.0, 2.0, 3.0])print(f"Small logits: {z_small}")print(f" Naive LSE: {logsumexp_naive(z_small):.6f}")print(f" Stable LSE: {logsumexp_stable(z_small):.6f}") # Large values: naive failsz_large = np.array([1000.0, 1001.0, 1002.0])print(f"\nLarge logits: {z_large}")print(f" Naive LSE: {logsumexp_naive(z_large)}") # inf or overflowprint(f" Stable LSE: {logsumexp_stable(z_large):.6f}") # Very negative valuesz_neg = np.array([-1000.0, -1001.0, -1002.0])print(f"\nNegative logits: {z_neg}")print(f" Stable LSE: {logsumexp_stable(z_neg):.6f}") # PyTorch built-inz_torch = torch.tensor([1000.0, 1001.0, 1002.0])print(f"\nPyTorch logsumexp: {torch.logsumexp(z_torch, dim=0):.6f}") # Cross-entropy examplelogits = np.array([2.0, 1.0, 0.5, 0.0])target = 0 # True classce = cross_entropy_stable(logits, target)print(f"\nCross-entropy for target={target}: {ce:.6f}")print(f"Probability of true class: {np.exp(-ce):.6f}")While understanding the math is crucial, always use framework-provided functions (torch.logsumexp, F.cross_entropy, F.log_softmax) in production. They're optimized for performance across CPU/GPU and handle edge cases. Implementing your own risks subtle bugs.
Modern neural networks often produce overconfident predictions. Even when wrong, they may output probabilities like 0.99. This miscalibration is problematic for decision-making, uncertainty quantification, and downstream applications.
Temperature scaling is a simple post-hoc calibration method that learns a single scalar $T > 0$ that divides the logits:
$$p_i = \text{softmax}(\mathbf{z}/T)_i = \frac{e^{z_i/T}}{\sum_j e^{z_j/T}}$$
The effect of temperature:
Calibration procedure:
This is remarkably effective—a single parameter often dramatically improves calibration without hurting accuracy.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npfrom scipy.optimize import minimize_scalar class TemperatureScaledModel(nn.Module): """ Wrapper that applies learned temperature scaling to a trained model. """ def __init__(self, base_model, initial_temperature=1.0): super().__init__() self.base_model = base_model # Temperature as a learnable parameter (log space for positivity) self.log_temperature = nn.Parameter( torch.log(torch.tensor(initial_temperature)) ) @property def temperature(self): return torch.exp(self.log_temperature) def forward(self, x): logits = self.base_model(x) return logits / self.temperature def calibrate(self, val_loader, max_iters=100, lr=0.01): """ Learn optimal temperature on validation data. Only temperature is updated; base model is frozen. """ # Freeze base model for param in self.base_model.parameters(): param.requires_grad = False optimizer = torch.optim.LBFGS([self.log_temperature], lr=lr, max_iter=max_iters) criterion = nn.CrossEntropyLoss() # Collect all validation data all_logits = [] all_labels = [] with torch.no_grad(): for x, y in val_loader: logits = self.base_model(x) all_logits.append(logits) all_labels.append(y) all_logits = torch.cat(all_logits) all_labels = torch.cat(all_labels) def closure(): optimizer.zero_grad() scaled_logits = all_logits / self.temperature loss = criterion(scaled_logits, all_labels) loss.backward() return loss optimizer.step(closure) print(f"Learned temperature: {self.temperature.item():.4f}") return self.temperature.item() def expected_calibration_error(probs, labels, n_bins=15): """ Compute Expected Calibration Error (ECE). ECE measures the average difference between confidence and accuracy across confidence bins. Lower is better. """ confidences = probs.max(dim=-1).values predictions = probs.argmax(dim=-1) accuracies = (predictions == labels).float() bin_boundaries = torch.linspace(0, 1, n_bins + 1) ece = torch.tensor(0.0) for i in range(n_bins): in_bin = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1]) prop_in_bin = in_bin.float().mean() if prop_in_bin > 0: avg_confidence = confidences[in_bin].mean() avg_accuracy = accuracies[in_bin].mean() ece += torch.abs(avg_confidence - avg_accuracy) * prop_in_bin return ece.item() # Demonstration with synthetic dataprint("=== Temperature Scaling Demo ===") # Simulate overconfident predictions# Logits that give ~98% confidence predictionsnp.random.seed(42)n_samples = 1000n_classes = 10 # True labelstrue_labels = torch.randint(0, n_classes, (n_samples,)) # Overconfident logits: correct class gets high logitlogits = torch.randn(n_samples, n_classes) * 0.5logits[torch.arange(n_samples), true_labels] += 4.0 # Make correct class dominant # But some are wrongn_wrong = 100wrong_idx = torch.randperm(n_samples)[:n_wrong]wrong_labels = (true_labels[wrong_idx] + torch.randint(1, n_classes, (n_wrong,))) % n_classeslogits[wrong_idx, wrong_labels] = logits[wrong_idx, true_labels[wrong_idx]] + 1 # Before temperature scalingprobs_before = F.softmax(logits, dim=-1)ece_before = expected_calibration_error(probs_before, true_labels)accuracy = (probs_before.argmax(dim=-1) == true_labels).float().mean() print(f"Accuracy: {accuracy:.4f}")print(f"ECE before scaling: {ece_before:.4f}")print(f"Mean confidence: {probs_before.max(dim=-1).values.mean():.4f}") # Find optimal temperaturedef nll_at_temp(temp): scaled_probs = F.softmax(logits / temp, dim=-1) log_probs = torch.log(scaled_probs + 1e-10) nll = -log_probs[torch.arange(n_samples), true_labels].mean() return nll.item() result = minimize_scalar(nll_at_temp, bounds=(0.1, 10.0), method='bounded')optimal_temp = result.x # After temperature scalingprobs_after = F.softmax(logits / optimal_temp, dim=-1)ece_after = expected_calibration_error(probs_after, true_labels) print(f"\nOptimal temperature: {optimal_temp:.4f}")print(f"ECE after scaling: {ece_after:.4f}")print(f"Mean confidence after: {probs_after.max(dim=-1).values.mean():.4f}")Temperature scaling is recommended whenever calibrated probabilities matter: medical diagnosis, autonomous driving, risk assessment, or any setting where 'how confident are we?' influences decisions. It's cheap (one parameter), effective, and doesn't hurt accuracy since it doesn't change the argmax.
Label smoothing is a regularization technique that replaces hard one-hot labels with soft targets. Instead of training on $y_k \in {0, 1}$, we use:
$$y_k^{\text{smooth}} = \begin{cases} 1 - \epsilon + \epsilon/K & \text{if } k = c \ \epsilon/K & \text{otherwise} \end{cases}$$
where $\epsilon$ is the smoothing parameter (typically 0.1) and $K$ is the number of classes.
Why label smoothing helps:
Prevents overconfidence: The model can't drive probabilities to exactly 0 or 1, keeping it calibrated
Implicit regularization: Encourages the model to not be infinitely confident, preventing logit explosion
Accounts for label noise: Real datasets have mislabeled examples; treating all labels as 100% certain is overconfident
Better generalization: Empirically improves test accuracy in many tasks, especially with limited data
Mathematical interpretation:
With label smoothing, we're minimizing:
$$\mathcal{L}{\text{LS}} = (1 - \epsilon)\mathcal{L}{\text{CE}} + \epsilon \mathcal{L}_{\text{uniform}}$$
A convex combination of cross-entropy with the true label and cross-entropy with a uniform distribution. This penalizes the KL divergence between predictions and the uniform prior.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
import torchimport torch.nn as nnimport torch.nn.functional as F class LabelSmoothingCrossEntropy(nn.Module): """ Cross-entropy loss with label smoothing. Instead of one-hot targets [0, 0, 1, 0, 0], uses soft targets [ε/K, ε/K, 1-ε+ε/K, ε/K, ε/K] """ def __init__(self, smoothing=0.1, reduction='mean'): super().__init__() self.smoothing = smoothing self.reduction = reduction def forward(self, logits, targets): n_classes = logits.size(-1) # Compute log-softmax log_probs = F.log_softmax(logits, dim=-1) # One-hot encoding of targets targets_one_hot = F.one_hot(targets, n_classes).float() # Smooth the labels smooth_targets = targets_one_hot * (1 - self.smoothing) + self.smoothing / n_classes # Cross-entropy with smooth targets loss = -(smooth_targets * log_probs).sum(dim=-1) if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() return loss class LabelSmoothingCrossEntropyEfficient(nn.Module): """ Efficient version that avoids explicit one-hot construction. The smoothed loss decomposes into two terms: 1. (1 - ε) * CE with true label 2. ε * mean of all log-probs (KL with uniform) """ def __init__(self, smoothing=0.1, reduction='mean'): super().__init__() self.smoothing = smoothing self.reduction = reduction def forward(self, logits, targets): log_probs = F.log_softmax(logits, dim=-1) # Term 1: CE with true class nll_loss = F.nll_loss(log_probs, targets, reduction='none') # Term 2: Mean log-prob (negative entropy encouragement) smooth_loss = -log_probs.mean(dim=-1) # Combine loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() return loss # Compare with PyTorch built-inprint("=== Label Smoothing Comparison ===") batch_size = 32n_classes = 10logits = torch.randn(batch_size, n_classes)targets = torch.randint(0, n_classes, (batch_size,)) # Our implementationsce_smooth_v1 = LabelSmoothingCrossEntropy(smoothing=0.1)ce_smooth_v2 = LabelSmoothingCrossEntropyEfficient(smoothing=0.1) # PyTorch built-in (PyTorch 1.10+)ce_pytorch = nn.CrossEntropyLoss(label_smoothing=0.1) # Standard CE for comparisonce_standard = nn.CrossEntropyLoss() print(f"Standard CE: {ce_standard(logits, targets):.6f}")print(f"Label Smoothing (v1): {ce_smooth_v1(logits, targets):.6f}")print(f"Label Smoothing (v2): {ce_smooth_v2(logits, targets):.6f}")print(f"PyTorch Label Smoothing: {ce_pytorch(logits, targets):.6f}") # Effect on training dynamicsprint("\n=== Effect on Confidence ===")# Model with high confidenceconfident_logits = torch.tensor([[10.0, 0.0, 0.0, 0.0, 0.0]])target = torch.tensor([0]) # Correct print(f"Predicted probs: {F.softmax(confident_logits, dim=-1)}")print(f"Standard CE loss: {ce_standard(confident_logits, target):.6f}")print(f"Smoothed CE loss: {ce_pytorch(confident_logits, target):.6f}")print("Note: Smoothed loss is higher because it penalizes overconfidence")Let's consolidate a production-ready multi-class classification architecture that incorporates best practices:
Standard pattern:
Input → Hidden Layers → Linear(d, K) → [No Softmax] → CrossEntropyLoss
↓
Softmax at inference only
Key design decisions:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Optional, Tuple class MultiClassClassifier(nn.Module): """ Production-ready multi-class classification model. Design principles: 1. Output K logits (no softmax during training) 2. Use CrossEntropyLoss for training 3. Temperature scaling for calibration (optional) 4. Label smoothing for regularization (optional) """ def __init__( self, input_dim: int, num_classes: int, hidden_dims: list = [256, 128], dropout: float = 0.1, use_batch_norm: bool = True, ): super().__init__() self.num_classes = num_classes # Build feature extractor layers = [] prev_dim = input_dim for hidden_dim in hidden_dims: layers.append(nn.Linear(prev_dim, hidden_dim)) if use_batch_norm: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.ReLU()) layers.append(nn.Dropout(dropout)) prev_dim = hidden_dim self.features = nn.Sequential(*layers) # Classification head: K output units, NO activation self.classifier = nn.Linear(prev_dim, num_classes) # Temperature for calibration (learned post-training) self.register_buffer('temperature', torch.tensor(1.0)) def forward(self, x) -> torch.Tensor: """ Returns raw logits (NOT probabilities). Shape: [batch_size, num_classes] """ features = self.features(x) logits = self.classifier(features) return logits def predict_proba( self, x, use_temperature: bool = True ) -> torch.Tensor: """ Returns class probabilities. """ with torch.no_grad(): logits = self.forward(x) if use_temperature: logits = logits / self.temperature return F.softmax(logits, dim=-1) def predict(self, x) -> torch.Tensor: """Returns class predictions.""" probs = self.predict_proba(x) return probs.argmax(dim=-1) def set_temperature(self, temperature: float): """Set calibration temperature.""" self.temperature.fill_(temperature) class MultiClassTrainer: """ Trainer with full multi-class best practices. """ def __init__( self, model: MultiClassClassifier, learning_rate: float = 1e-3, label_smoothing: float = 0.0, class_weights: Optional[torch.Tensor] = None, ): self.model = model # Loss function with optional smoothing and weighting self.criterion = nn.CrossEntropyLoss( weight=class_weights, label_smoothing=label_smoothing, ) self.optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=1e-4, ) # Learning rate scheduler self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=100 ) def train_step(self, x, y) -> Tuple[float, float]: """Single training step.""" self.model.train() self.optimizer.zero_grad() logits = self.model(x) loss = self.criterion(logits, y) loss.backward() # Gradient clipping for stability torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() # Compute accuracy with torch.no_grad(): preds = logits.argmax(dim=-1) accuracy = (preds == y).float().mean().item() return loss.item(), accuracy @torch.no_grad() def evaluate(self, x, y) -> Tuple[float, float]: """Evaluate on data.""" self.model.eval() logits = self.model(x) loss = self.criterion(logits, y) preds = logits.argmax(dim=-1) accuracy = (preds == y).float().mean().item() return loss.item(), accuracy # Usage examplemodel = MultiClassClassifier( input_dim=784, # e.g., MNIST num_classes=10, hidden_dims=[512, 256], dropout=0.2,) # For imbalanced data, compute class weightsclass_counts = torch.tensor([1000, 500, 100, 50, 20, 800, 600, 400, 200, 100])class_weights = 1.0 / class_countsclass_weights = class_weights / class_weights.sum() * len(class_weights) trainer = MultiClassTrainer( model, learning_rate=1e-3, label_smoothing=0.1, # Use label smoothing class_weights=class_weights,) # Training iterationx_batch = torch.randn(64, 784)y_batch = torch.randint(0, 10, (64,)) loss, acc = trainer.train_step(x_batch, y_batch)print(f"Loss: {loss:.4f}, Accuracy: {acc:.4f}")When K is very large (e.g., language modeling with 50K+ vocabulary), computing the full softmax is expensive. Alternatives include: hierarchical softmax (O(log K) per sample), sampled softmax (approximate with negative sampling), or adaptive softmax (combine frequent and rare words differently).
Multi-class classification is a foundational capability in deep learning. The softmax function and cross-entropy loss form a principled pair derived from maximum likelihood estimation under the categorical distribution. Understanding their properties enables not just correct implementation, but informed design choices for novel problems.
You have mastered multi-class classification output design from mathematical foundations to production implementation. Next, we'll explore multi-label classification, where each example can belong to any subset of classes simultaneously, requiring independent sigmoid outputs rather than a single softmax.