Loading learning content...
Every classification metric—accuracy, precision, recall, F1, kappa—derives from the same fundamental structure: the confusion matrix. For binary classification, this is a simple 2×2 table. But as the number of classes grows, the confusion matrix becomes a rich K×K structure that captures the complete pattern of model predictions and errors.
Understanding the multi-class confusion matrix isn't just about reading off numbers—it's about extracting insights: Which classes does the model confuse? Are errors symmetric? Where should you focus improvement efforts? The confusion matrix answers all these questions.
By the end of this page, you will construct and normalize confusion matrices, extract per-class and aggregate metrics, diagnose systematic confusion patterns, create effective visualizations, and use confusion matrix analysis to guide model improvement.
A confusion matrix C for K classes is a K×K matrix where:
Cᵢⱼ = |{samples where true label = i AND predicted label = j}|
By convention:
Example for 3 classes:
| Pred: A | Pred: B | Pred: C | |
|---|---|---|---|
| True: A | 85 | 10 | 5 |
| True: B | 8 | 72 | 20 |
| True: C | 3 | 15 | 82 |
Interpretation:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
import numpy as npfrom sklearn.metrics import confusion_matrixfrom typing import List, Optional def build_confusion_matrix( y_true: np.ndarray, y_pred: np.ndarray, labels: Optional[List] = None) -> np.ndarray: """ Build a confusion matrix from predictions. Parameters ---------- y_true : array-like Ground truth labels y_pred : array-like Predicted labels labels : list, optional Explicit ordering of classes Returns ------- C : ndarray of shape (n_classes, n_classes) Confusion matrix where C[i,j] = count(true=i, pred=j) """ if labels is None: labels = sorted(np.unique(np.concatenate([y_true, y_pred]))) return confusion_matrix(y_true, y_pred, labels=labels) def manual_confusion_matrix( y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: """Manual implementation for educational purposes.""" classes = sorted(np.unique(np.concatenate([y_true, y_pred]))) K = len(classes) class_to_idx = {c: i for i, c in enumerate(classes)} C = np.zeros((K, K), dtype=int) for true_label, pred_label in zip(y_true, y_pred): i = class_to_idx[true_label] j = class_to_idx[pred_label] C[i, j] += 1 return C def print_confusion_matrix(C: np.ndarray, class_names: List[str] = None): """Pretty-print a confusion matrix.""" K = C.shape[0] if class_names is None: class_names = [f"Class {i}" for i in range(K)] # Header header = "True\Pred".ljust(12) + "".join(f"{name:>10}" for name in class_names) print(header) print("-" * len(header)) # Rows for i, name in enumerate(class_names): row = name.ljust(12) + "".join(f"{C[i,j]:>10}" for j in range(K)) print(row) # Example usagenp.random.seed(42)y_true = np.random.choice(['Cat', 'Dog', 'Bird'], 300, p=[0.5, 0.3, 0.2])y_pred = y_true.copy()# Add noiseflip_idx = np.random.choice(300, 60)y_pred[flip_idx] = np.random.choice(['Cat', 'Dog', 'Bird'], 60) C = build_confusion_matrix(y_true, y_pred, labels=['Cat', 'Dog', 'Bird'])print_confusion_matrix(C, ['Cat', 'Dog', 'Bird'])Raw counts can be difficult to interpret when class sizes vary. Normalizing the confusion matrix converts counts to proportions, enabling clearer interpretation.
| Normalization | Formula | Interpretation | Reveals |
|---|---|---|---|
| By row (true) | C[i,j] / Σⱼ C[i,j] | P(pred=j | true=i) | Recall/sensitivity per class |
| By column (pred) | C[i,j] / Σᵢ C[i,j] | P(true=i | pred=j) | Precision per class |
| By total | C[i,j] / ΣᵢΣⱼ C[i,j] | P(true=i, pred=j) | Joint distribution |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
import numpy as np def normalize_confusion_matrix( C: np.ndarray, normalize: str = 'true') -> np.ndarray: """ Normalize confusion matrix. Parameters ---------- C : ndarray Raw confusion matrix normalize : {'true', 'pred', 'all'} - 'true': normalize by row (true label) -> shows recall - 'pred': normalize by column (predicted) -> shows precision - 'all': normalize by total -> shows joint probability Returns ------- C_norm : ndarray Normalized confusion matrix """ if normalize == 'true': # Row normalization: each row sums to 1 row_sums = C.sum(axis=1, keepdims=True) row_sums = np.where(row_sums == 0, 1, row_sums) # Avoid division by zero return C.astype(float) / row_sums elif normalize == 'pred': # Column normalization: each column sums to 1 col_sums = C.sum(axis=0, keepdims=True) col_sums = np.where(col_sums == 0, 1, col_sums) return C.astype(float) / col_sums elif normalize == 'all': # Total normalization: entire matrix sums to 1 total = C.sum() return C.astype(float) / total if total > 0 else C.astype(float) else: raise ValueError(f"Unknown normalization: {normalize}") def show_all_normalizations(C: np.ndarray, class_names: list): """Display all normalization variants side by side.""" print("Raw Counts:") print(C) print() print("Row-normalized (Recall per class):") C_recall = normalize_confusion_matrix(C, 'true') for i, name in enumerate(class_names): print(f" {name}: Recall = {C_recall[i,i]:.3f}") print() print("Column-normalized (Precision per class):") C_precision = normalize_confusion_matrix(C, 'pred') for j, name in enumerate(class_names): print(f" {name}: Precision = {C_precision[j,j]:.3f}")• Row normalization when you care about recall: 'Of all actual cats, what fraction did we catch?'
• Column normalization when you care about precision: 'Of all predicted cats, what fraction were actually cats?'
• Total normalization for understanding overall distribution of predictions.
Every classification metric can be derived directly from the confusion matrix. For class k in a K-class problem:
True Positives (TPₖ) = Cₖₖ (diagonal entry)
False Positives (FPₖ) = Σᵢ≠ₖ Cᵢₖ (column k, excluding diagonal)
False Negatives (FNₖ) = Σⱼ≠ₖ Cₖⱼ (row k, excluding diagonal)
True Negatives (TNₖ) = ΣᵢΣⱼ Cᵢⱼ - (TPₖ + FPₖ + FNₖ) (everything else)
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
import numpy as npfrom typing import Dict, List def extract_per_class_metrics(C: np.ndarray) -> Dict[int, Dict]: """ Extract TP, FP, FN, TN and derived metrics for each class. Parameters ---------- C : ndarray of shape (K, K) Confusion matrix Returns ------- metrics : dict Per-class metrics including precision, recall, F1 """ K = C.shape[0] total = C.sum() metrics = {} for k in range(K): tp = C[k, k] fp = C[:, k].sum() - tp # Column sum minus diagonal fn = C[k, :].sum() - tp # Row sum minus diagonal tn = total - tp - fp - fn # Derived metrics precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) \ if (precision + recall) > 0 else 0.0 specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0 metrics[k] = { 'tp': int(tp), 'fp': int(fp), 'fn': int(fn), 'tn': int(tn), 'precision': precision, 'recall': recall, 'f1': f1, 'specificity': specificity, 'support': int(tp + fn) } return metrics def confusion_to_report(C: np.ndarray, class_names: List[str] = None) -> str: """Generate a complete classification report from confusion matrix.""" K = C.shape[0] if class_names is None: class_names = [f"Class {i}" for i in range(K)] metrics = extract_per_class_metrics(C) lines = [] lines.append("=" * 70) lines.append(f"{'Class':<15} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Support':>10}") lines.append("-" * 70) total_support = 0 weighted_p = weighted_r = weighted_f1 = 0.0 for k, name in enumerate(class_names): m = metrics[k] lines.append( f"{name:<15} {m['precision']:>10.4f} {m['recall']:>10.4f} " f"{m['f1']:>10.4f} {m['support']:>10}" ) total_support += m['support'] weighted_p += m['precision'] * m['support'] weighted_r += m['recall'] * m['support'] weighted_f1 += m['f1'] * m['support'] lines.append("-" * 70) # Aggregates macro_p = np.mean([m['precision'] for m in metrics.values()]) macro_r = np.mean([m['recall'] for m in metrics.values()]) macro_f1 = np.mean([m['f1'] for m in metrics.values()]) lines.append(f"{'Macro avg':<15} {macro_p:>10.4f} {macro_r:>10.4f} {macro_f1:>10.4f}") lines.append(f"{'Weighted avg':<15} {weighted_p/total_support:>10.4f} " f"{weighted_r/total_support:>10.4f} {weighted_f1/total_support:>10.4f}") lines.append(f"{'Accuracy':<15} {np.trace(C)/C.sum():>10.4f}") lines.append("=" * 70) return "\n".join(lines)Beyond metrics, the confusion matrix reveals systematic error patterns that guide model improvement:
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
import numpy as npfrom typing import List, Tuple def find_confusion_pairs( C: np.ndarray, threshold: float = 0.1) -> List[Tuple[int, int, float, float]]: """ Find pairs of classes that are frequently confused. Returns pairs (i, j, C[i,j]/row_sum_i, C[j,i]/row_sum_j) where either ratio exceeds the threshold. """ K = C.shape[0] row_sums = C.sum(axis=1) pairs = [] for i in range(K): for j in range(i+1, K): ratio_ij = C[i, j] / row_sums[i] if row_sums[i] > 0 else 0 ratio_ji = C[j, i] / row_sums[j] if row_sums[j] > 0 else 0 if ratio_ij > threshold or ratio_ji > threshold: pairs.append((i, j, ratio_ij, ratio_ji)) # Sort by max confusion pairs.sort(key=lambda x: max(x[2], x[3]), reverse=True) return pairs def diagnose_confusion_matrix(C: np.ndarray, class_names: List[str] = None): """Automated diagnosis of confusion matrix patterns.""" K = C.shape[0] if class_names is None: class_names = [f"Class {i}" for i in range(K)] row_sums = C.sum(axis=1) col_sums = C.sum(axis=0) total = C.sum() print("Confusion Matrix Diagnosis:") print("=" * 60) # Check for majority class bias col_fractions = col_sums / total if col_fractions.max() > 0.5: dominant_class = class_names[col_fractions.argmax()] print(f"⚠️ MAJORITY BIAS: {col_fractions.max()*100:.1f}% of predictions " f"go to '{dominant_class}'") # Find hardest classes (lowest recall) recalls = np.diag(C) / np.where(row_sums > 0, row_sums, 1) hardest_idx = recalls.argmin() print(f"📉 HARDEST CLASS: '{class_names[hardest_idx]}' " f"(recall = {recalls[hardest_idx]:.3f})") # Find most confusing pairs pairs = find_confusion_pairs(C) if pairs: i, j, ratio_ij, ratio_ji = pairs[0] print(f"🔄 TOP CONFUSION: '{class_names[i]}' ↔ '{class_names[j]}'") print(f" {class_names[i]} → {class_names[j]}: {ratio_ij*100:.1f}%") print(f" {class_names[j]} → {class_names[i]}: {ratio_ji*100:.1f}%") # Check symmetry if abs(ratio_ij - ratio_ji) < 0.05: print(" (Symmetric - classes may be inherently similar)") else: print(" (Asymmetric - investigate feature space)") print("=" * 60)Effective visualization transforms confusion matrices from tables of numbers into actionable insights. Key visualization principles include using color scales that highlight errors, normalizing appropriately, and annotating with percentages.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
import numpy as npimport matplotlib.pyplot as pltimport seaborn as snsfrom typing import List, Optional def plot_confusion_matrix( C: np.ndarray, class_names: List[str] = None, normalize: str = None, title: str = "Confusion Matrix", cmap: str = "Blues", figsize: tuple = (10, 8)) -> plt.Figure: """ Create publication-quality confusion matrix visualization. Parameters ---------- C : ndarray Confusion matrix class_names : list Class labels for axes normalize : {'true', 'pred', 'all', None} Normalization to apply before plotting title : str Plot title cmap : str Colormap name figsize : tuple Figure size Returns ------- fig : matplotlib Figure """ K = C.shape[0] if class_names is None: class_names = [f"Class {i}" for i in range(K)] # Apply normalization if normalize is not None: C_plot = normalize_confusion_matrix(C, normalize) fmt = '.2%' if normalize else 'd' else: C_plot = C.astype(float) fmt = '.0f' fig, ax = plt.subplots(figsize=figsize) # Create heatmap im = ax.imshow(C_plot, interpolation='nearest', cmap=cmap) ax.set_title(title, fontsize=14, fontweight='bold') # Add colorbar cbar = fig.colorbar(im, ax=ax, shrink=0.8) cbar.ax.set_ylabel('Count' if normalize is None else 'Proportion') # Set labels ax.set_xticks(range(K)) ax.set_yticks(range(K)) ax.set_xticklabels(class_names, rotation=45, ha='right') ax.set_yticklabels(class_names) ax.set_xlabel('Predicted Label', fontsize=12) ax.set_ylabel('True Label', fontsize=12) # Add text annotations thresh = C_plot.max() / 2 for i in range(K): for j in range(K): value = C_plot[i, j] text = f"{value:.2%}" if normalize else f"{int(C[i, j])}" color = "white" if value > thresh else "black" ax.text(j, i, text, ha="center", va="center", color=color, fontsize=9) plt.tight_layout() return fig def plot_error_analysis(C: np.ndarray, class_names: List[str]): """Create error-focused visualization.""" fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Left: Per-class recall ax1 = axes[0] row_sums = C.sum(axis=1) recalls = np.diag(C) / np.where(row_sums > 0, row_sums, 1) colors = ['green' if r > 0.8 else 'orange' if r > 0.5 else 'red' for r in recalls] ax1.barh(class_names, recalls, color=colors) ax1.set_xlim(0, 1) ax1.set_xlabel('Recall') ax1.set_title('Per-class Recall') ax1.axvline(0.8, color='green', linestyle='--', alpha=0.5, label='Good') ax1.axvline(0.5, color='red', linestyle='--', alpha=0.5, label='Poor') # Right: Per-class precision ax2 = axes[1] col_sums = C.sum(axis=0) precisions = np.diag(C) / np.where(col_sums > 0, col_sums, 1) colors = ['green' if p > 0.8 else 'orange' if p > 0.5 else 'red' for p in precisions] ax2.barh(class_names, precisions, color=colors) ax2.set_xlim(0, 1) ax2.set_xlabel('Precision') ax2.set_title('Per-class Precision') plt.tight_layout() return figYou now have a complete understanding of multi-class confusion matrices. Next, we'll explore One-vs-Rest metrics—a powerful strategy for reducing multi-class problems to a series of binary problems.