Loading learning content...
Multi-task learning introduces optimization challenges that don't exist in single-task settings. When training on multiple objectives simultaneously, we face conflicting gradients, imbalanced task scales, varying convergence rates, and the fundamental problem of finding a single solution that performs well across all tasks.
This page provides a comprehensive treatment of MTL optimization: the core challenges, their mathematical characterization, and the state-of-the-art techniques developed to address them. Mastering these concepts is essential for building effective MTL systems in practice.
By the end of this page, you will understand: (1) gradient conflict and interference, (2) task balancing strategies, (3) multi-objective optimization perspectives, (4) advanced gradient manipulation techniques, and (5) practical optimization recipes for MTL.
The fundamental optimization challenge in MTL is gradient conflict: when gradients from different tasks point in different (or opposite) directions in parameter space.
Mathematical Characterization:
For tasks $T_1, ..., T_k$ with losses $\mathcal{L}_1, ..., \mathcal{L}_k$, the gradients with respect to shared parameters $\theta$ are:
$$g_t = \nabla_\theta \mathcal{L}_t, \quad t \in {1, ..., k}$$
The combined gradient in naive MTL is: $$g = \sum_{t=1}^{k} \lambda_t g_t$$
Conflict occurs when: $$g_i \cdot g_j < 0 \quad \text{(tasks disagree on update direction)}$$
In severe cases, the combined gradient $g$ may have negative projection onto some task gradients, meaning the update hurts that task.
Gradient conflict leads to the 'seesaw effect': improving one task degrades another. Training oscillates between favoring different tasks without reliably improving all. This is a key symptom of optimization difficulties in MTL.
Quantifying Conflict:
Gradient Cosine Similarity: $$\cos(g_i, g_j) = \frac{g_i \cdot g_j}{||g_i|| \cdot ||g_j||}$$ Negative values indicate conflict.
Gradient Agreement Ratio: Fraction of parameters where gradients agree on sign.
Conflict Intensity: $$C = \sum_{i<j} \max(0, -g_i \cdot g_j)$$ Measures total magnitude of conflicting gradients.
Sources of Gradient Conflict:
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
import torchfrom typing import Dict, List, Tuple def analyze_gradient_conflict( model: torch.nn.Module, task_batches: Dict[str, Tuple[torch.Tensor, torch.Tensor]], loss_fns: Dict[str, torch.nn.Module]) -> Dict[str, float]: """ Comprehensive analysis of gradient conflicts in MTL. """ task_names = list(task_batches.keys()) gradients = {} # Compute per-task gradients for task, (x, y) in task_batches.items(): model.zero_grad() pred = model(x, task) loss = loss_fns[task](pred, y) loss.backward() grad = torch.cat([ p.grad.flatten() for p in model.parameters() if p.grad is not None ]) gradients[task] = grad.detach().clone() results = {} # Pairwise cosine similarities n_conflicts = 0 total_pairs = 0 total_cos = 0 for i, t1 in enumerate(task_names): for t2 in task_names[i+1:]: g1, g2 = gradients[t1], gradients[t2] cos = torch.dot(g1, g2) / (g1.norm() * g2.norm() + 1e-8) total_cos += cos.item() total_pairs += 1 if cos < 0: n_conflicts += 1 results[f'cos_{t1}_{t2}'] = cos.item() results['conflict_rate'] = n_conflicts / max(total_pairs, 1) results['avg_cosine'] = total_cos / max(total_pairs, 1) # Combined gradient analysis combined = sum(gradients.values()) for task in task_names: # Projection of combined onto task gradient proj = torch.dot(combined, gradients[task]) proj = proj / (gradients[task].norm() + 1e-8) results[f'combined_proj_{task}'] = proj.item() # If negative, combined update hurts this task results[f'hurts_{task}'] = proj.item() < 0 return resultsTasks often have different loss scales, learning dynamics, and difficulty levels. Without careful balancing, some tasks dominate training while others are neglected.
Static Weighting:
The simplest approach assigns fixed weights $\lambda_t$ to each task: $$\mathcal{L} = \sum_t \lambda_t \mathcal{L}_t$$
Weights can be chosen by:
The uncertainty weighting approach learns task-specific uncertainty σ_t during training. Tasks with higher uncertainty (harder to predict confidently) receive lower weight. This provides principled automatic balancing with minimal hyperparameters.
Dynamic Weighting:
More sophisticated methods adapt weights during training:
1. GradNorm (Chen et al., 2018): Balance gradient norms across tasks:
$$\tilde{w}_t(i) \leftarrow \tilde{w}_t(i-1) \cdot \left(\frac{r_t(i)}{\bar{r}(i)}\right)^\alpha$$
where $r_t$ is the relative inverse training rate of task $t$.
2. Dynamic Weight Averaging (DWA): $$\lambda_t(i) = \frac{\exp(w_t(i-1)/T)}{\sum_j \exp(w_j(i-1)/T)}$$ where $w_t$ measures training speed.
3. Gradient-Based Meta-Learning: Treat task weights as learnable parameters optimized on validation set.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
import torchimport torch.nn as nnfrom typing import Dict class UncertaintyWeighting(nn.Module): """ Homoscedastic uncertainty weighting for MTL. Learns task-specific log variances to balance losses. """ def __init__(self, task_names: list): super().__init__() # Log variance for numerical stability self.log_vars = nn.ParameterDict({ task: nn.Parameter(torch.zeros(1)) for task in task_names }) def forward( self, task_losses: Dict[str, torch.Tensor] ) -> torch.Tensor: """ Compute uncertainty-weighted combined loss. L = sum_t (1/(2*sigma_t^2)) * L_t + log(sigma_t) """ total_loss = 0 for task, loss in task_losses.items(): log_var = self.log_vars[task] # Precision weighting with regularization precision = torch.exp(-log_var) total_loss += precision * loss + log_var return total_loss class GradNormBalancer: """ GradNorm: Gradient normalization for balanced MTL. """ def __init__( self, model: nn.Module, task_names: list, alpha: float = 1.5 ): self.model = model self.task_names = task_names self.alpha = alpha # Task weights (learnable) self.weights = {t: 1.0 for t in task_names} self.initial_losses = None def update_weights( self, task_losses: Dict[str, float], shared_params: list ): """Update task weights based on gradient norms.""" if self.initial_losses is None: self.initial_losses = task_losses.copy() return # Loss ratios (training speed indicators) loss_ratios = { t: task_losses[t] / (self.initial_losses[t] + 1e-8) for t in self.task_names } mean_ratio = sum(loss_ratios.values()) / len(loss_ratios) # Relative inverse training rates inv_rates = { t: (loss_ratios[t] / (mean_ratio + 1e-8)) ** self.alpha for t in self.task_names } # Target: balance gradient norms # Adjust weights to achieve balanced rates for t in self.task_names: self.weights[t] *= inv_rates[t] # Renormalize total = sum(self.weights.values()) for t in self.task_names: self.weights[t] /= total self.weights[t] *= len(self.task_names) def get_weighted_loss( self, task_losses: Dict[str, torch.Tensor] ) -> torch.Tensor: return sum( self.weights[t] * loss for t, loss in task_losses.items() )MTL can be viewed as multi-objective optimization (MOO), where we seek solutions that are optimal across multiple objectives simultaneously.
Pareto Optimality:
A solution $\theta^*$ is Pareto optimal if no other solution improves one task without worsening another:
$$\nexists \theta: \mathcal{L}_t(\theta) \leq \mathcal{L}_t(\theta^) \forall t, \text{ with } \mathcal{L}_j(\theta) < \mathcal{L}_j(\theta^) \text{ for some } j$$
The set of all Pareto optimal solutions forms the Pareto front.
Multiple Gradient Descent Algorithm (MGDA):
MGDA finds a gradient direction that improves all tasks (if one exists):
$$\min_{d} ||d||^2 \quad \text{s.t.} \quad d = \sum_t \alpha_t g_t, \quad \alpha_t \geq 0, \quad \sum_t \alpha_t = 1$$
The solution $d$ is the minimum-norm element of the convex hull of task gradients.
| Method | Approach | Key Property |
|---|---|---|
| MGDA | Min-norm in gradient convex hull | Guaranteed Pareto improvement |
| CAGrad | Maximize worst-case task improvement | Conflict-averse |
| Nash-MTL | Find Nash equilibrium | Fair to all tasks |
| Pareto-MTL | Explore Pareto front | Diverse solutions |
| IMTL-G | Frank-Wolfe optimization | Efficient for many tasks |
CAGrad (Conflict-Averse Gradient Descent):
Finds update direction within a bounded region around the average gradient:
$$d = \argmax_{||u - g_{avg}|| \leq c} \min_t g_t \cdot u$$
Maximizes minimum task improvement within a trust region.
Nash-MTL:
Formulates MTL as a bargaining game and finds the Nash equilibrium, ensuring no task can improve without cooperation.
Modern MTL research has developed sophisticated techniques to manipulate gradients for better optimization.
1. PCGrad (Projecting Conflicting Gradients):
When $g_i \cdot g_j < 0$, project away the conflicting component: $$g_i' = g_i - \frac{g_i \cdot g_j}{||g_j||^2} g_j$$
Removes the component of $g_i$ that conflicts with $g_j$.
2. Gradient Vaccine:
Similar to PCGrad but applies softer projection based on conflict severity.
3. Gradient Surgery:
Only modifies gradients when conflicts actually cause harm, preserving beneficial interactions.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
import torchimport numpy as npfrom typing import List def mgda_direction(gradients: List[torch.Tensor]) -> torch.Tensor: """ MGDA: Find minimum-norm direction in convex hull of gradients. Uses Frank-Wolfe optimization. """ n_tasks = len(gradients) # Stack gradients: [n_tasks, n_params] G = torch.stack([g.flatten() for g in gradients]) # Frank-Wolfe to find min-norm point # Initialize with uniform weights alpha = torch.ones(n_tasks) / n_tasks for _ in range(20): # FW iterations d = G.T @ alpha # Current direction # Find task with most negative inner product inner_products = G @ d min_task = torch.argmin(inner_products) # Frank-Wolfe step gamma = 2.0 / (2 + _) # Step size alpha_new = (1 - gamma) * alpha alpha_new[min_task] += gamma alpha = alpha_new return G.T @ alpha def cagrad_direction( gradients: List[torch.Tensor], c: float = 0.5) -> torch.Tensor: """ CAGrad: Conflict-Averse Gradient Descent. Finds direction maximizing worst-case improvement. """ G = torch.stack([g.flatten() for g in gradients]) n_tasks = len(gradients) # Average gradient g_avg = G.mean(dim=0) # Solve constrained optimization # Approximate with projection onto feasible region # Compute task projections onto average g_avg_norm = g_avg.norm() + 1e-8 # Find direction that maximizes min projection # Using the closed-form for 2-task case, iterative for more if n_tasks == 2: g0, g1 = G[0], G[1] if torch.dot(g0, g1) >= 0: return g_avg # Project to maximize minimum cos_angle = torch.dot(g0, g1) / (g0.norm() * g1.norm()) if cos_angle < -0.99: return torch.zeros_like(g_avg) # Blend based on magnitudes w0 = g1.norm() / (g0.norm() + g1.norm()) w1 = 1 - w0 return w0 * g0 + w1 * g1 # General case: iterative refinement d = g_avg.clone() for _ in range(10): projections = G @ d min_idx = torch.argmin(projections) # Move toward improving worst task d = d + 0.1 * (G[min_idx] - d.dot(G[min_idx]) * d / (d.norm()**2 + 1e-8)) # Project back to trust region diff = d - g_avg if diff.norm() > c * g_avg_norm: d = g_avg + c * g_avg_norm * diff / diff.norm() return dBased on extensive empirical research, here are practical recommendations for MTL optimization:
Simple approaches (uniform weights, uncertainty weighting) work well for related tasks. Use gradient manipulation (PCGrad, CAGrad) when you observe persistent gradient conflicts or the seesaw effect. For many tasks with varying relationships, consider learned weighting or MOO methods.
You now understand the optimization challenges in MTL and techniques to address them. The final page covers When MTL Helps—practical guidelines for when multi-task learning provides benefits over single-task alternatives.