Loading learning content...
The success of multi-task learning fundamentally depends on the relationships between tasks. When tasks are related, sharing information improves learning. When tasks are unrelated or conflicting, sharing can hurt performance through negative transfer.
This page explores how to conceptualize, measure, and leverage task relationships. We examine theoretical frameworks for task relatedness, practical methods for quantifying similarity, and techniques for using task structure to guide MTL architecture design.
By the end of this page, you will understand: (1) theoretical frameworks for task relatedness, (2) methods for measuring task similarity, (3) task clustering and grouping strategies, (4) how to leverage task relationships in architecture design, and (5) techniques for handling task conflicts.
Several theoretical frameworks formalize the notion of task relatedness, each capturing different aspects of how tasks can share structure.
1. Shared Hypothesis Class:
Tasks are related if they share optimal hypotheses from a common class. Formally, tasks $T_1, ..., T_k$ are related if there exists a hypothesis class $\mathcal{H}$ such that the optimal hypothesis for each task lies in $\mathcal{H}$:
$$h_t^* \in \mathcal{H}, \quad \forall t \in {1, ..., k}$$
2. Shared Representation:
Tasks are related if they share an optimal representation function. Each task's optimal predictor can be decomposed as:
$$f_t^* = g_t \circ h^*$$
where $h^*$ is the shared representation and $g_t$ are task-specific heads.
3. Task Covariance:
In the Bayesian view, task parameters $\theta_t$ are drawn from a prior distribution. Related tasks have correlated parameters:
$$\text{Cov}(\theta_i, \theta_j) > 0 \text{ for related tasks}$$
Tasks exist on a spectrum of relatedness. Two tasks might share low-level features but diverge at higher levels, or vice versa. Understanding the granularity and nature of task relationships is crucial for effective MTL design.
4. Transfer Distance:
A more operational definition measures relatedness through transfer performance. The transfer distance from task $i$ to task $j$ is:
$$d(T_i \to T_j) = \mathcal{L}{T_j}(h{T_i}) - \mathcal{L}{T_j}(h{T_j}^*)$$
where $h_{T_i}$ is trained on $T_i$ and evaluated on $T_j$. Small transfer distance indicates high relatedness.
5. Gradient Alignment:
During MTL training, relatedness manifests in gradient dynamics. Tasks are related if their gradients align:
$$\cos(g_i, g_j) = \frac{g_i \cdot g_j}{||g_i|| \cdot ||g_j||} > 0$$
Positive alignment indicates tasks agree on optimization direction.
Before training an MTL system, we often want to estimate task similarity to guide architecture decisions. Several practical approaches exist:
Data-Based Measures:
Input Distribution Similarity:
Label Correlation:
Transfer-Based Measures:
Probe Networks:
Representation Similarity:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
import torchimport numpy as npfrom sklearn.cross_decomposition import CCAfrom typing import Dict, List, Tuple def compute_task_affinity_matrix( task_models: Dict[str, torch.nn.Module], task_dataloaders: Dict[str, torch.utils.data.DataLoader], metric: str = 'accuracy') -> np.ndarray: """ Compute pairwise task affinity through transfer evaluation. Returns: Affinity matrix A where A[i,j] = performance of model i on task j """ task_names = list(task_models.keys()) n_tasks = len(task_names) affinity = np.zeros((n_tasks, n_tasks)) for i, source_task in enumerate(task_names): model = task_models[source_task] model.eval() for j, target_task in enumerate(task_names): loader = task_dataloaders[target_task] correct = 0 total = 0 with torch.no_grad(): for x, y in loader: pred = model(x).argmax(dim=1) correct += (pred == y).sum().item() total += len(y) affinity[i, j] = correct / total return affinity def compute_representation_similarity_cka( model1: torch.nn.Module, model2: torch.nn.Module, dataloader: torch.utils.data.DataLoader) -> float: """ Compute CKA (Centered Kernel Alignment) similarity between representations learned by two models. """ model1.eval() model2.eval() reps1, reps2 = [], [] with torch.no_grad(): for x, _ in dataloader: reps1.append(model1.get_representation(x).cpu()) reps2.append(model2.get_representation(x).cpu()) X = torch.cat(reps1, dim=0).numpy() Y = torch.cat(reps2, dim=0).numpy() # Compute CKA def centering(K): n = K.shape[0] unit = np.ones([n, n]) H = np.eye(n) - unit / n return H @ K @ H def linear_kernel(X): return X @ X.T K_X = centering(linear_kernel(X)) K_Y = centering(linear_kernel(Y)) hsic = np.sum(K_X * K_Y) norm = np.sqrt(np.sum(K_X * K_X) * np.sum(K_Y * K_Y)) return hsic / norm if norm > 0 else 0.0 def compute_gradient_similarity( model: torch.nn.Module, task_batches: Dict[str, Tuple[torch.Tensor, torch.Tensor]], loss_fn: torch.nn.Module) -> Dict[Tuple[str, str], float]: """ Compute pairwise gradient cosine similarity between tasks. """ task_names = list(task_batches.keys()) gradients = {} for task, (x, y) in task_batches.items(): model.zero_grad() pred = model(x, task) loss = loss_fn(pred, y) loss.backward() # Collect gradients from shared parameters grad = torch.cat([ p.grad.flatten() for p in model.get_shared_params() if p.grad is not None ]) gradients[task] = grad.detach() similarities = {} for i, task_i in enumerate(task_names): for task_j in task_names[i+1:]: g_i, g_j = gradients[task_i], gradients[task_j] cos_sim = torch.dot(g_i, g_j) / (g_i.norm() * g_j.norm()) similarities[(task_i, task_j)] = cos_sim.item() return similaritiesGradient-Based Measures (During Training):
Gradient Cosine Similarity: $$\cos(g_i, g_j) = \frac{\nabla_{\theta} \mathcal{L}i \cdot \nabla{\theta} \mathcal{L}j}{||\nabla{\theta} \mathcal{L}i|| \cdot ||\nabla{\theta} \mathcal{L}_j||}$$
Gradient Conflict: Count how often gradients point in opposite directions for each parameter.
Task Affinity via Training: Measure how training on one task affects validation loss on others.
When dealing with many tasks, not all should share parameters equally. Task clustering groups similar tasks to share parameters within clusters while maintaining separation between clusters.
Clustering Strategies:
Pre-determined Clustering:
Data-Driven Clustering:
Learned Clustering:
| Approach | When to Use | Pros/Cons |
|---|---|---|
| Domain knowledge | Clear task categories exist | Simple, interpretable / May miss subtle relationships |
| Transfer-based | Can afford pre-training | Captures true transfer / Computationally expensive |
| Gradient-based | During MTL training | Dynamic, adapts / Noisy, varies during training |
| Learned embeddings | Many tasks, complex relationships | Flexible / Adds complexity, may overfit |
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
import numpy as npfrom sklearn.cluster import AgglomerativeClustering, SpectralClusteringfrom scipy.cluster.hierarchy import dendrogram, linkageimport matplotlib.pyplot as plt def cluster_tasks_hierarchical( affinity_matrix: np.ndarray, task_names: list, n_clusters: int = None, distance_threshold: float = None) -> dict: """ Hierarchical clustering of tasks based on affinity. """ # Convert affinity to distance distance_matrix = 1 - (affinity_matrix + affinity_matrix.T) / 2 np.fill_diagonal(distance_matrix, 0) # Perform clustering clustering = AgglomerativeClustering( n_clusters=n_clusters, distance_threshold=distance_threshold, metric='precomputed', linkage='average' ) labels = clustering.fit_predict(distance_matrix) # Group tasks by cluster clusters = {} for task, label in zip(task_names, labels): if label not in clusters: clusters[label] = [] clusters[label].append(task) return clusters def visualize_task_relationships( affinity_matrix: np.ndarray, task_names: list): """Visualize task relationships with dendrogram and heatmap.""" fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Heatmap im = axes[0].imshow(affinity_matrix, cmap='RdYlGn') axes[0].set_xticks(range(len(task_names))) axes[0].set_yticks(range(len(task_names))) axes[0].set_xticklabels(task_names, rotation=45, ha='right') axes[0].set_yticklabels(task_names) axes[0].set_title('Task Affinity Matrix') plt.colorbar(im, ax=axes[0]) # Dendrogram distance = 1 - (affinity_matrix + affinity_matrix.T) / 2 condensed = distance[np.triu_indices(len(task_names), k=1)] Z = linkage(condensed, method='average') dendrogram(Z, labels=task_names, ax=axes[1]) axes[1].set_title('Task Hierarchy') axes[1].set_ylabel('Distance') plt.tight_layout() plt.savefig('task_relationships.png', dpi=150) plt.close()Once task relationships are understood, they should inform architecture design:
1. Hierarchical Sharing:
If tasks form a hierarchy (e.g., coarse classification → fine classification), share earlier layers for all tasks, with later layers shared within subtrees:
[Shared Base] → [Cluster-1 Layers] → Task-1a, Task-1b
→ [Cluster-2 Layers] → Task-2a, Task-2b
2. Asymmetric Sharing:
If transfer is asymmetric (task A helps B but not vice versa), use auxiliary task design:
3. Task-Conditional Computation:
Use task identity to modulate shared computation:
For complex task sets, consider automated architecture search methods that learn optimal sharing patterns. Approaches like AutoML-Zero and neural architecture search can discover effective MTL structures that might not be obvious from task similarity measures alone.
When tasks have conflicting requirements, their gradients interfere during training, leading to negative transfer. Several techniques address this:
Gradient Manipulation:
GradNorm: Dynamically adjust task weights to balance gradient magnitudes
PCGrad (Projecting Conflicting Gradients): When $g_i \cdot g_j < 0$, project $g_i$ onto the normal plane of $g_j$: $$g_i' = g_i - \frac{g_i \cdot g_j}{||g_j||^2} g_j$$
CAGrad (Conflict-Averse Gradient Descent): Find gradient direction in the cone of task gradients that maximizes worst-case improvement
Architecture Solutions:
Task-Specific Adapters: Add small task-specific modules to shared architecture
Mixture of Experts: Route different inputs/tasks to specialized sub-networks
Modular Networks: Compose task-specific paths through shared modules
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
import torchfrom typing import Dict, List def pcgrad_update( task_gradients: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ PCGrad: Project conflicting gradients. """ task_names = list(task_gradients.keys()) modified_grads = {t: g.clone() for t, g in task_gradients.items()} for i, task_i in enumerate(task_names): g_i = modified_grads[task_i] for task_j in task_names: if task_i == task_j: continue g_j = task_gradients[task_j] # Check for conflict dot = torch.dot(g_i.flatten(), g_j.flatten()) if dot < 0: # Project g_i onto normal plane of g_j g_i = g_i - (dot / (g_j.norm() ** 2 + 1e-8)) * g_j modified_grads[task_i] = g_i # Average modified gradients final_grad = torch.stack(list(modified_grads.values())).mean(dim=0) return final_grad def gradnorm_weights( task_losses: Dict[str, torch.Tensor], initial_losses: Dict[str, float], current_weights: Dict[str, torch.Tensor], alpha: float = 1.5) -> Dict[str, torch.Tensor]: """ GradNorm: Compute balanced task weights. """ # Compute loss ratios loss_ratios = { t: task_losses[t] / initial_losses[t] for t in task_losses } mean_ratio = sum(loss_ratios.values()) / len(loss_ratios) # Compute relative inverse training rates inv_rates = { t: (loss_ratios[t] / mean_ratio) ** alpha for t in loss_ratios } # Target gradient norms mean_grad_norm = sum( current_weights[t] * task_losses[t] for t in task_losses ).item() new_weights = {} for t in task_losses: target = mean_grad_norm * inv_rates[t] new_weights[t] = current_weights[t] * target # Normalize weights weight_sum = sum(new_weights.values()) return {t: w / weight_sum for t, w in new_weights.items()}Understanding task relationships is essential for effective MTL. Next, we explore Optimization Challenges—the practical difficulties of training MTL systems and techniques to overcome them.