Loading learning content...
The previous page introduced Large Margin Nearest Neighbors (LMNN)—a powerful method for learning a global Mahalanobis distance. But LMNN represents just one approach in a rich landscape of metric learning techniques.
The fundamental question of metric learning: Given labeled training data, how do we learn a distance function that best captures semantic similarity for downstream tasks?
This question admits many answers, varying along several dimensions:
Each design choice creates a different method suited to different problem characteristics. Mastering the landscape enables you to select the right tool for each situation.
By completing this page, you will understand the taxonomy of metric learning methods, compare global vs. local approaches, learn Siamese networks and contrastive loss for nonlinear embeddings, understand the connections between metric learning and representation learning, and know how to select the right method for your problem.
The field of metric learning has grown substantially since its origins in the 1990s. A clear taxonomy helps navigate the options.
Linear Methods: Learn a linear transformation $L: \mathbb{R}^d \rightarrow \mathbb{R}^r$
Kernel Methods: Implicit nonlinear mapping via kernel trick
Deep Methods: Learn nonlinear embedding via neural networks
| Method | Type | Supervision | Loss | Scalability |
|---|---|---|---|---|
| LMNN | Linear | Classes | Margin (triplet) | O(n²d) |
| NCA | Linear | Classes | Softmax KL | O(n²d) |
| ITML | Linear | Pairs | LogDet | O(constraints × d²) |
| KISSME | Linear | Pairs | Likelihood ratio | O(nd²) |
| Kernel LMNN | Kernel | Classes | Margin (triplet) | O(n³) |
| Siamese Net | Deep | Pairs | Contrastive | Batch × epochs |
| Triplet Net | Deep | Triplets | Triplet margin | Batch × epochs |
Fully Supervised (Class Labels): Each sample has a class label. Same-class pairs are similar; different-class pairs are dissimilar.
Pairwise Constraints: Only know "these two are similar" or "these two are dissimilar" without global class structure.
Triplet Constraints: Only know "A is closer to B than to C" (relative comparisons).
Weakly Supervised / Self-Supervised: Learn from data structure without explicit labels (e.g., augmentations of same image should be similar).
Global Metric: One metric for entire feature space
Local Metric: Different metrics for different regions
Discriminative Metric: Optimized for classification performance
Generative Metric: Models data distribution
Start with linear methods (LMNN, NCA) unless you have evidence of nonlinear structure. Move to kernel methods for moderate nonlinearity with limited data. Use deep methods when you have abundant data and complex patterns. Local methods are rarely needed unless you have evidence of region-specific optimal metrics.
ITML, introduced by Davis et al. (2007), takes a fundamentally different approach from LMNN. Instead of maximizing margins, it minimizes the KL-divergence between the learned Mahalanobis distance and a prior, subject to pairwise constraints.
Given:
ITML Objective:
$$\min_M D_{KL}(p(x; M) | p(x; M_0))$$
subject to:
The KL-divergence between Gaussians with precision matrices $M$ and $M_0$ is:
$$D_{KL} = \text{tr}(M^{-1} M_0) - \log\det(M^{-1} M_0) - d$$
This is known as the LogDet divergence.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
import numpy as npfrom scipy.linalg import inv, det def itml(X, similar_pairs, dissimilar_pairs, u=1.0, l=4.0, gamma=1.0, max_iter=100, tol=1e-3): """ Information-Theoretic Metric Learning Parameters: ----------- X : ndarray of shape (n_samples, n_features) similar_pairs : list of (i, j) tuples that should be close dissimilar_pairs : list of (i, j) tuples that should be far u : float, upper bound for similar pair distances l : float, lower bound for dissimilar pair distances gamma : float, slack parameter max_iter : int tol : float, convergence tolerance Returns: -------- M : ndarray, learned Mahalanobis matrix """ n, d = X.shape # Initialize with identity (prior) M = np.eye(d) # Create constraint list: (i, j, target_dist, is_upper_bound) constraints = [] for (i, j) in similar_pairs: constraints.append((i, j, u, True)) # upper bound for (i, j) in dissimilar_pairs: constraints.append((i, j, l, False)) # lower bound # Bregman projection iterations for iteration in range(max_iter): M_old = M.copy() for (i, j, bound, is_upper) in constraints: diff = X[i] - X[j] current_dist = diff @ M @ diff if is_upper: # Similar pair: want distance <= u if current_dist > bound: # Project to satisfy constraint alpha = min(gamma, (current_dist - bound) / (current_dist * np.outer(diff, diff).flatten() @ M.flatten() + 1e-10)) M = M - alpha * (M @ np.outer(diff, diff) @ M) else: # Dissimilar pair: want distance >= l if current_dist < bound: alpha = min(gamma, (bound - current_dist) / (current_dist * np.outer(diff, diff).flatten() @ M.flatten() + 1e-10)) M = M + alpha * (M @ np.outer(diff, diff) @ M) # Check convergence if np.linalg.norm(M - M_old) < tol: print(f"ITML converged at iteration {iteration}") break # Ensure PSD eigenvalues, eigenvectors = np.linalg.eigh(M) eigenvalues = np.maximum(eigenvalues, 1e-8) M = eigenvectors @ np.diag(eigenvalues) @ eigenvectors.T return M def generate_constraints_from_labels(y, n_similar=1000, n_dissimilar=1000): """Generate similar/dissimilar pairs from class labels""" n = len(y) similar_pairs = [] dissimilar_pairs = [] for _ in range(n_similar): i = np.random.randint(n) same_class = np.where(y == y[i])[0] same_class = same_class[same_class != i] if len(same_class) > 0: j = np.random.choice(same_class) similar_pairs.append((i, j)) for _ in range(n_dissimilar): i = np.random.randint(n) diff_class = np.where(y != y[i])[0] if len(diff_class) > 0: j = np.random.choice(diff_class) dissimilar_pairs.append((i, j)) return similar_pairs, dissimilar_pairs| Aspect | LMNN | ITML |
|---|---|---|
| Input | Class labels | Pairwise constraints |
| Objective | Margin maximization | Divergence minimization |
| Regularization | Implicit (trace norm) | Explicit (prior $M_0$) |
| Constraints | Triplets (anchor, target, imposter) | Pairs (similar, dissimilar) |
| Optimization | SDP or gradient descent | Bregman projections |
| Complexity | O(n²d) | O(c × d²) where c = #constraints |
When to prefer ITML:
When to prefer LMNN:
ITML's pairwise constraint formulation makes it ideal for active learning scenarios. You can iteratively query users for similarity judgments on ambiguous pairs, add those constraints to ITML, and update the metric. This is more natural than asking users to assign absolute class labels.
Global methods like LMNN assume one transformation is optimal everywhere. But consider a dataset where:
No single global metric can handle this. Local metric learning addresses this by learning region-specific transformations.
1. Per-Sample Metrics: Learn a separate metric $M_i$ for each training sample $\mathbf{x}_i$. When classifying a test point, use the metric of its nearest training neighbor.
$$d_i(\mathbf{x}, \mathbf{y}) = \sqrt{(\mathbf{x} - \mathbf{y})^T M_i (\mathbf{x} - \mathbf{y})}$$
Challenge: Too many parameters ($n × d²$); prone to overfitting.
2. Cluster-Based Metrics: Partition training data into $K$ clusters. Learn one metric $M_k$ per cluster.
$$d(\mathbf{x}, \mathbf{y}) = \sqrt{(\mathbf{x} - \mathbf{y})^T M_{c(\mathbf{x})} (\mathbf{x} - \mathbf{y})}$$
where $c(\mathbf{x})$ is the cluster assignment of $\mathbf{x}$.
Challenge: How to choose clusters? How to handle points near cluster boundaries?
3. Multi-Metric LMNN (MM-LMNN): Learn one metric per class. Each class's metric optimizes for distinguishing that class from others.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
import numpy as npfrom scipy.spatial.distance import cdist def multi_metric_lmnn(X, y, k=3, max_iter=50): """ Multi-Metric LMNN: Learn one metric per class. For each class c, learn M_c optimized for distinguishing class c from other classes. Parameters: ----------- X : ndarray of shape (n_samples, n_features) y : ndarray of shape (n_samples,) k : int, number of target neighbors Returns: -------- metrics : dict mapping class labels to Mahalanobis matrices """ classes = np.unique(y) n_classes = len(classes) d = X.shape[1] # Initialize with global LMNN or identity metrics = {c: np.eye(d) for c in classes} for iteration in range(max_iter): for c in classes: # Extract class c points class_mask = y == c X_c = X[class_mask] n_c = len(X_c) # Other class points X_other = X[~class_mask] # Current metric for this class M_c = metrics[c] # Compute gradient for class c's metric grad = np.zeros((d, d)) for i, x_i in enumerate(X_c): # Find k target neighbors (same class) dists_same = cdist([x_i @ np.linalg.cholesky(M_c).T], X_c @ np.linalg.cholesky(M_c).T)[0] dists_same[i] = np.inf # Exclude self target_indices = np.argsort(dists_same)[:k] for j_idx in target_indices: x_j = X_c[j_idx] diff_ij = x_i - x_j # Pull term gradient grad += 0.5 * np.outer(diff_ij, diff_ij) d_ij_sq = diff_ij @ M_c @ diff_ij # Check imposters (other class points) for x_l in X_other: diff_il = x_i - x_l d_il_sq = diff_il @ M_c @ diff_il # Margin violation if d_ij_sq + 1 > d_il_sq: grad += 0.5 * (np.outer(diff_ij, diff_ij) - np.outer(diff_il, diff_il)) # Gradient step metrics[c] = project_psd(M_c - 0.001 * grad) return metrics def predict_multi_metric(X_train, y_train, X_test, metrics, k=3): """Predict using class-specific metrics""" classes = list(metrics.keys()) predictions = [] for x_test in X_test: # For each class, compute distances using that class's metric class_votes = {c: 0 for c in classes} for c in classes: M_c = metrics[c] # Distance from x_test to all training points using M_c dists = np.array([np.sqrt((x_test - x_train) @ M_c @ (x_test - x_train)) for x_train in X_train]) # k nearest neighbors nearest_k = np.argsort(dists)[:k] # Vote based on neighbor labels for idx in nearest_k: class_votes[y_train[idx]] += 1 predictions.append(max(class_votes.keys(), key=lambda c: class_votes[c])) return predictions def project_psd(M): """Project onto positive semidefinite cone""" eigenvalues, eigenvectors = np.linalg.eigh(M) eigenvalues = np.maximum(eigenvalues, 1e-8) return eigenvectors @ np.diag(eigenvalues) @ eigenvectors.TLocal metric learning provides benefit when:
Evidence for needing local metrics:
Local metrics have many more parameters than global metrics (K × d² vs. d²). With limited data, they overfit easily. Always use cross-validation to compare local vs. global approaches. If global LMNN achieves similar performance, prefer it for its simplicity and robustness.
Deep metric learning replaces the linear transformation $L$ with a neural network $f_\theta$. This enables learning highly nonlinear embeddings that capture complex semantic relationships.
Architecture: Two identical networks (weight-sharing) process pairs of inputs. The output embeddings are compared via a distance function.
$$d(\mathbf{x}_1, \mathbf{x}2) = ||f\theta(\mathbf{x}1) - f\theta(\mathbf{x}_2)||_2$$
Contrastive Loss (Chopra et al., 2005):
$$\mathcal{L} = (1-y) \cdot \frac{1}{2} d^2 + y \cdot \frac{1}{2} \max(0, m - d)^2$$
where:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
import torchimport torch.nn as nnimport torch.nn.functional as F class SiameseNetwork(nn.Module): """ Siamese Network for metric learning. Uses weight-shared encoder to embed pairs, then computes Euclidean distance for similarity. """ def __init__(self, input_dim, embedding_dim=128, hidden_dims=[256, 128]): super().__init__() # Build encoder network layers = [] prev_dim = input_dim for hidden_dim in hidden_dims: layers.append(nn.Linear(prev_dim, hidden_dim)) layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.ReLU()) layers.append(nn.Dropout(0.2)) prev_dim = hidden_dim layers.append(nn.Linear(prev_dim, embedding_dim)) self.encoder = nn.Sequential(*layers) def forward_one(self, x): """Encode a single input""" return self.encoder(x) def forward(self, x1, x2): """Encode a pair and compute distance""" e1 = self.forward_one(x1) e2 = self.forward_one(x2) distance = F.pairwise_distance(e1, e2) return distance, e1, e2 class ContrastiveLoss(nn.Module): """ Contrastive loss for Siamese networks. Similar pairs: minimize distance Dissimilar pairs: push apart until margin """ def __init__(self, margin=1.0): super().__init__() self.margin = margin def forward(self, distance, label): """ Parameters: ----------- distance : Tensor of shape (batch_size,) Euclidean distances between pairs label : Tensor of shape (batch_size,) 0 = similar (same class), 1 = dissimilar (different class) """ similar_loss = (1 - label) * distance.pow(2) dissimilar_loss = label * F.relu(self.margin - distance).pow(2) return (similar_loss + dissimilar_loss).mean() def train_siamese(model, dataloader, epochs=50, lr=1e-3): """Train Siamese network with contrastive loss""" optimizer = torch.optim.Adam(model.parameters(), lr=lr) criterion = ContrastiveLoss(margin=1.0) model.train() for epoch in range(epochs): total_loss = 0 for x1, x2, labels in dataloader: optimizer.zero_grad() distances, _, _ = model(x1, x2) loss = criterion(distances, labels) loss.backward() optimizer.step() total_loss += loss.item() if epoch % 10 == 0: avg_loss = total_loss / len(dataloader) print(f"Epoch {epoch}: Loss = {avg_loss:.4f}") return modelTriplet networks process triplets: (anchor, positive, negative). The goal: push anchor closer to positive than to negative.
Triplet Loss:
$$\mathcal{L} = \sum_{(a,p,n)} \max(0, ||f(a) - f(p)||^2 - ||f(a) - f(n)||^2 + \alpha)$$
This is exactly LMNN's push term, with neural network $f$ instead of linear $L$.
Triplet Mining: The key challenge is selecting informative triplets:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
import torchimport torch.nn as nnimport torch.nn.functional as F class TripletLoss(nn.Module): """ Triplet loss for metric learning. Anchor should be closer to positive than to negative by margin. """ def __init__(self, margin=1.0): super().__init__() self.margin = margin def forward(self, anchor, positive, negative): """ Parameters: ----------- anchor, positive, negative : Tensors of shape (batch_size, embedding_dim) """ pos_dist = F.pairwise_distance(anchor, positive) neg_dist = F.pairwise_distance(anchor, negative) loss = F.relu(pos_dist - neg_dist + self.margin) return loss.mean() def mine_semi_hard_triplets(embeddings, labels, margin=1.0): """ Online semi-hard triplet mining within a batch. Semi-hard: negative is further than positive but within margin i.e., d(a,p) < d(a,n) < d(a,p) + margin """ batch_size = len(labels) distances = torch.cdist(embeddings, embeddings) triplets = [] for anchor_idx in range(batch_size): anchor_label = labels[anchor_idx] # Positive indices: same class as anchor positive_mask = (labels == anchor_label) positive_mask[anchor_idx] = False positive_indices = torch.where(positive_mask)[0] # Negative indices: different class negative_mask = (labels != anchor_label) negative_indices = torch.where(negative_mask)[0] if len(positive_indices) == 0 or len(negative_indices) == 0: continue for pos_idx in positive_indices: pos_dist = distances[anchor_idx, pos_idx] # Find semi-hard negatives for neg_idx in negative_indices: neg_dist = distances[anchor_idx, neg_idx] # Semi-hard condition if pos_dist < neg_dist < pos_dist + margin: triplets.append((anchor_idx, pos_idx.item(), neg_idx.item())) return triplets def batch_hard_triplet_loss(embeddings, labels, margin=1.0): """ Batch hard triplet loss: for each anchor, use hardest positive and negative. """ batch_size = len(labels) distances = torch.cdist(embeddings, embeddings) # Masks labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) labels_not_equal = ~labels_equal # For each anchor, find hardest positive (same class, max distance) # Mask out negatives with -inf before taking max pos_distances = distances.clone() pos_distances[labels_not_equal] = -float('inf') pos_distances.fill_diagonal_(-float('inf')) hardest_positive_dist, _ = pos_distances.max(dim=1) # For each anchor, find hardest negative (diff class, min distance) neg_distances = distances.clone() neg_distances[labels_equal] = float('inf') hardest_negative_dist, _ = neg_distances.min(dim=1) # Triplet loss loss = F.relu(hardest_positive_dist - hardest_negative_dist + margin) return loss.mean()Applying metric learning effectively requires a systematic approach. Here's a workflow that balances rigor with practicality.
Before any metric learning, establish baselines:
If baselines already achieve > 95% accuracy, metric learning may not be worth the complexity.
Understand your features before learning metrics:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
import numpy as npfrom sklearn.model_selection import cross_val_scorefrom sklearn.neighbors import KNeighborsClassifierfrom sklearn.preprocessing import StandardScalerfrom sklearn.feature_selection import mutual_info_classif def metric_learning_workflow(X, y, verbose=True): """ Complete workflow for evaluating whether and how to apply metric learning. Returns recommendations and baseline results. """ n, d = X.shape n_classes = len(np.unique(y)) results = {'recommendations': [], 'baselines': {}} if verbose: print(f"Dataset: {n} samples, {d} features, {n_classes} classes") # Step 1: Baseline evaluation scaler = StandardScaler() X_scaled = scaler.fit_transform(X) knn_baseline = KNeighborsClassifier(n_neighbors=5) baseline_scores = cross_val_score(knn_baseline, X_scaled, y, cv=5) baseline_acc = np.mean(baseline_scores) results['baselines']['euclidean_knn'] = baseline_acc if verbose: print(f"Baseline k-NN accuracy: {baseline_acc:.3f} (+/- {np.std(baseline_scores):.3f})") # Step 2: Feature analysis mi_scores = mutual_info_classif(X_scaled, y) useful_features = np.sum(mi_scores > 0.01) if verbose: print(f"Features with MI > 0.01: {useful_features}/{d}") # Step 3: Recommendations if baseline_acc > 0.95: results['recommendations'].append("Baseline is strong. Metric learning likely offers marginal improvement.") if useful_features < d / 2: results['recommendations'].append(f"Many irrelevant features ({d - useful_features}/{d}). LMNN should help significantly.") if n < 10 * d * (d + 1) / 2: results['recommendations'].append("Limited samples for full Mahalanobis. Consider low-rank or diagonal constraints.") if n > 10000 and d > 100: results['recommendations'].append("Large scale problem. Consider deep metric learning or scalable approximations.") # Step 4: Quick LMNN test try: from metric_learn import LMNN lmnn = LMNN(k=5, max_iter=50) lmnn.fit(X_scaled, y) X_lmnn = lmnn.transform(X_scaled) lmnn_knn = KNeighborsClassifier(n_neighbors=5) lmnn_scores = cross_val_score(lmnn_knn, X_lmnn, y, cv=5) lmnn_acc = np.mean(lmnn_scores) results['baselines']['lmnn_knn'] = lmnn_acc improvement = lmnn_acc - baseline_acc if verbose: print(f"LMNN k-NN accuracy: {lmnn_acc:.3f} (improvement: {improvement:+.3f})") if improvement > 0.02: results['recommendations'].append(f"LMNN improves accuracy by {improvement:.3f}. Recommend full optimization.") else: results['recommendations'].append("LMNN provides minimal improvement. May not be worth the overhead.") except ImportError: results['recommendations'].append("metric-learn not installed. Install for automatic LMNN evaluation.") if verbose: print("Recommendations:") for rec in results['recommendations']: print(f" • {rec}") return resultsBased on your analysis, select an appropriate method:
| Scenario | Recommended Method |
|---|---|
| d < 100, n > 1000, features need weighting | LMNN or NCA |
| Pairwise constraints available | ITML or KISSME |
| Region-specific metrics needed | Multi-Metric LMNN |
| Large n (> 50K), complex patterns | Deep metric learning |
| High d, limited n | Low-rank LMNN or feature selection first |
Key hyperparameters to tune:
For LMNN:
For Deep Metric Learning:
Always evaluate generalization:
In practice, 80% of metric learning's benefit comes from simple feature preprocessing (standardization, log transforms) and basic LMNN. Only invest in complex methods (local metrics, deep learning) when simpler approaches demonstrably underperform and you have sufficient data to train them reliably.
We've surveyed the rich landscape of metric learning—from classical linear methods to modern deep approaches. Let's consolidate the key insights:
Metric learning improves k-NN by learning how to measure similarity. But what if we want to combine multiple k-NN classifiers, each with different perspectives on the data?
The next page explores KNN Ensembles—methods that combine multiple k-NN classifiers using different metrics, feature subsets, or parameter settings to achieve robustness and improved accuracy beyond any single classifier.
You now have a comprehensive understanding of metric learning for k-NN—from the taxonomy of methods through specific algorithms to practical workflow. You can select the right approach for your problem, implement key methods, and validate improvements rigorously. Next, we explore combining KNN classifiers through ensemble methods.