Loading learning content...
Imagine you're about to sprint across unfamiliar terrain. Would you rather start from a random position, or from a carefully chosen launching point that provides the best angle for any direction you might need to run? Model-Agnostic Meta-Learning (MAML) embodies this simple but powerful insight: the right starting point makes adaptation dramatically faster.
Introduced by Finn, Abbeel, and Levine in 2017, MAML became one of the most influential meta-learning algorithms by demonstrating that a model's initial parameters can be explicitly optimized for fast adaptation to new tasks. Unlike methods that learn task-specific components, MAML is model-agnostic—it works with any model trainable by gradient descent, from simple classifiers to deep neural networks to reinforcement learning policies.
The elegance of MAML lies in its simplicity: it doesn't add new components to the model architecture or learning procedure. Instead, it shifts the optimization objective from 'find parameters that solve this task' to 'find parameters from which a few gradient steps produce excellent task-specific solutions.'
By completing this page, you will understand: (1) The MAML algorithm in complete detail, (2) The mathematical derivation of the bi-level optimization, (3) First-order approximations (FOMAML, Reptile) for computational efficiency, (4) Practical implementation considerations, (5) MAML++ and other important extensions, and (6) When to use MAML vs. alternative approaches.
Before diving into mathematics, let's build intuition for what MAML accomplishes.
The Problem with Random Initialization:
When training a neural network from random initialization, early gradient steps are noisy. The network must first learn basic features (edges, textures for images) before it can learn task-specific patterns. This wastes many gradient steps on 'bootstrapping' rather than task-specific learning.
The Pre-training Approach:
Pre-training (e.g., on ImageNet) provides better initialization by learning general features. But pre-training optimizes for performance on the pre-training task, not for adaptability to new tasks. A model pre-trained to classify 1000 ImageNet classes isn't optimized to quickly adapt to 5-way few-shot classification.
The MAML Insight:
What if we explicitly optimize the initialization for adaptability? We want parameters $\theta$ such that a few gradient steps on a new task $\mathcal{T}_i$ produce excellent task-specific parameters $\theta'_i$.
Mathematically, for any new task: $$\theta'i = \theta - \alpha \nabla\theta \mathcal{L}_{\mathcal{T}_i}(\theta)$$
We want $\theta$ chosen such that $\theta'_i$ performs well on $\mathcal{T}_i$—for all tasks $\mathcal{T}_i$ we might encounter.
Imagine the loss landscape as a mountain range. Each task has its optimum in a different valley. Random initialization drops you on a random peak. Pre-training puts you in a specific valley (good for that task, far from others). MAML finds a ridge position from which you can quickly descend into any nearby valley—a position of maximum adaptability.
What Makes MAML Special:
Model-agnostic: Works with any differentiable model—CNNs, RNNs, Transformers, policy networks. No architectural changes needed.
Algorithm-agnostic: Uses standard gradient descent both for task adaptation (inner loop) and meta-learning (outer loop).
Learned learning rate sensitivity: The initialization inherently encodes which parameters should change (high gradient) and which should remain stable (low gradient) during adaptation.
Explicit optimization for adaptation: Unlike heuristic approaches, MAML directly optimizes the quantity of interest—post-adaptation performance.
MAML operates through a nested loop structure: the inner loop adapts to specific tasks, while the outer loop optimizes the shared initialization for fast adaptation.
Algorithm Overview:
Require: p(𝒯): distribution over tasks
Require: α: inner loop learning rate
Require: β: outer loop (meta) learning rate
1: Randomly initialize θ
2: while not done do
3: Sample batch of tasks 𝒯_i ~ p(𝒯)
4: for all 𝒯_i do
5: Sample K datapoints D_i = {(x_j, y_j)} from 𝒯_i for adaptation
6: Evaluate ∇_θ ℒ_𝒯_i(f_θ) using D_i
7: Compute adapted parameters: θ'_i = θ - α∇_θ ℒ_𝒯_i(f_θ)
8: Sample datapoints D'_i = {(x_j, y_j)} from 𝒯_i for meta-update
9: end for
10: Update θ ← θ - β∇_θ Σ_i ℒ_𝒯_i(f_{θ'_i}) using D'_i for each 𝒯_i
11: end while
The critical insight is in step 10: we update $\theta$ to minimize loss after adaptation ($\theta'_i$), not before. This directly optimizes for adaptability.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
import torchimport torch.nn as nnfrom torch.nn import functional as Ffrom typing import List, Tuple, Dictimport copy class MAML: """ Model-Agnostic Meta-Learning (MAML) implementation. Key insight: Learn an initialization from which a few gradient steps produce excellent task-specific parameters. """ def __init__( self, model: nn.Module, inner_lr: float = 0.01, # α: task adaptation learning rate outer_lr: float = 0.001, # β: meta-learning rate inner_steps: int = 5, # Number of gradient steps for adaptation first_order: bool = False # Use first-order approximation? ): self.model = model self.inner_lr = inner_lr self.outer_lr = outer_lr self.inner_steps = inner_steps self.first_order = first_order # Meta-optimizer updates the initialization self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=outer_lr) def inner_loop( self, support_x: torch.Tensor, support_y: torch.Tensor, params: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: """ Inner loop: Adapt to a specific task. Takes K gradient steps on the support set, starting from the current meta-parameters. Args: support_x: Support set inputs [n_support, ...] support_y: Support set labels [n_support] params: Current parameters (shared initialization) Returns: adapted_params: Task-specific parameters after adaptation """ adapted_params = {k: v.clone() for k, v in params.items()} for step in range(self.inner_steps): # Forward pass with current adapted parameters logits = self.functional_forward(support_x, adapted_params) loss = F.cross_entropy(logits, support_y) # Compute gradients w.r.t. adapted parameters grads = torch.autograd.grad( loss, adapted_params.values(), create_graph=not self.first_order # Need graph for 2nd order ) # Gradient descent update adapted_params = { k: adapted_params[k] - self.inner_lr * g for (k, _), g in zip(adapted_params.items(), grads) } return adapted_params def outer_loop( self, task_batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] ) -> float: """ Outer loop: Update initialization for better adaptability. For each task: 1. Adapt parameters using support set (inner loop) 2. Evaluate adapted parameters on query set 3. Accumulate meta-gradients Key: Gradient flows THROUGH the inner loop adaptation. Args: task_batch: List of (support_x, support_y, query_x, query_y) tuples Returns: Mean meta-loss across tasks """ self.meta_optimizer.zero_grad() meta_loss = 0.0 # Get current parameters as a dictionary params = dict(self.model.named_parameters()) for support_x, support_y, query_x, query_y in task_batch: # Inner loop: task-specific adaptation adapted_params = self.inner_loop(support_x, support_y, params) # Evaluate adapted parameters on query set query_logits = self.functional_forward(query_x, adapted_params) task_loss = F.cross_entropy(query_logits, query_y) meta_loss += task_loss # Average over task batch meta_loss = meta_loss / len(task_batch) # Compute meta-gradients and update # Gradients flow through inner_loop adaptation! meta_loss.backward() self.meta_optimizer.step() return meta_loss.item() def functional_forward( self, x: torch.Tensor, params: Dict[str, torch.Tensor] ) -> torch.Tensor: """ Forward pass using specified parameters. This allows using adapted parameters without modifying model state. Implementation depends on model architecture. """ # Example for a simple CNN with named parameters # Actual implementation requires model-specific handling # For Linear layers: F.linear(x, params['layer.weight'], params['layer.bias']) # For Conv layers: F.conv2d(x, params['conv.weight'], params['conv.bias'], ...) # Many implementations use higher-level libraries like 'higher' or 'learn2learn' raise NotImplementedError("Implement based on model architecture") def adapt_and_evaluate( self, support_x: torch.Tensor, support_y: torch.Tensor, query_x: torch.Tensor, query_y: torch.Tensor ) -> Tuple[float, float]: """ Adapt to a new task and evaluate. Used at test time to evaluate few-shot performance. """ params = dict(self.model.named_parameters()) # Adapt to task adapted_params = self.inner_loop(support_x, support_y, params) # Evaluate on query with torch.no_grad(): query_logits = self.functional_forward(query_x, adapted_params) loss = F.cross_entropy(query_logits, query_y) predictions = query_logits.argmax(dim=1) accuracy = (predictions == query_y).float().mean() return loss.item(), accuracy.item()MAML requires a 'functional forward' that takes parameters as input rather than using model.parameters() directly. This is necessary because adapted_params must remain in the computational graph for gradient computation. Libraries like 'higher' and 'learn2learn' provide clean abstractions for this.
Understanding why MAML requires second-order derivatives illuminates both its power and computational cost.
The Meta-Objective:
$$\min_\theta \sum_{\mathcal{T}i \sim p(\mathcal{T})} \mathcal{L}{\mathcal{T}i}(f{\theta'_i})$$
where $\theta'i = \theta - \alpha \nabla\theta \mathcal{L}_{\mathcal{T}i}(f\theta)$
Computing the Meta-Gradient:
We need $\frac{\partial}{\partial \theta} \mathcal{L}_{\mathcal{T}i}(f{\theta'_i})$
Using the chain rule: $$\frac{\partial \mathcal{L}(\theta'_i)}{\partial \theta} = \frac{\partial \mathcal{L}}{\partial \theta'_i} \cdot \frac{\partial \theta'_i}{\partial \theta}$$
Now, $\theta'i = \theta - \alpha \nabla\theta \mathcal{L}(\theta)$, so: $$\frac{\partial \theta'_i}{\partial \theta} = I - \alpha \frac{\partial^2 \mathcal{L}}{\partial \theta^2}$$
This is the Hessian of the task loss! The full meta-gradient is: $$\nabla_\theta \mathcal{L}(\theta'i) = \nabla{\theta'_i} \mathcal{L}(\theta'i) \cdot \left(I - \alpha H\theta\right)$$
| Component | Complexity | Memory | Description |
|---|---|---|---|
| Forward pass | O(|θ|) | O(|θ|) | Standard neural network forward |
| Inner-loop gradient | O(|θ|) | O(|θ|) | Backprop through task loss |
| Hessian (full) | O(|θ|²) | O(|θ|²) | Second derivatives—prohibitive |
| Hessian-vector product | O(|θ|) | O(|θ|) | Efficient approximation |
| Per inner step | O(|θ|) | O(K·|θ|) | Must store K computational graphs |
The Hessian-Vector Product Trick:
Computing the full Hessian $H$ is $O(|\theta|^2)$—prohibitively expensive. But we only need $H \cdot v$ for some vector $v$ (the outer gradient). This can be computed in $O(|\theta|)$ using automatic differentiation:
$$H \cdot v = \nabla_\theta (\nabla_\theta \mathcal{L} \cdot v)$$
PyTorch/JAX compute this efficiently through reverse-mode autodiff composition.
Memory Considerations:
Each inner-loop step builds upon the previous computational graph. With $K$ inner steps, we must store $K$ backward graphs before the meta-update. This quickly becomes prohibitive for deep networks with many inner steps.
Solutions:
The Hessian term captures how the loss landscape curvature affects adaptation. Ignoring it (first-order) assumes the loss landscape is locally flat. Second-order information allows MAML to learn initializations that sit in regions with favorable curvature—where gradient descent works well.
The computational cost of full MAML motivated first-order approximations that drop the Hessian term. Remarkably, these often perform nearly as well as full MAML while being significantly cheaper.
First-Order MAML (FOMAML):
FOMAML approximates the meta-gradient by ignoring second-order terms:
$$\nabla_\theta \mathcal{L}(\theta'i) \approx \nabla{\theta'_i} \mathcal{L}(\theta'_i)$$
This treats $\theta'_i$ as independent of $\theta$ during backpropagation—we simply evaluate the gradient at the adapted parameters and use it directly.
Implementation difference:
# Full MAML: create_graph=True (second-order)
grads = torch.autograd.grad(loss, params, create_graph=True)
# FOMAML: create_graph=False (first-order)
grads = torch.autograd.grad(loss, params, create_graph=False)
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
import torchimport torch.nn as nnfrom typing import List, Tupleimport copy class Reptile: """ Reptile: A simpler alternative to MAML. Key insight: Instead of computing meta-gradients through adaptation, simply move the initialization toward adapted parameters. Update rule: θ ← θ + ε(θ'_i - θ) This is equivalent to averaging across task-specific fine-tuned parameters, providing an implicit regularization effect. """ def __init__( self, model: nn.Module, inner_lr: float = 0.01, # Learning rate for task adaptation outer_lr: float = 0.001, # Meta-learning rate (step size toward θ') inner_steps: int = 5, # SGD steps per task ): self.model = model self.inner_lr = inner_lr self.outer_lr = outer_lr self.inner_steps = inner_steps def train_step( self, task_batch: List[Tuple[torch.Tensor, torch.Tensor]] ) -> float: """ Reptile training step. For each task: 1. Clone current parameters 2. Take K gradient steps on task data 3. Compute θ' - θ (direction toward task optimum) Then: Move θ toward average task-adapted parameters """ # Store original parameters original_weights = { name: param.clone() for name, param in self.model.named_parameters() } # Accumulate (θ' - θ) across tasks weight_diffs = { name: torch.zeros_like(param) for name, param in self.model.named_parameters() } total_loss = 0.0 for support_x, support_y in task_batch: # Reset to original weights before each task for name, param in self.model.named_parameters(): param.data.copy_(original_weights[name]) # Task adaptation: K gradient descent steps optimizer = torch.optim.SGD( self.model.parameters(), lr=self.inner_lr ) for step in range(self.inner_steps): optimizer.zero_grad() logits = self.model(support_x) loss = nn.functional.cross_entropy(logits, support_y) loss.backward() optimizer.step() if step == self.inner_steps - 1: total_loss += loss.item() # Accumulate difference: θ'_i - θ for name, param in self.model.named_parameters(): weight_diffs[name] += param.data - original_weights[name] # Average across tasks n_tasks = len(task_batch) # Update original weights: θ ← θ + ε * mean(θ'_i - θ) for name, param in self.model.named_parameters(): param.data.copy_( original_weights[name] + self.outer_lr * weight_diffs[name] / n_tasks ) return total_loss / n_tasks def adapt( self, support_x: torch.Tensor, support_y: torch.Tensor, adapt_steps: int = None ): """ Adapt to a new task at test time. Simply fine-tune for K steps (same as inner loop during training). """ steps = adapt_steps or self.inner_steps optimizer = torch.optim.SGD( self.model.parameters(), lr=self.inner_lr ) for _ in range(steps): optimizer.zero_grad() logits = self.model(support_x) loss = nn.functional.cross_entropy(logits, support_y) loss.backward() optimizer.step() # Theoretical connection between Reptile and MAML"""Reptile approximates MAML in expectation: For a single inner step: Reptile update: θ ← θ + ε(θ - α∇L(θ) - θ) = θ - εα∇L(θ) This is just SGD! The magic happens with multiple inner steps. With K inner steps, Reptile's update direction includes:1. Average gradient (like SGD)2. Curvature-aware terms (like second-order methods) The key insight: Moving toward task-adapted parameters implicitlycaptures second-order information about the loss landscape. Empirically, Reptile performs comparably to FOMAML and sometimesapproaches full MAML, while being simpler to implement."""| Method | Second-Order | Memory | Compute | Typical Performance |
|---|---|---|---|---|
| Full MAML | Yes | High (graph per step) | High | Best |
| FOMAML | No | Low | Medium | ~0.5-1% below MAML |
| Reptile | No (implicit) | Very Low | Low | ~1% below MAML |
| iMAML | Implicit | Low | Medium | Near MAML |
Start with Reptile for prototyping (simplest). Move to FOMAML for better performance without much complexity. Use full MAML only if you have resources and need the extra 0.5-1% accuracy. iMAML is best when you need many inner steps.
The original MAML algorithm encounters several practical challenges that limit its effectiveness. MAML++ (Antoniou et al., 2018) addresses these systematically, making MAML significantly more stable and effective.
Key Problems with Vanilla MAML:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
import torchimport torch.nn as nnfrom typing import Dict, List class MAMLPlusPlus: """ MAML++ with practical improvements for stability and performance. """ def __init__( self, model: nn.Module, inner_lr_init: float = 0.01, outer_lr: float = 0.001, inner_steps: int = 5, learn_inner_lr: bool = True, per_layer_lr: bool = True, multi_step_loss: bool = True, msl_decay: float = 0.5, # Weight decay for earlier steps anneal_second_order: bool = True, anneal_epochs: int = 50, ): self.model = model self.inner_steps = inner_steps self.multi_step_loss = multi_step_loss self.msl_decay = msl_decay self.anneal_second_order = anneal_second_order self.anneal_epochs = anneal_epochs self.current_epoch = 0 # Learnable per-layer inner learning rates if per_layer_lr and learn_inner_lr: # One learning rate per layer self.inner_lrs = nn.ParameterDict({ name.replace('.', '_'): nn.Parameter(torch.tensor(inner_lr_init)) for name, _ in model.named_parameters() }) elif learn_inner_lr: # Single learnable learning rate self.inner_lrs = nn.Parameter(torch.tensor(inner_lr_init)) else: self.inner_lrs = inner_lr_init # Meta-optimizer includes inner learning rates if learnable meta_params = list(model.parameters()) if isinstance(self.inner_lrs, nn.ParameterDict): meta_params.extend(self.inner_lrs.values()) elif isinstance(self.inner_lrs, nn.Parameter): meta_params.append(self.inner_lrs) self.meta_optimizer = torch.optim.Adam(meta_params, lr=outer_lr) def get_inner_lr(self, param_name: str) -> torch.Tensor: """Get the inner learning rate for a specific parameter.""" if isinstance(self.inner_lrs, nn.ParameterDict): key = param_name.replace('.', '_') return self.inner_lrs[key] elif isinstance(self.inner_lrs, nn.Parameter): return self.inner_lrs else: return torch.tensor(self.inner_lrs) @property def use_second_order(self) -> bool: """Determine whether to use second-order gradients.""" if not self.anneal_second_order: return True # Linear annealing from 0 to 1 over anneal_epochs return self.current_epoch >= self.anneal_epochs def multi_step_loss_weights(self) -> List[float]: """ Compute weights for multi-step loss optimization. Earlier steps get lower weights, final step gets highest. This encourages the model to optimize adaptation trajectory. """ weights = [] for k in range(self.inner_steps): # Exponential decay: later steps weighted more weight = self.msl_decay ** (self.inner_steps - k - 1) weights.append(weight) # Normalize total = sum(weights) return [w / total for w in weights] def inner_loop_with_msl( self, support_x: torch.Tensor, support_y: torch.Tensor, params: Dict[str, torch.Tensor] ) -> tuple: """ Inner loop with multi-step loss tracking. Returns adapted parameters AND list of losses at each step. """ adapted_params = {k: v.clone() for k, v in params.items()} step_losses = [] for step in range(self.inner_steps): logits = self.functional_forward(support_x, adapted_params) loss = nn.functional.cross_entropy(logits, support_y) step_losses.append(loss) grads = torch.autograd.grad( loss, adapted_params.values(), create_graph=self.use_second_order ) # Per-parameter learning rates adapted_params = {} for (name, param), g in zip(params.items(), grads): lr = self.get_inner_lr(name) adapted_params[name] = param - lr * g return adapted_params, step_losses def outer_loop(self, task_batch: List) -> float: """ Outer loop with multi-step loss and other MAML++ improvements. """ self.meta_optimizer.zero_grad() meta_loss = 0.0 msl_weights = self.multi_step_loss_weights() params = dict(self.model.named_parameters()) for support_x, support_y, query_x, query_y in task_batch: adapted_params, step_losses = self.inner_loop_with_msl( support_x, support_y, params ) # Query loss query_logits = self.functional_forward(query_x, adapted_params) query_loss = nn.functional.cross_entropy(query_logits, query_y) if self.multi_step_loss: # Weighted combination of inner step losses + query loss task_loss = query_loss for weight, step_loss in zip(msl_weights, step_losses): task_loss = task_loss + 0.1 * weight * step_loss else: task_loss = query_loss meta_loss += task_loss meta_loss = meta_loss / len(task_batch) meta_loss.backward() self.meta_optimizer.step() return meta_loss.item() def functional_forward(self, x, params): """Model-specific functional forward pass.""" raise NotImplementedErrorMAML++ improvements typically yield 2-5% accuracy gains over vanilla MAML on miniImageNet. The combination of learned learning rates and multi-step loss optimization contributes most. These improvements are now standard in competitive MAML implementations.
MAML is powerful but not universally optimal. Understanding when it excels versus when alternatives are better helps select the right approach.
MAML Strengths:
| Scenario | Recommended Approach | Rationale |
|---|---|---|
| Few-shot classification | Prototypical Networks | Simpler, faster, often comparable accuracy |
| Few-shot RL | MAML | Policy needs task-specific adaptation |
| Heavy model + limited compute | Reptile | Low memory, no second-order gradients |
| Cross-domain generalization | MAML | Adaptation handles domain differences |
| Very large models | Adapter-tuning | Only adapt small parameter subset |
| Need theoretical guarantees | MAML | Well-analyzed convergence properties |
Before implementing MAML, try simple baselines: pre-trained encoder + nearest centroid, or fine-tuning with early stopping. These often perform surprisingly well and help establish whether the added complexity of MAML is warranted.
Implementing MAML correctly requires attention to details that aren't obvious from the algorithm description. This section provides a practical checklist for successful implementation.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
import torchimport learn2learn as l2lfrom learn2learn.algorithms import MAML def train_maml( model, train_dataset, val_dataset, n_epochs: int = 100, tasks_per_batch: int = 32, n_way: int = 5, k_shot: int = 5, q_query: int = 15, inner_lr: float = 0.01, outer_lr: float = 0.001, inner_steps: int = 5, first_order: bool = False, grad_clip: float = 10.0,): """ Complete MAML training loop with best practices. Uses learn2learn library for clean implementation. """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) # Wrap model with MAML maml = MAML(model, lr=inner_lr, first_order=first_order) # Meta-optimizer optimizer = torch.optim.Adam(maml.parameters(), lr=outer_lr) # Task samplers train_tasks = l2l.data.TaskDataset( train_dataset, task_transforms=[ l2l.data.transforms.NWays(train_dataset, n_way), l2l.data.transforms.KShots(train_dataset, k_shot + q_query), l2l.data.transforms.LoadData(train_dataset), ], num_tasks=-1 # Infinite ) best_val_acc = 0.0 patience = 10 patience_counter = 0 for epoch in range(n_epochs): model.train() epoch_loss = 0.0 epoch_acc = 0.0 for task_idx in range(tasks_per_batch): # Sample task batch = train_tasks.sample() # Split into support and query support_x = batch[0][:n_way * k_shot].to(device) support_y = batch[1][:n_way * k_shot].to(device) query_x = batch[0][n_way * k_shot:].to(device) query_y = batch[1][n_way * k_shot:].to(device) # Clone for task-specific adaptation learner = maml.clone() # Inner loop: adapt to task for step in range(inner_steps): support_logits = learner(support_x) support_loss = torch.nn.functional.cross_entropy( support_logits, support_y ) learner.adapt(support_loss) # Outer loop: evaluate on query, accumulate gradients query_logits = learner(query_x) query_loss = torch.nn.functional.cross_entropy( query_logits, query_y ) epoch_loss += query_loss.item() # Accuracy preds = query_logits.argmax(dim=1) epoch_acc += (preds == query_y).float().mean().item() # Accumulate gradients query_loss.backward() # Meta-update with gradient clipping torch.nn.utils.clip_grad_norm_(maml.parameters(), grad_clip) optimizer.step() optimizer.zero_grad() # Logging avg_loss = epoch_loss / tasks_per_batch avg_acc = epoch_acc / tasks_per_batch print(f"Epoch {epoch + 1}: Loss={avg_loss:.4f}, Acc={avg_acc:.2%}") # Validation if (epoch + 1) % 5 == 0: val_acc = evaluate_maml(maml, val_dataset, n_way, k_shot, q_query, inner_steps, device) print(f" Validation Acc: {val_acc:.2%}") if val_acc > best_val_acc: best_val_acc = val_acc torch.save(maml.state_dict(), 'best_maml.pt') patience_counter = 0 else: patience_counter += 1 if patience_counter >= patience: print("Early stopping!") break return mamlYou now have comprehensive understanding of MAML—from intuition to mathematics to implementation. Next, we'll explore Prototypical Networks, which take a fundamentally different approach: learning to compare rather than learning to adapt.
Coming Next: Page 3 covers Prototypical Networks—a metric-based approach that learns an embedding space where classification reduces to nearest-prototype comparison, offering speed and simplicity while achieving competitive performance.