Loading learning content...
If you've trained a neural network for classification—whether recognizing handwritten digits, classifying images, or predicting sentiment—you've almost certainly used cross-entropy loss. This loss function is so pervasive that it's often applied without much thought: "it's just what you use for classification."
But why cross-entropy? Why not mean squared error? Why not absolute difference between predicted and true probabilities? The answer lies deep in information theory and connects to fundamental questions about what it means to learn a probability distribution.
Cross-entropy H(P, Q) measures the average number of bits needed to encode data from distribution P when using a code optimized for distribution Q. When P is the true data distribution and Q is our model, minimizing cross-entropy means building a model whose predictions would require the fewest bits to encode the actual outcomes.
This page traces cross-entropy from its information-theoretic origins through its emergence as the canonical classification loss. By the end, you'll understand not just how to use cross-entropy, but why it uniquely captures what we want from a probabilistic classifier.
By the end of this page, you will understand cross-entropy's formal definition and information-theoretic interpretation, see exactly how it relates to entropy and KL divergence, connect cross-entropy loss to maximum likelihood estimation, and appreciate why cross-entropy is superior to naive alternatives for classification.
Recall that entropy H(P) measures the minimum average bits needed to encode samples from distribution P using an optimal code:
H(P) = −Σᵢ P(xᵢ) · log P(xᵢ)
But what if we don't know the true distribution P? What if we only have an approximation Q, and we design our code based on Q instead?
The answer is cross-entropy: the average bits needed to encode samples from P using a code optimized for Q.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
# Cross-Entropy Definition# ======================== H(P, Q) = -∑ᵢ P(xᵢ) · log Q(xᵢ) # In words:# - Sample outcomes according to the TRUE distribution P# - Use code lengths based on the MODEL distribution Q# - Compute the average code length # Comparison to Entropy:# H(P) = -∑ P(x) · log P(x) ← Using optimal P-based code# H(P, Q) = -∑ P(x) · log Q(x) ← Using suboptimal Q-based code # Key insight: H(P, Q) ≥ H(P) always# We pay a "penalty" for using the wrong distribution # Python Implementation:import numpy as np def cross_entropy(p, q, epsilon=1e-15): """ Compute cross-entropy H(P, Q). Args: p: True distribution (array of probabilities) q: Model distribution (array of probabilities) epsilon: Small constant to avoid log(0) Returns: Cross-entropy in bits (using log base 2) """ p = np.array(p) q = np.array(q) + epsilon # Avoid log(0) q = q / q.sum() # Renormalize after adding epsilon # Only sum over non-zero p values mask = p > 0 return -np.sum(p[mask] * np.log2(q[mask])) # Examplesp_true = [0.7, 0.2, 0.1] # True distribution # Model 1: Perfect matchq_exact = [0.7, 0.2, 0.1]print(f"H(P, Q_exact) = {cross_entropy(p_true, q_exact):.4f} bits") # Model 2: Slight mismatch q_close = [0.6, 0.3, 0.1]print(f"H(P, Q_close) = {cross_entropy(p_true, q_close):.4f} bits") # Model 3: Poor approximationq_poor = [0.33, 0.33, 0.34]print(f"H(P, Q_poor) = {cross_entropy(p_true, q_poor):.4f} bits") # For reference: entropy of Pdef entropy(p): p = np.array(p) p = p[p > 0] return -np.sum(p * np.log2(p)) print(f"H(P) = {entropy(p_true):.4f} bits (theoretical minimum)")Interpretation through coding:
Recall that in an optimal code, a symbol with probability p gets assigned approximately −log p bits. If the true probability is P(x) but we assign code length based on Q(x):
The asymmetry is important. Underestimating the probability of common events is much worse than overestimating rare ones. This explains why cross-entropy severely penalizes confident wrong predictions.
Cross-entropy decomposes as:
H(P, Q) = H(P) + D_KL(P || Q)
where D_KL(P || Q) is the KL divergence (next page). Since D_KL ≥ 0:
H(P, Q) ≥ H(P)
Equality holds if and only if Q = P. Cross-entropy exceeds entropy by exactly the KL divergence—the "extra bits" we pay for using the wrong distribution.
In machine learning, we typically don't know the true distribution P—that's what we're trying to learn! Instead, we observe samples from P (our training data) and try to learn a model Q that approximates P.
For classification:
The cross-entropy loss for a single example simplifies beautifully:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
# Cross-Entropy Loss for Classification# ====================================== # For a single training example:# - True label y (one-hot): only position k has value 1# - Predicted probabilities q = [q_1, q_2, ..., q_C] for C classes L = H(y, q) = -∑ᵢ yᵢ · log(qᵢ) = -log(q_k) # Only the true class k contributes! # This is just the negative log probability of the correct class. # For a batch of N examples:L_batch = -(1/N) ∑ⱼ log(q_j,yⱼ) # where q_j,yⱼ is the predicted probability for the true class of example j import numpy as npimport torchimport torch.nn.functional as F # Example: 3-class classification# Model outputs (logits, before softmax)logits = torch.tensor([[2.0, 1.0, 0.1], # Example 1 [0.5, 2.5, 0.3], # Example 2 [0.1, 0.2, 3.0]]) # Example 3 # True labels (class indices)labels = torch.tensor([0, 1, 2]) # Examples 1,2,3 have classes 0,1,2 # Method 1: Manual computationprobs = F.softmax(logits, dim=1)print("Predicted probabilities:")print(probs) # Pick out the probability of the true class for each exampletrue_class_probs = probs[range(3), labels]print(f"True class probabilities: {true_class_probs}") loss_manual = -torch.log(true_class_probs).mean()print(f"Manual cross-entropy loss: {loss_manual.item():.4f}") # Method 2: Using PyTorch's built-in (combines softmax + CE)loss_pytorch = F.cross_entropy(logits, labels)print(f"PyTorch cross-entropy loss: {loss_pytorch.item():.4f}") # They match!Why this formulation works:
Since the true distribution is one-hot, cross-entropy simplifies to −log(q_correct). The loss becomes:
The logarithm is crucial. If the model predicts q_correct = 0.01 for the true class, the loss is −log(0.01) ≈ 6.64. If it predicts q_correct = 0.001, the loss jumps to 9.97. The loss severely punishes confident wrong predictions, creating strong gradients for correction.
For binary classification (two classes), cross-entropy takes a special form known as binary cross-entropy (BCE) or log loss. With only two classes, we can parameterize the model with a single probability p ∈ [0, 1] for class 1, implying probability 1-p for class 0.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
# Binary Cross-Entropy (Log Loss)# ================================= # For a single example:# - y ∈ {0, 1}: true binary label# - p ∈ [0, 1]: predicted probability of class 1 L = -[y · log(p) + (1-y) · log(1-p)] # When y = 1: L = -log(p) → penalizes p < 1# When y = 0: L = -log(1-p) → penalizes p > 0 # For a batch of N examples:L_batch = -(1/N) ∑ᵢ [yᵢ · log(pᵢ) + (1-yᵢ) · log(1-pᵢ)] import numpy as npimport torchimport torch.nn.functional as F def binary_cross_entropy(y_true, y_pred, epsilon=1e-15): """ Compute binary cross-entropy loss. Args: y_true: True binary labels (0 or 1) y_pred: Predicted probabilities for class 1 epsilon: Small constant for numerical stability Returns: Binary cross-entropy loss """ y_pred = np.clip(y_pred, epsilon, 1 - epsilon) return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)) # Example predictionsy_true = np.array([1, 0, 1, 1, 0])y_pred = np.array([0.9, 0.2, 0.8, 0.7, 0.1]) print(f"True labels: {y_true}")print(f"Predictions: {y_pred}")print(f"BCE Loss: {binary_cross_entropy(y_true, y_pred):.4f}") # What if predictions were perfect?y_pred_perfect = np.array([1.0, 0.0, 1.0, 1.0, 0.0])print(f"Perfect predictions BCE: {binary_cross_entropy(y_true, np.clip(y_pred_perfect, 1e-15, 1-1e-15)):.4f}") # What if predictions were completely wrong?y_pred_wrong = np.array([0.1, 0.9, 0.1, 0.1, 0.9])print(f"Wrong predictions BCE: {binary_cross_entropy(y_true, y_pred_wrong):.4f}") # PyTorch comparisony_true_t = torch.tensor(y_true, dtype=torch.float32)y_pred_t = torch.tensor(y_pred, dtype=torch.float32)print(f"PyTorch BCE: {F.binary_cross_entropy(y_pred_t, y_true_t).item():.4f}")BCE from logits (numerically stable):
In practice, directly computing log(sigmoid(z)) can cause numerical issues for large |z|. Deep learning frameworks provide stable implementations that work with raw logits:
1234567891011121314151617181920212223242526272829303132333435363738
import torchimport torch.nn.functional as F def bce_with_logits_manual(logits, targets): """ Numerically stable binary cross-entropy from logits. Uses the identity: log(sigmoid(x)) = x - softplus(x) where softplus(x) = log(1 + exp(x)) This avoids computing sigmoid explicitly, preventing overflow. """ # For numerical stability, we use: # max(z, 0) - z*y + log(1 + exp(-abs(z))) max_val = torch.clamp(logits, min=0) loss = max_val - logits * targets + torch.log(1 + torch.exp(-torch.abs(logits))) return loss.mean() # Examplelogits = torch.tensor([5.0, -3.0, 2.0, 1.0, -5.0])targets = torch.tensor([1.0, 0.0, 1.0, 1.0, 0.0]) # Manual implementationloss_manual = bce_with_logits_manual(logits, targets)print(f"Manual BCE with logits: {loss_manual.item():.6f}") # PyTorch built-in (recommended)loss_pytorch = F.binary_cross_entropy_with_logits(logits, targets)print(f"PyTorch BCE with logits: {loss_pytorch.item():.6f}") # They match! But PyTorch's version handles edge cases better. # Extreme logits that would break naive implementationextreme_logits = torch.tensor([100.0, -100.0])extreme_targets = torch.tensor([1.0, 0.0])print(f"Extreme logits BCE: {F.binary_cross_entropy_with_logits(extreme_logits, extreme_targets).item():.6f}")# Works fine! Naive sigmoid + log would give inf or nanIn production code, always use F.cross_entropy(logits, labels) or F.binary_cross_entropy_with_logits(logits, targets) rather than manually computing softmax/sigmoid followed by log. The combined operations are numerically stable even for extreme values, while the two-step approach can produce inf or nan.
One of the most elegant aspects of cross-entropy is its equivalence to maximum likelihood estimation (MLE). This connection explains why cross-entropy is the principled choice for classification—it arises naturally from statistical first principles.
Maximum Likelihood Formulation:
Given training data {(xᵢ, yᵢ)}ᵢ₌₁ᴺ where yᵢ is the true class label, the likelihood of the data under model parameters θ is:
L(θ) = ∏ᵢ P(yᵢ | xᵢ; θ)
The log-likelihood (more convenient for optimization):
log L(θ) = Σᵢ log P(yᵢ | xᵢ; θ)
Maximum likelihood seeks parameters θ that maximize this quantity.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
# MLE and Cross-Entropy are Equivalent (up to constants)# ====================================================== # Negative Log-Likelihood (NLL):NLL = -log L(θ) = -∑ᵢ log P(yᵢ | xᵢ; θ) # For a softmax classifier, P(y=k | x; θ) = softmax(f_θ(x))_k # So: NLL = -∑ᵢ log(softmax(f_θ(xᵢ))_{yᵢ}) # This is EXACTLY the cross-entropy loss summed over the dataset! # The connection:# Cross-Entropy Loss = -E_P[log Q]# If P is the empirical distribution (1/N for each training sample),# then E_P[log Q] = (1/N) ∑ᵢ log Q(yᵢ | xᵢ) # Therefore:# Minimizing cross-entropy loss = Maximizing log-likelihood# They are the SAME optimization problem! import torchimport torch.nn as nnimport torch.nn.functional as F # Example: Simple linear classifierclass LinearClassifier(nn.Module): def __init__(self, input_dim, num_classes): super().__init__() self.linear = nn.Linear(input_dim, num_classes) def forward(self, x): return self.linear(x) # Returns logits def predict_proba(self, x): return F.softmax(self.forward(x), dim=1) # Generate synthetic datatorch.manual_seed(42)num_samples = 100input_dim = 5num_classes = 3 X = torch.randn(num_samples, input_dim)y = torch.randint(0, num_classes, (num_samples,)) # Create modelmodel = LinearClassifier(input_dim, num_classes) # Method 1: Cross-Entropy Losslogits = model(X)ce_loss = F.cross_entropy(logits, y) # Method 2: Negative Log-Likelihood (manual)probs = F.softmax(logits, dim=1)# Get probability of true class for each sampletrue_class_probs = probs[torch.arange(num_samples), y]nll_loss = -torch.log(true_class_probs + 1e-10).mean() print(f"Cross-Entropy Loss: {ce_loss.item():.6f}")print(f"Manual NLL Loss: {nll_loss.item():.6f}")print(f"Difference: {abs(ce_loss.item() - nll_loss.item()):.2e}") # They are the same!Why this matters:
The MLE-cross-entropy equivalence tells us that:
Cross-entropy is principled: It's not an arbitrary choice but emerges from the fundamental goal of finding probability distributions that best explain the observed data.
Statistical properties transfer: MLE has well-understood properties (consistency, asymptotic efficiency, normality). Since cross-entropy training is MLE, the same guarantees apply.
Probabilistic interpretation: The trained model's outputs can be interpreted as probabilities—they represent the model's belief about class membership, not just scores.
Generalizes naturally: For any exponential family distribution, MLE leads to a corresponding loss function. Cross-entropy is the special case for categorical distributions.
Another way to see this: minimizing cross-entropy H(P_data, Q_model) while P_data is fixed is equivalent to minimizing KL divergence D_KL(P_data || Q_model). We're finding the model Q that is "closest" to the true data distribution P in the sense of KL divergence. More on this in the next page!
A natural question arises: why use cross-entropy instead of Mean Squared Error (MSE) for classification? After all, we could treat class probabilities as regression targets and minimize the squared difference between predicted and true probabilities.
It turns out there are compelling reasons to prefer cross-entropy:
12345678910111213141516171819202122232425262728293031323334353637383940
import numpy as npimport matplotlib.pyplot as plt def sigmoid(z): return 1 / (1 + np.exp(-z)) def sigmoid_derivative(z): s = sigmoid(z) return s * (1 - s) # True label y = 1 (positive class)y = 1 # Range of logits (pre-sigmoid values)z = np.linspace(-6, 6, 200)p = sigmoid(z) # Predictions # MSE Loss: L = (y - p)² / 2# dL/dz = dL/dp * dp/dz = (p - y) * sigmoid'(z)mse_grad = (p - y) * sigmoid_derivative(z) # Cross-Entropy Loss: L = -[y*log(p) + (1-y)*log(1-p)]# dL/dz = p - y (sigmoid derivative cancels!)ce_grad = p - y print("Gradient Analysis (y = 1, wanting p close to 1):")print("=" * 50) for z_val in [-5, -2, 0, 2, 5]: p_val = sigmoid(z_val) mse_g = (p_val - y) * sigmoid_derivative(z_val) ce_g = p_val - y print(f"z = {z_val:+.0f}: p = {p_val:.4f}, MSE grad = {mse_g:+.4f}, CE grad = {ce_g:+.4f}") print()print("Key observations:")print("- At z=-5 (very wrong): MSE grad ≈ 0, CE grad ≈ -1 (CE learns fast)")print("- At z=+5 (very right): Both grads ≈ 0 (both stop, as they should)")print("- CE gradient is always meaningful; MSE vanishes when saturated")The Beautiful Gradient Cancellation:
For sigmoid output with cross-entropy loss, the gradient has a remarkably simple form:
∂L/∂z = σ(z) - y = p - y
The sigmoid's derivative completely cancels! This is not a coincidence—it's a deep consequence of the sigmoid being the "natural" activation for the categorical loss, arising from the theory of exponential families in statistics.
Impact on Training:
| Scenario | Prediction | MSE Gradient | CE Gradient |
|---|---|---|---|
| Very wrong | p = 0.01, y = 1 | ≈ -0.01 | ≈ -0.99 |
| Moderately wrong | p = 0.30, y = 1 | ≈ -0.15 | ≈ -0.70 |
| Nearly correct | p = 0.90, y = 1 | ≈ -0.01 | ≈ -0.10 |
| Correct | p = 0.99, y = 1 | ≈ -0.001 | ≈ -0.01 |
With MSE, the worst case (p = 0.01 when y = 1) has the smallest gradient! Cross-entropy correctly assigns largest gradients to worst errors.
Cross-entropy applies differently depending on whether your problem is multi-class (exactly one class per sample) or multi-label (zero or more classes per sample). Understanding this distinction is crucial for correct implementation.
| Aspect | Multi-Class | Multi-Label |
|---|---|---|
| Classes per sample | Exactly one | Zero, one, or many |
| Example problem | ImageNet (1 of 1000 objects) | Movie genres (multiple per film) |
| Output activation | Softmax | Independent Sigmoids |
| Output constraint | Probabilities sum to 1 | Each probability independent |
| Loss function | Categorical Cross-Entropy | Binary Cross-Entropy per class |
| PyTorch function | F.cross_entropy | F.binary_cross_entropy_with_logits |
| Label format | Integer class index | Binary vector (multi-hot) |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
import torchimport torch.nn.functional as F # ==========================================# MULTI-CLASS: Exactly one class per sample# ========================================== # Logits from model (batch_size=4, num_classes=5)logits_multiclass = torch.randn(4, 5) # Labels: integer indices of correct classlabels_multiclass = torch.tensor([0, 3, 2, 1]) # Loss: Softmax + Cross-Entropy (combined for stability)loss_multiclass = F.cross_entropy(logits_multiclass, labels_multiclass)print(f"Multi-class CE loss: {loss_multiclass.item():.4f}") # Predicted probabilities (for inspection)probs_multiclass = F.softmax(logits_multiclass, dim=1)print(f"Multi-class probs sum to 1: {probs_multiclass.sum(dim=1)}") # ==========================================# MULTI-LABEL: Multiple classes per sample# ========================================== # Same logits shape, but different interpretationlogits_multilabel = torch.randn(4, 5) # Labels: binary vector indicating which classes are present# Each sample can have 0, 1, or more classes activelabels_multilabel = torch.tensor([ [1, 0, 1, 0, 0], # Sample 1: classes 0 and 2 [0, 0, 0, 1, 1], # Sample 2: classes 3 and 4 [1, 1, 1, 0, 0], # Sample 3: classes 0, 1, and 2 [0, 0, 0, 0, 1], # Sample 4: class 4 only], dtype=torch.float32) # Loss: Independent binary cross-entropy per classloss_multilabel = F.binary_cross_entropy_with_logits(logits_multilabel, labels_multilabel)print(f"Multi-label BCE loss: {loss_multilabel.item():.4f}") # Predicted probabilities (independent sigmoids)probs_multilabel = torch.sigmoid(logits_multilabel)print(f"Multi-label probs (don't sum to 1): {probs_multilabel.sum(dim=1)}") # Common mistake: Using cross_entropy for multi-label# This would force probabilities to sum to 1, which is WRONG!print("⚠️ Do NOT use F.cross_entropy for multi-label problems!")print(" Softmax assumes mutually exclusive classes.")Using softmax + categorical cross-entropy for multi-label classification is incorrect. Softmax forces outputs to sum to 1, implying mutual exclusivity. For multi-label, use independent sigmoids with binary cross-entropy for each class. The total loss is the sum (or mean) of per-class BCE losses.
When using cross-entropy in practice, several implementation details matter for training stability, speed, and accuracy.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
import torchimport torch.nn.functional as F # Label Smoothing Cross-Entropydef label_smoothing_ce(logits, targets, smoothing=0.1): """ Cross-entropy with label smoothing. Instead of one-hot targets, uses: - (1 - smoothing) for the true class - smoothing / (num_classes - 1) for other classes """ num_classes = logits.size(-1) with torch.no_grad(): # Create smoothed labels smooth_targets = torch.full_like(logits, smoothing / (num_classes - 1)) smooth_targets.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing) log_probs = F.log_softmax(logits, dim=-1) return -(smooth_targets * log_probs).sum(dim=-1).mean() # Class-Weighted Cross-Entropydef weighted_ce(logits, targets, class_weights): """ Cross-entropy with per-class weights. Useful for imbalanced datasets. """ return F.cross_entropy(logits, targets, weight=class_weights) # Focal Loss (for extreme imbalance)def focal_loss(logits, targets, gamma=2.0, alpha=None): """ Focal Loss: -alpha * (1-p)^gamma * log(p) Down-weights easy examples, focuses on hard ones. gamma=0 recovers standard cross-entropy. """ ce_loss = F.cross_entropy(logits, targets, reduction='none') pt = torch.exp(-ce_loss) # probability of true class focal_weight = (1 - pt) ** gamma if alpha is not None: alpha_weight = alpha[targets] focal_weight = focal_weight * alpha_weight return (focal_weight * ce_loss).mean() # Example usagelogits = torch.randn(8, 5)targets = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2]) print("Standard CE:", F.cross_entropy(logits, targets).item())print("Label smoothed (ε=0.1):", label_smoothing_ce(logits, targets, 0.1).item()) # For imbalanced data (class 4 is rare, gets 5x weight)weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 5.0])print("Weighted CE:", weighted_ce(logits, targets, weights).item()) print("Focal loss (γ=2):", focal_loss(logits, targets, gamma=2.0).item())Label smoothing prevents the model from becoming overconfident by softening the target distribution. Information-theoretically, it adds entropy to the target, preventing the model from driving predictions to extreme values (0 or 1). This acts as implicit regularization and often improves generalization, especially for large models.
We've traced cross-entropy from information theory to practical neural network training. Let's consolidate:
What's next:
We've seen that cross-entropy exceeds entropy by the KL divergence. The next page dives deep into KL divergence—the measure of how different two distributions are. KL divergence is central to variational inference, generative models, and understanding the geometry of probability distributions.
You now understand cross-entropy as both an information-theoretic concept and the canonical classification loss function. You can explain why it's superior to alternatives, implement it correctly for multi-class and multi-label problems, and apply advanced techniques like label smoothing.