Loading learning content...
When a child learns to recognize a new animal from a few examples, they don't update millions of neural connections. Instead, they form a mental prototype—an abstract representation of the concept—and compare new instances to it. 'Does this creature look like what I learned an okapi looks like?'
Prototypical Networks (Snell et al., 2017) operationalize this intuition. Rather than adapting model parameters (like MAML), they learn an embedding space where:
This elegantly simple approach achieves competitive performance with MAML while being dramatically simpler to implement and faster at inference time. There are no gradient steps during adaptation—just embedding computation and nearest-neighbor classification.
By completing this page, you will understand: (1) The mathematical formulation of Prototypical Networks, (2) Why mean embeddings work as class prototypes, (3) The connection to mixture density estimation and Bregman divergences, (4) Implementation details and training, (5) Relation to other metric learning approaches, and (6) When Prototypical Networks are the right choice.
Prototypical Networks belong to the metric-based family of meta-learning methods. The core idea: learn a representation (embedding) space where distance corresponds to semantic dissimilarity, then use a simple distance-based classifier.
Key Components:
Embedding function $f_\theta: \mathcal{X} \rightarrow \mathbb{R}^d$
Prototypes $c_k \in \mathbb{R}^d$ for each class $k$
Distance function $d: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}$
Classification rule for query $x_q$: $$p(y = k | x_q) = \frac{\exp(-d(f_\theta(x_q), c_k))}{\sum_{k'} \exp(-d(f_\theta(x_q), c_{k'}))}$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Tuple class PrototypicalNetwork(nn.Module): """ Prototypical Networks for few-shot classification. Key insight: Classification = nearest prototype in learned embedding space. Advantages: - No gradient computation at test time (unlike MAML) - Simple, interpretable mechanism - Efficient: O(N*K) prototype computation """ def __init__(self, encoder: nn.Module, distance: str = 'euclidean'): """ Args: encoder: Neural network mapping inputs to embeddings distance: 'euclidean' or 'cosine' """ super().__init__() self.encoder = encoder self.distance = distance def compute_prototypes( self, support_embeddings: torch.Tensor, # [n_way * k_shot, embed_dim] support_labels: torch.Tensor, # [n_way * k_shot] n_way: int ) -> torch.Tensor: """ Compute class prototypes as mean of support embeddings. Args: support_embeddings: Embedded support examples support_labels: Class labels (0 to n_way-1) n_way: Number of classes Returns: prototypes: [n_way, embed_dim] tensor of class prototypes """ prototypes = torch.zeros(n_way, support_embeddings.shape[1], device=support_embeddings.device) for k in range(n_way): # Gather embeddings for class k mask = support_labels == k class_embeddings = support_embeddings[mask] # Prototype = mean embedding prototypes[k] = class_embeddings.mean(dim=0) return prototypes def compute_distances( self, query_embeddings: torch.Tensor, # [n_query, embed_dim] prototypes: torch.Tensor # [n_way, embed_dim] ) -> torch.Tensor: """ Compute distances from queries to all prototypes. Returns: distances: [n_query, n_way] tensor """ if self.distance == 'euclidean': # Squared Euclidean distance # ||q - p||^2 = ||q||^2 + ||p||^2 - 2*q·p n_query = query_embeddings.shape[0] n_way = prototypes.shape[0] distances = ( (query_embeddings ** 2).sum(dim=1, keepdim=True) + (prototypes ** 2).sum(dim=1).unsqueeze(0) - 2 * query_embeddings @ prototypes.t() ) return distances elif self.distance == 'cosine': # Cosine distance = 1 - cosine_similarity query_norm = F.normalize(query_embeddings, dim=1) proto_norm = F.normalize(prototypes, dim=1) return 1 - query_norm @ proto_norm.t() else: raise ValueError(f"Unknown distance: {self.distance}") def forward( self, support_x: torch.Tensor, # [n_way * k_shot, ...] support_y: torch.Tensor, # [n_way * k_shot] query_x: torch.Tensor, # [n_query, ...] n_way: int ) -> torch.Tensor: """ Forward pass: embed support, compute prototypes, classify queries. Returns: log_probs: [n_query, n_way] log probabilities """ # Embed support and query support_embeddings = self.encoder(support_x) query_embeddings = self.encoder(query_x) # Compute prototypes prototypes = self.compute_prototypes( support_embeddings, support_y, n_way ) # Compute distances distances = self.compute_distances(query_embeddings, prototypes) # Convert distances to log probabilities (negative distances = logits) log_probs = F.log_softmax(-distances, dim=1) return log_probs def predict( self, support_x: torch.Tensor, support_y: torch.Tensor, query_x: torch.Tensor, n_way: int ) -> torch.Tensor: """Get predicted class labels for queries.""" log_probs = self.forward(support_x, support_y, query_x, n_way) return log_probs.argmax(dim=1) # Example encoder architecture (Conv-4)class ConvEncoder(nn.Module): """ Standard Conv-4 encoder for few-shot learning. 4 convolutional blocks, each: conv → batch_norm → relu → max_pool Widely used in few-shot learning papers for fair comparison. """ def __init__(self, in_channels: int = 3, hidden_dim: int = 64, output_dim: int = 64): super().__init__() self.layer1 = self._conv_block(in_channels, hidden_dim) self.layer2 = self._conv_block(hidden_dim, hidden_dim) self.layer3 = self._conv_block(hidden_dim, hidden_dim) self.layer4 = self._conv_block(hidden_dim, output_dim) def _conv_block(self, in_channels: int, out_channels: int): return nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.MaxPool2d(2) ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x.view(x.size(0), -1) # Flatten to [batch, embed_dim]Using squared Euclidean distance with softmax over negative distances is equivalent to assuming each class follows a Gaussian distribution in embedding space, with the prototype as the mean. This probabilistic interpretation explains why the approach is principled, not just a heuristic.
The elegance of Prototypical Networks isn't just aesthetic—it's grounded in principled probabilistic and geometric reasoning.
Mixture Density Estimation View:
Prototypical Networks perform transductive inference in a generative model where each class is represented by a spherical Gaussian:
$$p(x | y = k) = \mathcal{N}(f_\theta(x) | c_k, \sigma^2 I)$$
With uniform priors over classes: $$p(y = k | x) = \frac{p(x | y = k)}{\sum_{k'} p(x | y = k')} = \frac{\exp(-\frac{1}{2\sigma^2}|f_\theta(x) - c_k|^2)}{\sum_{k'} \exp(-\frac{1}{2\sigma^2}|f_\theta(x) - c_{k'}|^2)}$$
This is exactly the Prototypical Network classification rule (with $\sigma^2$ absorbed into the distance function).
Bregman Divergences and Generalizations:
The original paper shows that Prototypical Networks can be generalized using Bregman divergences. Any Bregman divergence $d_\phi$ corresponds to a distribution from the exponential family:
$$d_\phi(z, c) = \phi(z) - \phi(c) - \nabla\phi(c)^T(z - c)$$
Key properties:
This explains the mathematical validity of using the mean as the prototype—it's not just intuitive, it's optimal under this framework.
| Distance | Formula | Corresponding Distribution | Prototype |
|---|---|---|---|
| Squared Euclidean | $|z - c|^2$ | Spherical Gaussian | Mean |
| Mahalanobis | $(z-c)^T\Sigma^{-1}(z-c)$ | General Gaussian | Mean |
| Cosine | $1 - \frac{z \cdot c}{|z||c|}$ | von Mises-Fisher | Normalized mean |
| KL Divergence | $\sum_i z_i \log(z_i/c_i)$ | Multinomial | Mean (simplex) |
While mean prototypes are optimal for Bregman divergences, some extensions use more complex prototypes: multiple prototypes per class (for multi-modal classes), learned prototype refinement, or attention-weighted prototypes. These can improve performance when class distributions are complex.
Training Prototypical Networks follows the episodic training paradigm, but with a key insight: training with more classes than at test time improves generalization.
Training Episode Structure:
The Higher-Way Training Trick:
Snell et al. found that training with more classes than testing (e.g., 60-way training for 5-way testing) significantly improves performance. Why?
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
import torchimport torch.nn as nnimport torch.optim as optimfrom typing import List, Tuple def train_protonet( model: nn.Module, train_dataset, val_dataset, n_epochs: int = 100, episodes_per_epoch: int = 100, n_way_train: int = 60, # More ways during training! n_way_val: int = 5, # Match test-time setting k_shot: int = 5, q_query: int = 5, lr: float = 0.001, lr_decay_epochs: List[int] = [50, 75], lr_decay_factor: float = 0.1,): """ Train Prototypical Networks with episodic training. Key insight: Train with more classes than test time. This forces the encoder to learn highly discriminative embeddings. """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=lr_decay_epochs, gamma=lr_decay_factor ) best_val_acc = 0.0 for epoch in range(n_epochs): model.train() epoch_loss = 0.0 epoch_acc = 0.0 for episode in range(episodes_per_epoch): # Sample training episode (high n_way) episode_data = sample_episode( train_dataset, n_way_train, k_shot, q_query ) support_x = episode_data['support_x'].to(device) support_y = episode_data['support_y'].to(device) query_x = episode_data['query_x'].to(device) query_y = episode_data['query_y'].to(device) # Forward pass log_probs = model(support_x, support_y, query_x, n_way_train) # Cross-entropy loss loss = nn.functional.nll_loss(log_probs, query_y) # Accuracy preds = log_probs.argmax(dim=1) acc = (preds == query_y).float().mean() # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() epoch_acc += acc.item() scheduler.step() avg_loss = epoch_loss / episodes_per_epoch avg_acc = epoch_acc / episodes_per_epoch print(f"Epoch {epoch + 1}: Loss={avg_loss:.4f}, Train Acc={avg_acc:.2%}") # Validation (use test-time n_way) if (epoch + 1) % 5 == 0: val_acc = evaluate_protonet(model, val_dataset, n_way_val, k_shot, q_query, device) print(f" Validation Acc ({n_way_val}-way {k_shot}-shot): {val_acc:.2%}") if val_acc > best_val_acc: best_val_acc = val_acc torch.save(model.state_dict(), 'best_protonet.pt') return model def evaluate_protonet( model: nn.Module, dataset, n_way: int, k_shot: int, q_query: int, device: torch.device, n_episodes: int = 600) -> float: """ Evaluate Prototypical Network on test episodes. NOTE: No adaptation/gradient steps needed! Just embed, compute prototypes, classify. """ model.eval() accuracies = [] with torch.no_grad(): for _ in range(n_episodes): episode_data = sample_episode(dataset, n_way, k_shot, q_query) support_x = episode_data['support_x'].to(device) support_y = episode_data['support_y'].to(device) query_x = episode_data['query_x'].to(device) query_y = episode_data['query_y'].to(device) log_probs = model(support_x, support_y, query_x, n_way) preds = log_probs.argmax(dim=1) acc = (preds == query_y).float().mean().item() accuracies.append(acc) return sum(accuracies) / len(accuracies) # Training tips and best practices"""Key training considerations: 1. HIGHER-WAY TRAINING - Train with 20-60 ways, test with 5 ways - Forces truly discriminative embeddings - ~2-3% improvement over matching train/test ways 2. EPISODE SAMPLING - Ensure class balance within episodes - Random sampling from class pools - Shuffle within support/query sets 3. DATA AUGMENTATION - Standard augmentation (crop, flip, color jitter) - Applied BEFORE episode sampling - Significant impact on miniImageNet 4. LEARNING RATE SCHEDULE - Cosine annealing or step decay - Decay after 50-75% of training 5. EMBEDDING NORMALIZATION - L2 normalize embeddings before distance computation - Can help with cosine distance - Optional for Euclidean 6. LABEL SMOOTHING - Soft labels prevent overconfident prototypes - Typical value: 0.1"""Training with 60 ways for a 5-way test setting can improve accuracy by 2-3%. The intuition: discriminating among 60 classes requires more nuanced embeddings than discriminating among 5. The encoder must 'try harder' during training, yielding better representations.
Prototypical Networks are part of a family of metric-based meta-learning methods. Understanding the relationships and trade-offs helps choose the right approach for each problem.
Matching Networks (Vinyals et al., 2016) preceded Prototypical Networks and use an attention-weighted combination of support labels:
$$p(y = k | x_q) \propto \sum_{(x_i, y_i) \in S} a(x_q, x_i) \cdot \mathbb{1}[y_i = k]$$
where $a(x_q, x_i)$ is an attention kernel (typically softmax over cosine similarities).
Key differences from ProtoNet:
When to prefer Matching Networks:
| Method | Class Representation | Distance | Accuracy | Inference Speed |
|---|---|---|---|---|
| Prototypical Net | Mean prototype | Euclidean (fixed) | ~68% | Very fast |
| Matching Net | All examples | Cosine attention | ~63% | Fast |
| Relation Net | Mean prototype | Learned CNN | ~65% | Medium |
| Siamese Net | Pairwise | Learned | ~61% | Slow (O(n²)) |
Despite being simpler, Prototypical Networks often outperform more complex alternatives. The mean prototype is a highly effective regularizer—it encourages embeddings where class means are sufficient statistics. This prevents overfitting to individual support examples.
The simplicity of Prototypical Networks makes them an excellent foundation for extensions. Researchers have developed numerous variants that address specific limitations.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
import torchimport torch.nn as nnimport torch.nn.functional as F class TransductiveProtoNet(nn.Module): """ Transductive Prototypical Network. Uses query set information to refine prototypes through iterative pseudo-labeling. Improves accuracy by ~1-2% over inductive ProtoNet. """ def __init__(self, encoder: nn.Module, n_refine_steps: int = 3): super().__init__() self.encoder = encoder self.n_refine_steps = n_refine_steps def forward(self, support_x, support_y, query_x, n_way): # Initial prototypes from support only support_embeddings = self.encoder(support_x) query_embeddings = self.encoder(query_x) prototypes = self.compute_prototypes( support_embeddings, support_y, n_way ) # Iteratively refine prototypes using high-confidence query predictions for step in range(self.n_refine_steps): # Soft assignment of queries to classes distances = self.compute_distances(query_embeddings, prototypes) soft_assignments = F.softmax(-distances, dim=1) # [n_query, n_way] # Weight by confidence (entropy-based) entropy = -(soft_assignments * torch.log(soft_assignments + 1e-8)).sum(1) confidence = torch.exp(-entropy) # High entropy = low confidence # Refine prototypes with confident query embeddings for k in range(n_way): # Weighted query contribution query_contribution = ( soft_assignments[:, k:k+1] * confidence.unsqueeze(1) * query_embeddings ).sum(0) query_weight = (soft_assignments[:, k] * confidence).sum() # Support contribution mask = support_y == k support_contribution = support_embeddings[mask].sum(0) support_weight = mask.float().sum() # Updated prototype prototypes[k] = ( support_contribution + 0.5 * query_contribution ) / (support_weight + 0.5 * query_weight) # Final classification distances = self.compute_distances(query_embeddings, prototypes) return F.log_softmax(-distances, dim=1) class MultiPrototypeNetwork(nn.Module): """ Multiple prototypes per class for handling multi-modal distributions. Instead of one mean prototype, uses K-means to find multiple cluster centers per class. Classification uses minimum distance to any prototype. """ def __init__(self, encoder: nn.Module, prototypes_per_class: int = 3): super().__init__() self.encoder = encoder self.prototypes_per_class = prototypes_per_class def compute_multi_prototypes( self, support_embeddings, support_labels, n_way ): """Compute multiple prototypes per class using K-means.""" all_prototypes = [] for k in range(n_way): mask = support_labels == k class_embeddings = support_embeddings[mask] # Run K-means (simplified: use random init + assignments) n_support_per_class = class_embeddings.shape[0] n_prototypes = min(self.prototypes_per_class, n_support_per_class) if n_prototypes == 1: prototypes = class_embeddings.mean(0, keepdim=True) else: # K-means clustering prototypes = self._kmeans(class_embeddings, n_prototypes) all_prototypes.append(prototypes) return all_prototypes # List of [k_i, embed_dim] tensors def forward(self, support_x, support_y, query_x, n_way): support_embeddings = self.encoder(support_x) query_embeddings = self.encoder(query_x) # Multiple prototypes per class multi_prototypes = self.compute_multi_prototypes( support_embeddings, support_y, n_way ) # Minimum distance to any prototype in class n_query = query_embeddings.shape[0] class_distances = torch.zeros(n_query, n_way, device=query_embeddings.device) for k, class_protos in enumerate(multi_prototypes): # Distance to each prototype dists = self.compute_distances(query_embeddings, class_protos) # Minimum distance to any prototype = class distance class_distances[:, k] = dists.min(dim=1)[0] return F.log_softmax(-class_distances, dim=1)Transductive extensions that use query set information during classification typically improve accuracy by 1-3%. The trade-off: you must have all query examples available simultaneously, which isn't always possible in streaming scenarios.
Prototypical Networks and MAML represent fundamentally different philosophies about how to achieve few-shot learning. Understanding their trade-offs is crucial for choosing the right approach.
| Aspect | Prototypical Networks | MAML |
|---|---|---|
| Core mechanism | Learn distance metric | Learn initialization |
| Adaptation at test time | None (just compute prototypes) | Gradient descent steps |
| Inference speed | Very fast (single forward pass) | Slower (K forward + backward passes) |
| Training complexity | Simple episodic training | Bi-level optimization |
| Memory at training | Standard | High (computational graphs for each step) |
| What's learned | Embedding function | Parameter initialization |
| Handles new tasks by | Computing new prototypes | Fine-tuning from initialization |
When the choice matters:
Classification with good encoder: ProtoNet often wins. If you have a strong pre-trained encoder (e.g., ResNet on ImageNet), learning prototypes in that space is effective and fast.
Reinforcement learning: MAML wins. Policy networks need parameter updates, not prototype classification.
Cross-domain adaptation: MAML may have an edge. Parameter adaptation can adjust representations more flexibly than fixed metrics.
Resource-constrained deployment: ProtoNet wins. No gradient computation means faster, cheaper inference.
Very low shot (1-shot): Often comparable. Both struggle; the bottleneck is data scarcity, not algorithmic approach.
Recent research suggests that a simple baseline—pre-trained encoder + linear classifier trained on support examples—often matches or beats both ProtoNet and MAML. This highlights that much of few-shot performance comes from the encoder quality, not the adaptation mechanism. Always compare against simple baselines.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
# Common ProtoNet implementation gotchas and fixes # ❌ Wrong: Computing Euclidean distance with broadcasting issuesdef bad_distance(query, prototypes): # This doesn't work correctly for batched computation return torch.norm(query - prototypes, dim=-1) # ✅ Correct: Efficient squared Euclidean distancedef good_distance(query, prototypes): # query: [n_query, embed_dim] # prototypes: [n_way, embed_dim] # output: [n_query, n_way] n_query = query.shape[0] n_way = prototypes.shape[0] # ||q - p||^2 = ||q||^2 + ||p||^2 - 2*q·p query_sq = (query ** 2).sum(dim=1).view(n_query, 1) proto_sq = (prototypes ** 2).sum(dim=1).view(1, n_way) cross = torch.mm(query, prototypes.t()) return query_sq + proto_sq - 2 * cross # ❌ Wrong: Using raw logits for lossdef bad_loss(distances, labels): logits = -distances # Higher = closer = more likely return F.cross_entropy(logits, labels) # Works but numerically unstable # ✅ Correct: Use log_softmax for numerical stabilitydef good_loss(distances, labels): log_probs = F.log_softmax(-distances, dim=1) return F.nll_loss(log_probs, labels) # ❌ Wrong: Prototype computation with wrong dimensionsdef bad_prototypes(embeddings, labels, n_way): prototypes = [] for k in range(n_way): # This fails if no examples for class k class_mean = embeddings[labels == k].mean(0) prototypes.append(class_mean) return torch.stack(prototypes) # ✅ Correct: Safe prototype computationdef good_prototypes(embeddings, labels, n_way): embed_dim = embeddings.shape[1] prototypes = torch.zeros(n_way, embed_dim, device=embeddings.device) for k in range(n_way): mask = labels == k if mask.sum() == 0: # Handle edge case: no examples for class k # (shouldn't happen with proper episode sampling) prototypes[k] = torch.zeros(embed_dim, device=embeddings.device) else: prototypes[k] = embeddings[mask].mean(dim=0) return prototypes # Training hyperparameters that work wellDEFAULT_HPARAMS = { 'encoder': 'ResNet-12', # For miniImageNet 'embed_dim': 640, # ResNet-12 output 'distance': 'euclidean', # Not cosine 'n_way_train': 30, # Higher than test 'n_way_test': 5, 'k_shot': 5, 'q_query': 15, 'lr': 0.001, 'lr_decay_epochs': [60, 80], 'n_epochs': 100, 'episodes_per_epoch': 100,}You now have comprehensive understanding of Prototypical Networks—from intuition to mathematics to implementation. Next, we'll survey meta-learning applications beyond image classification, seeing how these techniques impact NLP, robotics, drug discovery, and more.
Coming Next: Page 4 explores Meta-Learning Applications—how the techniques we've learned apply to natural language processing, reinforcement learning, healthcare, and other domains where learning from limited data is essential.