Loading content...
While multi-class classification assumes exactly one label per example, many real-world problems require assigning multiple labels simultaneously. An image might contain both 'dog' and 'cat'. A movie can be 'action', 'comedy', and 'romance'. A document may discuss 'machine learning', 'healthcare', and 'ethics'. A customer support ticket might relate to 'billing', 'technical', and 'account access'.
Multi-label classification fundamentally differs from multi-class: labels are not mutually exclusive. Each label is an independent binary decision. This changes everything—the output activation, the loss function, the evaluation metrics, and the training dynamics.
This page provides comprehensive coverage of multi-label output layer design. By understanding the principles, you'll be able to handle any multi-label scenario, from simple tag prediction to complex hierarchical or correlated label structures.
This page covers: the fundamental difference between multi-class and multi-label, independent sigmoid formulation, binary cross-entropy for multiple labels, handling label imbalance per tag, macro/micro/sample-averaged metrics, label correlation methods, thresholding strategies, and production architecture patterns.
The distinction between multi-class and multi-label is critical and frequently misunderstood. Confusing them leads to fundamentally incorrect models.
Multi-class classification:
Multi-label classification:
| Aspect | Multi-Class | Multi-Label |
|---|---|---|
| Label constraint | Exactly one label | Zero to $K$ labels |
| Output activation | Softmax | Element-wise sigmoid |
| Loss function | Categorical cross-entropy | Binary cross-entropy per label |
| Probability constraint | $\sum p_k = 1$ | No constraint on sum |
| Evaluation metrics | Accuracy, top-k accuracy | Subset accuracy, F1, AUC per label |
| Label space size | $K$ possible outcomes | $2^K$ possible label sets |
| Threshold | Usually argmax (implicit) | Per-label threshold needed |
With $K$ labels, there are $2^K$ possible label combinations. For $K = 20$, that's over 1 million possible outputs. This combinatorial explosion is why we model each label independently rather than treating every combination as a separate class.
Formal problem definition:
Given input $\mathbf{x}$, predict a binary vector $\mathbf{y} \in {0, 1}^K$ where $y_k = 1$ indicates label $k$ is present.
The model outputs $K$ probability estimates: $$p_k = \sigma(z_k) = \frac{1}{1 + e^{-z_k}}$$
Each $p_k$ is the probability that label $k$ applies to this example, estimated independently of other labels.
When to use multi-label:
The output layer for multi-label classification produces $K$ logits, each passed through an independent sigmoid:
$$p_k = \sigma(z_k) \quad \text{for } k = 1, \ldots, K$$
Unlike softmax, these probabilities are not normalized—each can be any value in $(0, 1)$ independent of the others. This models $K$ separate Bernoulli distributions:
$$p(y_k = 1 | \mathbf{x}) = \sigma(z_k)$$ $$p(\mathbf{y} | \mathbf{x}) = \prod_{k=1}^K p_k^{y_k} (1 - p_k)^{1-y_k}$$
Architecture:
Input → Hidden Layers → Linear(d, K) → [No Activation] → BCEWithLogitsLoss
↓
K Sigmoids at inference
The key insight: this is exactly $K$ parallel binary classifiers sharing the same feature representation. The hidden layers learn features useful for all labels, while the output layer makes $K$ independent predictions.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
import torchimport torch.nn as nnimport torch.nn.functional as F class MultiLabelClassifier(nn.Module): """ Multi-label classification with independent sigmoids. Each output is an independent binary classification. """ def __init__( self, input_dim: int, num_labels: int, hidden_dims: list = [512, 256], dropout: float = 0.2, ): super().__init__() self.num_labels = num_labels # Shared feature extractor layers = [] prev_dim = input_dim for hidden_dim in hidden_dims: layers.extend([ nn.Linear(prev_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), ]) prev_dim = hidden_dim self.features = nn.Sequential(*layers) # Multi-label output: K independent logits # NO activation - loss function handles sigmoid self.classifier = nn.Linear(prev_dim, num_labels) def forward(self, x) -> torch.Tensor: """ Returns K logits per sample. Shape: [batch_size, num_labels] """ features = self.features(x) logits = self.classifier(features) return logits def predict_proba(self, x) -> torch.Tensor: """ Returns K independent probabilities per sample. Each is in (0, 1), sums unconstrained. """ with torch.no_grad(): logits = self.forward(x) return torch.sigmoid(logits) def predict(self, x, threshold: float = 0.5) -> torch.Tensor: """ Returns binary predictions using given threshold. Shape: [batch_size, num_labels] """ probs = self.predict_proba(x) return (probs >= threshold).long() def predict_with_thresholds( self, x, thresholds: torch.Tensor ) -> torch.Tensor: """ Per-label thresholds for more nuanced prediction. Args: thresholds: tensor of shape [num_labels] """ probs = self.predict_proba(x) return (probs >= thresholds.unsqueeze(0)).long() # Example usagemodel = MultiLabelClassifier( input_dim=2048, # e.g., ResNet features num_labels=20, # 20 possible labels hidden_dims=[512, 256],) # Sample inputx = torch.randn(8, 2048) # 8 imageslogits = model(x) # Shape: [8, 20]probs = model.predict_proba(x) # Shape: [8, 20] print(f"Logits shape: {logits.shape}")print(f"Sample probabilities (first example):")print(f" {probs[0].numpy().round(3)}")print(f"Sum of probs: {probs[0].sum():.2f} (unconstrained!)")print(f"Predicted labels: {model.predict(x)[0].numpy()}")The power of multi-label with neural networks: the hidden layers learn shared representations that benefit all labels (transfer learning within the model), while the output layer specializes per label. This is more parameter-efficient and often more accurate than training K separate binary classifiers.
The loss function for multi-label classification is the sum of binary cross-entropy losses over all labels:
$$\mathcal{L} = -\frac{1}{nK} \sum_{i=1}^n \sum_{k=1}^K \left[ y_{ik} \log(p_{ik}) + (1 - y_{ik})\log(1 - p_{ik}) \right]$$
This treats each label as an independent binary classification problem. PyTorch's BCEWithLogitsLoss handles this directly:
criterion = nn.BCEWithLogitsLoss() # Expects float targets!
loss = criterion(logits, targets.float()) # Convert int to float
Important details:
[batch_size, num_labels] with values in {0, 1}reduction='none' gives per-element losses for custom aggregation123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
import torchimport torch.nn as nnimport torch.nn.functional as F def multilabel_bce_manual(logits, targets): """ Manual multi-label BCE for understanding. Args: logits: [batch_size, num_labels] raw logits targets: [batch_size, num_labels] binary labels """ # For each label, compute binary cross-entropy probs = torch.sigmoid(logits) eps = 1e-7 probs = torch.clamp(probs, eps, 1 - eps) bce = -(targets * torch.log(probs) + (1 - targets) * torch.log(1 - probs)) return bce.mean() # Mean over all samples and labels def multilabel_bce_per_label(logits, targets): """ Compute loss separately for monitoring individual labels. """ num_labels = logits.size(1) losses = {} for k in range(num_labels): label_loss = F.binary_cross_entropy_with_logits( logits[:, k], targets[:, k], ) losses[f'label_{k}'] = label_loss.item() total = F.binary_cross_entropy_with_logits(logits, targets) losses['total'] = total.item() return total, losses # Sample databatch_size = 32num_labels = 10logits = torch.randn(batch_size, num_labels) # Multi-label targets: sparse (few positives per example)targets = (torch.rand(batch_size, num_labels) > 0.8).float() # Different loss formulationscriterion = nn.BCEWithLogitsLoss() print("=== Multi-Label Loss Comparison ===")print(f"Manual: {multilabel_bce_manual(logits, targets):.6f}")print(f"PyTorch: {criterion(logits, targets):.6f}")print(f"Functional: {F.binary_cross_entropy_with_logits(logits, targets):.6f}") # Per-label loss monitoringtotal_loss, per_label_losses = multilabel_bce_per_label(logits, targets)print(f"\nPer-label losses:")for k in range(min(5, num_labels)): print(f" Label {k}: {per_label_losses[f'label_{k}']:.4f}") # Alternative: sum instead of mean (for backward compatibility)criterion_sum = nn.BCEWithLogitsLoss(reduction='sum')print(f"\nSum reduction: {criterion_sum(logits, targets):.4f}")print(f"Equivalent to mean * n * K: {criterion(logits, targets) * batch_size * num_labels:.4f}")The default 'mean' reduction divides by (batch_size × num_labels). With 'sum', the loss scales with both dimensions. For comparable learning rates across different label counts, 'mean' is usually preferred. However, some papers use 'sum' for consistency with their formulations.
Multi-label datasets often suffer from severe imbalance at multiple levels:
Without addressing imbalance, the model learns to predict all negatives—achieving high accuracy but useless predictions.
Strategies for multi-label imbalance:
1. Per-label pos_weight
Weight positive examples more heavily for rare labels:
$$\mathcal{L}k = -\frac{1}{n}\sum_i \left[ w_k^+ y{ik} \log(p_{ik}) + (1 - y_{ik})\log(1 - p_{ik}) \right]$$
where $w_k^+ = n_k^- / n_k^+$ is the negative-to-positive ratio for label $k$.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as np def compute_pos_weight(targets): """ Compute per-label positive weights from training data. pos_weight[k] = num_negatives[k] / num_positives[k] This balances the contribution of positive and negative examples for each label. """ # Count positives and negatives per label pos_counts = targets.sum(dim=0) # [num_labels] neg_counts = targets.size(0) - pos_counts # Avoid division by zero for labels with no positives pos_counts = torch.clamp(pos_counts, min=1) pos_weight = neg_counts / pos_counts return pos_weight class FocalLossMultiLabel(nn.Module): """ Focal Loss adapted for multi-label classification. Down-weights easy examples (both easy positives and easy negatives) to focus on hard cases. """ def __init__(self, alpha=0.25, gamma=2.0, pos_weight=None): super().__init__() self.alpha = alpha # Balance factor self.gamma = gamma # Focusing parameter self.pos_weight = pos_weight def forward(self, logits, targets): probs = torch.sigmoid(logits) # p_t = p if y=1 else 1-p p_t = probs * targets + (1 - probs) * (1 - targets) # Alpha factor alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) # Focal factor focal_weight = (1 - p_t) ** self.gamma # BCE (with optional pos_weight) if self.pos_weight is not None: bce = F.binary_cross_entropy_with_logits( logits, targets, pos_weight=self.pos_weight, reduction='none' ) else: bce = F.binary_cross_entropy_with_logits( logits, targets, reduction='none' ) focal_loss = alpha_t * focal_weight * bce return focal_loss.mean() class AsymmetricLoss(nn.Module): """ Asymmetric Loss for Multi-Label Classification. From: "Asymmetric Loss For Multi-Label Classification" (CVPR 2021) Key idea: Use different focusing parameters for positive (gamma_pos) and negative (gamma_neg) samples. Also applies probability shifting to down-weight hard negatives. """ def __init__( self, gamma_neg=4, gamma_pos=1, clip=0.05, disable_torch_grad_focal_loss=True ): super().__init__() self.gamma_neg = gamma_neg self.gamma_pos = gamma_pos self.clip = clip self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss def forward(self, logits, targets): probs = torch.sigmoid(logits) # Probability shifting: clip negatives probs_pos = probs probs_neg = probs.clamp(max=1 - self.clip) # Loss components loss_pos = targets * torch.log(probs_pos.clamp(min=1e-8)) loss_neg = (1 - targets) * torch.log((1 - probs_neg).clamp(min=1e-8)) # Asymmetric focusing if self.disable_torch_grad_focal_loss: with torch.no_grad(): asymmetric_weight = torch.pow( 1 - probs_pos * targets - probs_neg * (1 - targets), self.gamma_pos * targets + self.gamma_neg * (1 - targets) ) else: asymmetric_weight = torch.pow( 1 - probs_pos * targets - probs_neg * (1 - targets), self.gamma_pos * targets + self.gamma_neg * (1 - targets) ) loss = -asymmetric_weight * (loss_pos + loss_neg) return loss.mean() # Demonstration with imbalanced dataprint("=== Multi-Label Imbalance Handling ===") batch_size = 100num_labels = 10 # Create imbalanced targets (most labels are rare)label_probs = torch.tensor([0.5, 0.3, 0.1, 0.05, 0.02, 0.01, 0.01, 0.005, 0.002, 0.001])targets = (torch.rand(batch_size, num_labels) < label_probs).float() print(f"Label frequencies: {targets.mean(dim=0).numpy().round(3)}") # Compute pos_weightpos_weight = compute_pos_weight(targets)print(f"Pos weights: {pos_weight.numpy().round(1)}") # Simulate predictions (random logits)logits = torch.randn(batch_size, num_labels) # Compare lossesbce_standard = F.binary_cross_entropy_with_logits(logits, targets)bce_weighted = F.binary_cross_entropy_with_logits( logits, targets, pos_weight=pos_weight)focal = FocalLossMultiLabel(gamma=2.0)(logits, targets)asymmetric = AsymmetricLoss()(logits, targets) print(f"\nLoss values:")print(f" Standard BCE: {bce_standard:.4f}")print(f" Weighted BCE: {bce_weighted:.4f}")print(f" Focal Loss: {focal:.4f}")print(f" Asymmetric: {asymmetric:.4f}")When positives are very rare (<1% of examples), Asymmetric Loss often outperforms Focal Loss. It uses stronger down-weighting for negatives (higher gamma_neg) and probability clipping to focus even more on finding rare positives.
Evaluating multi-label classifiers is more nuanced than single-label accuracy. We need metrics that capture performance across labels, across samples, and at different thresholds.
Key metrics:
1. Subset Accuracy (Exact Match) $$\text{SubsetAccuracy} = \frac{1}{n} \sum_i \mathbb{1}[\hat{\mathbf{y}}_i = \mathbf{y}_i]$$ Binary: is the predicted label set exactly correct? Very strict.
2. Hamming Loss $$\text{HammingLoss} = \frac{1}{nK} \sum_i \sum_k \mathbb{1}[\hat{y}{ik} \neq y{ik}]$$ Fraction of labels incorrectly predicted. Lower is better.
3. Sample-Averaged Metrics Compute P/R/F1 per sample, average across samples.
4. Label-Averaged (Macro) Metrics Compute P/R/F1 per label, average across labels.
5. Micro-Averaged Metrics Pool all predictions, compute P/R/F1 on the aggregate.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
import torchimport numpy as npfrom sklearn.metrics import ( accuracy_score, hamming_loss, precision_score, recall_score, f1_score, average_precision_score, roc_auc_score, multilabel_confusion_matrix) def compute_multilabel_metrics(y_true, y_pred, y_probs=None): """ Compute comprehensive multi-label metrics. Args: y_true: [n_samples, n_labels] ground truth y_pred: [n_samples, n_labels] binary predictions y_probs: [n_samples, n_labels] probabilities (optional) Returns: Dictionary of metrics """ metrics = {} # 1. Subset Accuracy (exact match ratio) metrics['subset_accuracy'] = accuracy_score(y_true, y_pred) # 2. Hamming Loss metrics['hamming_loss'] = hamming_loss(y_true, y_pred) # 3. Sample-averaged metrics (per-sample, then average) metrics['sample_precision'] = precision_score( y_true, y_pred, average='samples', zero_division=0 ) metrics['sample_recall'] = recall_score( y_true, y_pred, average='samples', zero_division=0 ) metrics['sample_f1'] = f1_score( y_true, y_pred, average='samples', zero_division=0 ) # 4. Macro-averaged metrics (per-label, then average) metrics['macro_precision'] = precision_score( y_true, y_pred, average='macro', zero_division=0 ) metrics['macro_recall'] = recall_score( y_true, y_pred, average='macro', zero_division=0 ) metrics['macro_f1'] = f1_score( y_true, y_pred, average='macro', zero_division=0 ) # 5. Micro-averaged metrics (aggregate, then compute) metrics['micro_precision'] = precision_score( y_true, y_pred, average='micro', zero_division=0 ) metrics['micro_recall'] = recall_score( y_true, y_pred, average='micro', zero_division=0 ) metrics['micro_f1'] = f1_score( y_true, y_pred, average='micro', zero_division=0 ) # 6. Probability-based metrics (if probs provided) if y_probs is not None: # Mean Average Precision (mAP) try: metrics['mAP'] = average_precision_score( y_true, y_probs, average='macro' ) except ValueError: metrics['mAP'] = float('nan') # ROC-AUC (macro) try: metrics['macro_auc'] = roc_auc_score( y_true, y_probs, average='macro' ) except ValueError: metrics['macro_auc'] = float('nan') return metrics def find_optimal_thresholds(y_true, y_probs, metric='f1'): """ Find per-label optimal thresholds. Returns: thresholds: [n_labels] optimal threshold for each label """ n_labels = y_true.shape[1] thresholds = [] for k in range(n_labels): best_thresh = 0.5 best_metric = 0 for thresh in np.arange(0.1, 0.9, 0.05): pred = (y_probs[:, k] >= thresh).astype(int) if metric == 'f1': score = f1_score(y_true[:, k], pred, zero_division=0) elif metric == 'precision': score = precision_score(y_true[:, k], pred, zero_division=0) elif metric == 'recall': score = recall_score(y_true[:, k], pred, zero_division=0) if score > best_metric: best_metric = score best_thresh = thresh thresholds.append(best_thresh) return np.array(thresholds) # Demonstrationprint("=== Multi-Label Metrics Demo ===") # Create sample datan_samples, n_labels = 200, 10np.random.seed(42) # Ground truth (imbalanced)label_probs = np.array([0.3, 0.2, 0.15, 0.1, 0.08, 0.06, 0.05, 0.03, 0.02, 0.01])y_true = (np.random.rand(n_samples, n_labels) < label_probs).astype(int) # Simulated probabilities (correlated with truth, with noise)y_probs = np.clip(y_true * 0.6 + np.random.randn(n_samples, n_labels) * 0.2 + 0.2, 0, 1) # Standard threshold predictionsy_pred_05 = (y_probs >= 0.5).astype(int) # Compute metrics with default thresholdmetrics = compute_multilabel_metrics(y_true, y_pred_05, y_probs) print("Metrics with threshold=0.5:")for name, value in metrics.items(): print(f" {name}: {value:.4f}") # Find optimal per-label thresholdsoptimal_thresholds = find_optimal_thresholds(y_true, y_probs, metric='f1')print(f"\nOptimal thresholds per label: {optimal_thresholds.round(2)}") # Predictions with optimal thresholdsy_pred_optimal = (y_probs >= optimal_thresholds).astype(int)metrics_optimal = compute_multilabel_metrics(y_true, y_pred_optimal, y_probs) print("\nMetrics with optimal thresholds:")for name, value in metrics_optimal.items(): if name in ['sample_f1', 'macro_f1', 'micro_f1']: print(f" {name}: {value:.4f} (was {metrics[name]:.4f})")Micro-F1 weights labels by frequency (majority labels dominate). Macro-F1 treats all labels equally (rare labels matter more). Sample-F1 focuses on per-example quality. mAP is threshold-free and often best for ranking. Choose based on application: if rare labels are critical, use macro; if overall correctness matters, use micro.
The independent sigmoid assumption ignores label correlations. In reality, labels often co-occur ('outdoor' with 'sky') or are mutually exclusive ('indoors' vs 'outdoors'). Modeling these dependencies can improve predictions.
Approaches to capture label relationships:
1. Classifier Chains Predictions are made sequentially, with each label conditioned on previous predictions: $$p(y_k | \mathbf{x}, y_1, \ldots, y_{k-1})$$
2. Label Attention Use attention mechanisms over label embeddings to model dependencies.
3. Graph Neural Networks Model label relationships as a graph, propagate information between related labels.
4. Output Regularization Add losses that penalize unlikely label combinations.
5. Multi-task with Shared Structure Explicitly share parameters between related labels.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
import torchimport torch.nn as nnimport torch.nn.functional as F class ClassifierChain(nn.Module): """ Classifier Chain for correlated multi-label prediction. Labels are predicted sequentially, with each classifier receiving the predictions of previous labels as input. """ def __init__(self, input_dim, num_labels, hidden_dim=128): super().__init__() self.num_labels = num_labels # Shared feature extractor self.shared_encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), ) # Per-label classifiers (each takes features + previous predictions) self.classifiers = nn.ModuleList() for k in range(num_labels): # Input: hidden_dim + k (previous predictions) self.classifiers.append( nn.Sequential( nn.Linear(hidden_dim + k, hidden_dim // 2), nn.ReLU(), nn.Linear(hidden_dim // 2, 1), ) ) def forward(self, x, y_prev=None, teacher_forcing=True): """ Forward pass. Args: x: Input features [batch_size, input_dim] y_prev: Ground truth for teacher forcing (training only) teacher_forcing: Use ground truth (training) or predictions (inference) Returns: logits: [batch_size, num_labels] """ batch_size = x.size(0) features = self.shared_encoder(x) logits = [] predictions = [] for k in range(self.num_labels): if k == 0: # First label: only features classifier_input = features else: # Subsequent labels: features + previous predictions if teacher_forcing and y_prev is not None: prev_labels = y_prev[:, :k].float() else: prev_labels = torch.stack(predictions, dim=1) classifier_input = torch.cat([features, prev_labels], dim=1) logit = self.classifiers[k](classifier_input).squeeze(-1) logits.append(logit) # For next iteration if teacher_forcing and y_prev is not None: pass # Use ground truth else: predictions.append(torch.sigmoid(logit)) return torch.stack(logits, dim=1) class LabelAttentionClassifier(nn.Module): """ Multi-label classification with label attention. Each label has an embedding. Predictions are made by attending from label embeddings to input features. """ def __init__(self, input_dim, num_labels, label_embed_dim=64, hidden_dim=256): super().__init__() self.num_labels = num_labels # Feature encoder self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), ) # Learnable label embeddings self.label_embeddings = nn.Parameter( torch.randn(num_labels, label_embed_dim) ) # Projection for attention self.query_proj = nn.Linear(label_embed_dim, hidden_dim) self.key_proj = nn.Linear(hidden_dim, hidden_dim) self.value_proj = nn.Linear(hidden_dim, hidden_dim) # Final classifier self.classifier = nn.Linear(hidden_dim, 1) def forward(self, x): """ Args: x: [batch_size, input_dim] Returns: logits: [batch_size, num_labels] """ batch_size = x.size(0) # Encode features features = self.encoder(x) # [batch, hidden_dim] # Compute attention from labels to features queries = self.query_proj(self.label_embeddings) # [num_labels, hidden_dim] keys = self.key_proj(features) # [batch, hidden_dim] values = self.value_proj(features) # [batch, hidden_dim] # Attention: each label attends to the sample's features # [num_labels, hidden_dim] @ [batch, hidden_dim].T => [num_labels, batch] attention_scores = torch.matmul(queries, keys.T) / (queries.size(-1) ** 0.5) attention_weights = F.softmax(attention_scores, dim=-1) # [num_labels, batch] # Context for each label-sample pair # Actually, we need per-sample predictions # Simpler: direct scoring # Alternative: direct bilinear scoring # logits[i, k] = features[i] @ W @ label_embeddings[k] logits = torch.matmul( features, self.query_proj(self.label_embeddings).T ) / (features.size(-1) ** 0.5) return logits # Demonstrationprint("=== Label Correlation Models ===") input_dim = 128num_labels = 10batch_size = 16 x = torch.randn(batch_size, input_dim)y_true = (torch.rand(batch_size, num_labels) > 0.7).long() # Classifier chainchain = ClassifierChain(input_dim, num_labels)logits_chain = chain(x, y_true, teacher_forcing=True)print(f"Classifier Chain output shape: {logits_chain.shape}") # Label attentionattention = LabelAttentionClassifier(input_dim, num_labels)logits_attention = attention(x)print(f"Label Attention output shape: {logits_attention.shape}") # Loss computationcriterion = nn.BCEWithLogitsLoss()loss_chain = criterion(logits_chain, y_true.float())loss_attention = criterion(logits_attention, y_true.float()) print(f"\nChain loss: {loss_chain:.4f}")print(f"Attention loss: {loss_attention:.4f}")Simple independent sigmoids often work surprisingly well. Model correlations only when: (1) you have strong prior knowledge of dependencies, (2) independent models underperform, (3) you have enough data to learn complex interactions. Start simple, add complexity only if needed.
Let's consolidate a production-ready multi-label classification system with best practices:
Key considerations:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Optional, Tuple, Dict, List class ProductionMultiLabelClassifier(nn.Module): """ Production-ready multi-label classifier with: - Per-label threshold tuning - Label constraints - Confidence calibration - Fallback handling """ def __init__( self, input_dim: int, label_names: List[str], hidden_dims: List[int] = [512, 256], dropout: float = 0.2, ): super().__init__() self.label_names = label_names self.num_labels = len(label_names) # Feature extractor layers = [] prev_dim = input_dim for hidden_dim in hidden_dims: layers.extend([ nn.Linear(prev_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), ]) prev_dim = hidden_dim self.features = nn.Sequential(*layers) self.classifier = nn.Linear(prev_dim, self.num_labels) # Per-label thresholds (learnable or fixed) self.register_buffer( 'thresholds', torch.full((self.num_labels,), 0.5) ) # Per-label temperature for calibration self.register_buffer( 'temperatures', torch.ones(self.num_labels) ) def forward(self, x) -> torch.Tensor: """Returns raw logits.""" features = self.features(x) return self.classifier(features) def predict_proba( self, x, calibrated: bool = True ) -> torch.Tensor: """Returns calibrated probabilities per label.""" with torch.no_grad(): logits = self.forward(x) if calibrated: logits = logits / self.temperatures.unsqueeze(0) return torch.sigmoid(logits) def predict( self, x, min_labels: int = 0, max_labels: Optional[int] = None, min_confidence: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict with constraints. Args: x: Input features min_labels: Minimum number of labels to predict max_labels: Maximum number of labels to predict min_confidence: Minimum confidence to assign any label Returns: predictions: Binary label predictions confidences: Probability scores """ probs = self.predict_proba(x) batch_size = probs.size(0) # Apply per-label thresholds predictions = probs >= self.thresholds.unsqueeze(0) # Apply constraints per sample for i in range(batch_size): sample_probs = probs[i] sample_preds = predictions[i] n_predicted = sample_preds.sum().item() # Enforce minimum labels if n_predicted < min_labels: # Add top labels until minimum met sorted_indices = torch.argsort(sample_probs, descending=True) for idx in sorted_indices[:min_labels]: if sample_probs[idx] >= min_confidence: predictions[i, idx] = True # Enforce maximum labels if max_labels is not None and n_predicted > max_labels: # Keep only top-k labels sorted_indices = torch.argsort(sample_probs, descending=True) keep = sorted_indices[:max_labels] new_preds = torch.zeros_like(sample_preds, dtype=torch.bool) new_preds[keep] = sample_preds[keep] >= self.thresholds[keep] predictions[i] = new_preds return predictions.long(), probs def set_thresholds(self, thresholds: torch.Tensor): """Set per-label classification thresholds.""" assert len(thresholds) == self.num_labels self.thresholds.copy_(thresholds) def set_temperatures(self, temperatures: torch.Tensor): """Set per-label calibration temperatures.""" assert len(temperatures) == self.num_labels self.temperatures.copy_(temperatures) def predict_with_names( self, x, threshold: float = None, ) -> List[Dict]: """ Human-readable predictions. Returns list of dicts with label names and confidences. """ probs = self.predict_proba(x) if threshold is None: thresh = self.thresholds else: thresh = torch.full((self.num_labels,), threshold) results = [] for i in range(probs.size(0)): sample_result = {} for k, name in enumerate(self.label_names): conf = probs[i, k].item() if conf >= thresh[k]: sample_result[name] = round(conf, 4) results.append(sample_result) return results class MultiLabelTrainer: """Training wrapper with multi-label specific handling.""" def __init__( self, model: ProductionMultiLabelClassifier, pos_weight: Optional[torch.Tensor] = None, learning_rate: float = 1e-3, ): self.model = model self.criterion = nn.BCEWithLogitsLoss( pos_weight=pos_weight ) self.optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=1e-4, ) def train_step(self, x, y) -> Tuple[float, Dict]: """Single training step with per-label metrics.""" self.model.train() self.optimizer.zero_grad() logits = self.model(x) loss = self.criterion(logits, y.float()) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() # Compute per-label accuracy with torch.no_grad(): probs = torch.sigmoid(logits) preds = (probs >= 0.5).long() per_label_acc = (preds == y).float().mean(dim=0) return loss.item(), { 'per_label_accuracy': per_label_acc.mean().item() } # Example usagelabel_names = ['cat', 'dog', 'person', 'car', 'bicycle', 'tree', 'building', 'sky', 'grass', 'water'] model = ProductionMultiLabelClassifier( input_dim=2048, label_names=label_names, hidden_dims=[512, 256],) # Simulate predictionsx = torch.randn(3, 2048)predictions, confidences = model.predict(x, min_labels=1, max_labels=5) print("=== Production Multi-Label Demo ===")print(f"Predictions shape: {predictions.shape}")print(f"\nHuman-readable predictions:")for i, result in enumerate(model.predict_with_names(x)): print(f" Sample {i}: {result}")If your labels have hierarchy (e.g., 'animal' → 'mammal' → 'dog'), predictions should be consistent: if 'dog' is predicted, so should 'mammal' and 'animal'. Enforce this in post-processing or via hierarchical loss functions.
Multi-label classification is fundamentally different from multi-class classification. Understanding this distinction—and designing appropriate output layers—is crucial for any task where examples can belong to multiple categories simultaneously.
You now understand multi-label classification output design comprehensively. Next, we'll explore structured outputs, where the prediction space has complex internal structure—sequences, trees, graphs, or other non-decomposable formats that require specialized output architectures.