Loading learning content...
Multi-task learning (MTL) represents a fundamental paradigm shift in how we approach machine learning problems. Rather than training separate models for each task in isolation, MTL trains a single model on multiple tasks simultaneously, leveraging the inductive bias that tasks share underlying structure. At the heart of this paradigm lies the concept of shared representations—the idea that different tasks can benefit from learning and utilizing common feature representations.
This page provides a rigorous exploration of shared representations: why they work, how to formalize them mathematically, what architectural choices enable effective sharing, and how this principle manifests across different domains from computer vision to natural language processing. Understanding shared representations is essential because they form the theoretical and practical foundation upon which all multi-task learning architectures are built.
By the end of this page, you will understand: (1) the theoretical motivation for shared representations, (2) the mathematical formalization of representation sharing, (3) the relationship between shared representations and generalization, (4) how representation learning differs across domains, and (5) the key principles that guide effective representation sharing in practice.
Consider how humans learn. When a child learns to recognize chairs, they simultaneously develop visual processing capabilities that transfer to recognizing tables, sofas, and countless other objects. The child doesn't develop entirely separate visual systems for each object category—instead, they build a hierarchical representation that captures edges, textures, shapes, and spatial relationships that are useful across many recognition tasks.
This observation motivates the core hypothesis of multi-task learning: related tasks share underlying structure, and by learning this structure jointly, we can achieve better generalization than learning each task independently.
The statistical argument:
When we train a model on a single task with limited data, we risk overfitting to task-specific noise. The model may learn spurious correlations that happen to work on the training set but fail to generalize. By training on multiple related tasks simultaneously, we impose an implicit regularization: the model must learn representations that are useful across all tasks, effectively filtering out task-specific noise and preserving genuinely useful features.
Mathematically, if we denote the true underlying representation as $h^*(x)$ and our learned representation as $\hat{h}(x)$, single-task learning estimates:
$$\hat{h}{\text{single}}(x) = h^*(x) + \epsilon{\text{task}} + \epsilon_{\text{noise}}$$
where $\epsilon_{\text{task}}$ represents task-specific biases and $\epsilon_{\text{noise}}$ represents fitting noise. Multi-task learning, by averaging across tasks, can reduce both error terms:
$$\hat{h}{\text{multi}}(x) \approx h^*(x) + \frac{1}{T}\sum{t=1}^{T}\epsilon_{\text{task}}^{(t)} + \frac{\epsilon_{\text{noise}}}{\sqrt{T}}$$
The task-specific biases may cancel or average out, and the noise term decreases with more tasks.
Multi-task learning can be viewed as a form of regularization. By constraining the model to use representations that work across multiple tasks, we prevent it from exploiting task-specific shortcuts that don't generalize. This is particularly valuable when individual tasks have limited training data.
The information-theoretic perspective:
From an information-theoretic viewpoint, shared representations compress information that is relevant across tasks while discarding task-irrelevant details. Consider a representation $h(x)$ learned for tasks ${T_1, T_2, ..., T_k}$. An ideal shared representation maximizes:
$$I(h(X); Y_1, Y_2, ..., Y_k) - \beta \cdot I(h(X); X)$$
where $I(\cdot; \cdot)$ denotes mutual information. The first term encourages the representation to be predictive of all task outputs, while the second term (with coefficient $\beta$) encourages compressing away information from the input that isn't needed. This is a multi-task extension of the Information Bottleneck principle.
A good shared representation:
To rigorously understand shared representations, we need a formal mathematical framework. Let's establish the notation and key concepts that underpin multi-task learning theory.
Setup and Notation:
Suppose we have $T$ tasks, each defined by a distribution $\mathcal{D}t$ over input-output pairs $(x, y_t)$. For each task $t \in {1, 2, ..., T}$, we have a training set $S_t = {(x_i^{(t)}, y_i^{(t)})}{i=1}^{n_t}$.
In multi-task learning with shared representations, we decompose our model into:
The prediction for task $t$ given input $x$ is:
$$\hat{y}_t = f_t(h(x))$$
The shared representation $h$ is parameterized by weights $\theta_{\text{shared}}$, and each task head $f_t$ is parameterized by $\theta_t$.
The choice of representation space dimensionality |H| involves a trade-off. Higher dimensions can capture more task-relevant information but may lead to less sharing and increased overfitting. Lower dimensions enforce more compression but may lose task-critical information. This hyperparameter requires careful tuning based on task similarity and data availability.
The Multi-Task Objective:
The standard multi-task learning objective combines task-specific losses:
$$\mathcal{L}{\text{MTL}}(\theta{\text{shared}}, \theta_1, ..., \theta_T) = \sum_{t=1}^{T} \lambda_t \mathcal{L}t(\theta{\text{shared}}, \theta_t)$$
where:
Gradient Flow Through Shared Representations:
During backpropagation, the shared parameters receive gradients from all tasks:
$$\nabla_{\theta_{\text{shared}}} \mathcal{L}{\text{MTL}} = \sum{t=1}^{T} \lambda_t \nabla_{\theta_{\text{shared}}} \mathcal{L}_t$$
This means the shared representation is shaped by the learning signals from all tasks. If tasks have conflicting gradients (pointing in different directions in parameter space), the shared representation must find a compromise—this is both a strength (regularization) and a challenge (gradient interference).
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
import torchimport torch.nn as nnfrom typing import List, Dict class SharedRepresentationMTL(nn.Module): """ Multi-Task Learning model with explicit shared representation. Architecture: - Shared encoder: Maps input to shared representation space - Task-specific heads: Map representation to task outputs """ def __init__( self, input_dim: int, hidden_dims: List[int], representation_dim: int, task_output_dims: Dict[str, int], dropout_rate: float = 0.1 ): super().__init__() # Build shared encoder encoder_layers = [] prev_dim = input_dim for hidden_dim in hidden_dims: encoder_layers.extend([ nn.Linear(prev_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate) ]) prev_dim = hidden_dim # Final representation layer encoder_layers.append(nn.Linear(prev_dim, representation_dim)) self.shared_encoder = nn.Sequential(*encoder_layers) # Task-specific heads self.task_heads = nn.ModuleDict({ task_name: nn.Sequential( nn.Linear(representation_dim, representation_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(representation_dim // 2, output_dim) ) for task_name, output_dim in task_output_dims.items() }) self.representation_dim = representation_dim def get_shared_representation(self, x: torch.Tensor) -> torch.Tensor: """Extract the shared representation for input x.""" return self.shared_encoder(x) def forward( self, x: torch.Tensor, task_name: str ) -> torch.Tensor: """Forward pass for a specific task.""" # Get shared representation h = self.get_shared_representation(x) # Apply task-specific head return self.task_heads[task_name](h) def forward_all_tasks( self, x: torch.Tensor ) -> Dict[str, torch.Tensor]: """Forward pass for all tasks (useful for analysis).""" h = self.get_shared_representation(x) return { task_name: head(h) for task_name, head in self.task_heads.items() } class MTLTrainer: """ Trainer for multi-task learning with shared representations. Handles task weighting and gradient accumulation. """ def __init__( self, model: SharedRepresentationMTL, task_weights: Dict[str, float], learning_rate: float = 1e-3 ): self.model = model self.task_weights = task_weights self.optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate ) def train_step( self, task_batches: Dict[str, tuple] ) -> Dict[str, float]: """ Single training step across all tasks. Args: task_batches: Dict mapping task name to (x, y) batches Returns: Dict of task losses """ self.model.train() self.optimizer.zero_grad() total_loss = 0 task_losses = {} for task_name, (x, y) in task_batches.items(): # Forward pass predictions = self.model(x, task_name) # Compute task loss (assuming classification) loss = nn.functional.cross_entropy(predictions, y) # Weight and accumulate weighted_loss = self.task_weights[task_name] * loss total_loss += weighted_loss task_losses[task_name] = loss.item() # Backward pass - gradients flow through shared representation total_loss.backward() # Optional: gradient clipping for stability torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() return task_lossesRepresentation Learning Dynamics:
The optimization dynamics of shared representations are governed by how task gradients interact. Let $g_t = \nabla_{\theta_{\text{shared}}} \mathcal{L}_t$ be the gradient for task $t$. The effective update to shared parameters is:
$$\Delta \theta_{\text{shared}} \propto -\sum_{t} \lambda_t g_t$$
Three scenarios can occur:
Aligned gradients ($g_i \cdot g_j > 0$): Tasks agree on the update direction. Learning each task reinforces the others.
Orthogonal gradients ($g_i \cdot g_j \approx 0$): Tasks require different features but don't conflict. The representation can learn features for both.
Conflicting gradients ($g_i \cdot g_j < 0$): Tasks disagree on the update direction. This creates negative transfer where improving one task hurts another.
Understanding these dynamics is crucial for designing effective MTL systems and diagnosing training issues.
The primary motivation for using shared representations is improved generalization. But why exactly does sharing lead to better generalization? This section provides a rigorous analysis of the statistical and computational benefits.
Sample Complexity Reduction:
Consider learning a representation $h$ of complexity $C_h$ (e.g., measured by VC dimension or Rademacher complexity). In single-task learning with $n$ samples:
$$\text{Generalization Error} \leq \hat{\mathcal{L}} + \mathcal{O}\left(\sqrt{\frac{C_h + C_f}{n}}\right)$$
In multi-task learning with $T$ tasks and $n$ samples per task:
$$\text{Generalization Error} \leq \hat{\mathcal{L}} + \mathcal{O}\left(\sqrt{\frac{C_h}{Tn}} + \sqrt{\frac{C_f}{n}}\right)$$
The shared representation complexity $C_h$ is now amortized across $T$ tasks, reducing its contribution to generalization error by a factor of $\sqrt{T}$. This is the fundamental statistical benefit: we learn the expensive representation using data from all tasks, while only the cheap task-specific heads need per-task data.
If the shared representation captures most of the model complexity (as in deep networks where the encoder is much larger than the heads), multi-task learning can dramatically improve data efficiency. This is especially valuable when individual tasks have limited labeled data but the overall data pool is substantial.
Inductive Bias as Implicit Regularization:
Beyond sample complexity, shared representations impose a beneficial inductive bias. By forcing the representation to be useful across multiple tasks, we bias the learning toward features that capture invariant, causal structure rather than spurious correlations.
Formally, suppose each task has a true underlying function $f_t^* = f_t \circ h^$ where $h^$ is the ideal shared representation. Single-task learning might find any $\tilde{h}$ such that $f_t \circ \tilde{h} \approx f_t^*$ on the training data—there are many such $\tilde{h}$, and most don't generalize. Multi-task learning restricts the search to representations that work for all tasks, greatly reducing the hypothesis space.
Theoretical Guarantees (Ben-David et al.):
Seminal work by Ben-David et al. provides formal guarantees for multi-task learning. Define the task relatedness as:
$$\mathcal{R}(T_1, T_2) = \sup_{h \in \mathcal{H}} |\mathcal{L}{T_1}(h) - \mathcal{L}{T_2}(h)|$$
This measures the maximum difference in loss any representation can achieve across the two tasks. When $\mathcal{R}$ is small, tasks are related and share optimal representations.
The generalization bound becomes:
$$\mathcal{L}{\text{test}}(h, f_t) \leq \hat{\mathcal{L}}{\text{train}} + \mathcal{R} + \text{complexity term}$$
When tasks are highly related ($\mathcal{R} \approx 0$), multi-task learning provides strong generalization guarantees.
| Mechanism | Description | When It Helps Most |
|---|---|---|
| Sample Efficiency | Representation complexity amortized across tasks | Many tasks with limited per-task data |
| Regularization | Forces representations to be broadly useful | Tasks prone to overfitting; complex representations |
| Feature Selection | Identifies features useful across tasks (likely causal) | Noisy data with spurious correlations |
| Attention Focusing | Auxiliary tasks highlight informative input regions | Complex inputs with irrelevant dimensions |
| Implicit Data Augmentation | Task variations act like augmented examples | Limited data variety for primary task |
Eavesdropping and Representation Bootstrapping:
An underappreciated benefit is that tasks can "eavesdrop" on each other's learning signals. A task $T_1$ that is difficult to learn directly might become easier when co-trained with a related task $T_2$, because the representation learned for $T_2$ provides useful features.
For example:
This creates a form of curriculum learning where easier tasks bootstrap harder tasks through shared representations.
Understanding what shared representations actually look like—their internal structure, geometry, and properties—provides insight into why they work and how to design better MTL systems.
Hierarchical Feature Organization:
In deep networks, shared representations typically exhibit hierarchical structure:
This hierarchy reflects the compositional structure of the world: complex concepts are built from simpler primitives that recur across domains.
A common heuristic is to share early layers (generic features) while allowing later layers to specialize per task. The optimal split point depends on task similarity: highly related tasks can share more layers; diverse tasks may only share the earliest layers.
Representation Geometry:
The geometry of shared representations reveals how tasks relate in the learned feature space. Key geometric properties include:
1. Linear Decodability: A well-structured shared representation allows each task to be solved by a simple (often linear) transformation. This suggests the representation organizes information in a way that separates task-relevant dimensions.
$$y_t \approx W_t \cdot h(x) + b_t$$
where $W_t$ and $b_t$ are the task-specific head parameters.
2. Subspace Organization: Tasks often utilize overlapping but distinct subspaces of the representation. Analysis shows that:
3. Manifold Structure: Data from different tasks may lie on different manifolds within the representation space, connected by shared regions corresponding to common features.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
import torchimport numpy as npfrom sklearn.decomposition import PCAfrom sklearn.cross_decomposition import CCAimport matplotlib.pyplot as pltfrom typing import Dict, List, Tuple def analyze_shared_representation_structure( model: torch.nn.Module, task_dataloaders: Dict[str, torch.utils.data.DataLoader], representation_layer: str = 'shared_encoder') -> Dict[str, np.ndarray]: """ Analyze the structure of learned shared representations. Computes: 1. Per-task representation statistics 2. Cross-task alignment (via CCA) 3. Shared vs unique variance decomposition 4. Representation visualization """ model.eval() # Extract representations for each task task_representations = {} task_labels = {} with torch.no_grad(): for task_name, dataloader in task_dataloaders.items(): reps = [] labels = [] for x, y in dataloader: # Get shared representation h = model.get_shared_representation(x) reps.append(h.cpu().numpy()) labels.append(y.cpu().numpy()) task_representations[task_name] = np.vstack(reps) task_labels[task_name] = np.concatenate(labels) results = {} # 1. Per-task statistics print("=" * 60) print("Per-Task Representation Statistics") print("=" * 60) for task_name, reps in task_representations.items(): mean = np.mean(reps, axis=0) std = np.std(reps, axis=0) # Effective dimensionality (participation ratio) cov = np.cov(reps.T) eigenvalues = np.linalg.eigvalsh(cov) eigenvalues = eigenvalues[eigenvalues > 1e-10] participation_ratio = (np.sum(eigenvalues) ** 2) / np.sum(eigenvalues ** 2) print(f"\nTask: {task_name}") print(f" Samples: {len(reps)}") print(f" Mean norm: {np.linalg.norm(mean):.4f}") print(f" Avg std: {np.mean(std):.4f}") print(f" Effective dimensionality: {participation_ratio:.2f}") results[f'{task_name}_effective_dim'] = participation_ratio # 2. Cross-task CCA analysis task_names = list(task_representations.keys()) if len(task_names) >= 2: print("\n" + "=" * 60) print("Cross-Task Canonical Correlation Analysis") print("=" * 60) for i, task1 in enumerate(task_names): for task2 in task_names[i+1:]: # Align sample sizes n_samples = min( len(task_representations[task1]), len(task_representations[task2]) ) X = task_representations[task1][:n_samples] Y = task_representations[task2][:n_samples] # Fit CCA n_components = min(10, X.shape[1]) cca = CCA(n_components=n_components) cca.fit(X, Y) # Get canonical correlations X_c, Y_c = cca.transform(X, Y) correlations = [ np.corrcoef(X_c[:, i], Y_c[:, i])[0, 1] for i in range(n_components) ] print(f"\n{task1} <-> {task2}:") print(f" Top-5 canonical correlations: {correlations[:5]}") print(f" Mean correlation: {np.mean(correlations):.4f}") results[f'cca_{task1}_{task2}'] = np.array(correlations) # 3. Shared variance decomposition print("\n" + "=" * 60) print("Shared vs Unique Variance Analysis") print("=" * 60) # Stack all representations all_reps = np.vstack([reps for reps in task_representations.values()]) # Total variance total_var = np.var(all_reps, axis=0).sum() # Within-task variance (unique) within_var = sum( len(reps) * np.var(reps, axis=0).sum() for reps in task_representations.values() ) / len(all_reps) # Between-task variance (shared) task_means = np.array([ np.mean(reps, axis=0) for reps in task_representations.values() ]) between_var = np.var(task_means, axis=0).sum() shared_ratio = between_var / total_var if total_var > 0 else 0 print(f"\nTotal variance: {total_var:.4f}") print(f"Within-task (unique) variance: {within_var:.4f}") print(f"Between-task (shared) variance: {between_var:.4f}") print(f"Shared variance ratio: {shared_ratio:.4f}") results['shared_variance_ratio'] = shared_ratio # 4. Visualization fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # PCA visualization pca = PCA(n_components=2) all_reps_2d = pca.fit_transform(all_reps) colors = plt.cm.tab10(np.linspace(0, 1, len(task_names))) start_idx = 0 for task_idx, (task_name, reps) in enumerate(task_representations.items()): end_idx = start_idx + len(reps) axes[0].scatter( all_reps_2d[start_idx:end_idx, 0], all_reps_2d[start_idx:end_idx, 1], c=[colors[task_idx]], label=task_name, alpha=0.5, s=10 ) start_idx = end_idx axes[0].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} var)') axes[0].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} var)') axes[0].set_title('Shared Representation Structure (PCA)') axes[0].legend() # Variance spectrum eigenvalues = pca.explained_variance_ratio_ axes[1].bar(range(1, len(eigenvalues) + 1), eigenvalues[:20]) axes[1].set_xlabel('Principal Component') axes[1].set_ylabel('Explained Variance Ratio') axes[1].set_title('Representation Variance Spectrum') plt.tight_layout() plt.savefig('representation_structure.png', dpi=150, bbox_inches='tight') plt.close() print("\nVisualization saved to 'representation_structure.png'") return resultsDisentanglement and Factorization:
Ideal shared representations exhibit disentanglement: different factors of variation are encoded in separate dimensions. This makes it easier for task-specific heads to extract relevant information.
For multi-task learning, a useful property is task-aware factorization:
$$h(x) = [h_{\text{shared}}(x), h_{\text{task-specific}}(x)]$$
where $h_{\text{shared}}$ captures factors common across all tasks, and $h_{\text{task-specific}}$ captures factors relevant only to specific tasks. While full disentanglement is generally impossible without inductive biases or supervision, MTL naturally encourages partial disentanglement by forcing representations to serve multiple purposes.
The nature of shared representations varies dramatically across domains. Understanding these domain-specific patterns is essential for designing effective MTL systems.
Computer Vision:
In vision, shared representations follow the hierarchical feature learning paradigm established by deep convolutional networks:
Typical MTL tasks (object detection, segmentation, depth estimation, surface normal prediction) share early convolutional layers and diverge at later stages. Empirical studies show that sharing up to layer 3-4 of networks like ResNet provides benefit, while sharing beyond that can hurt performance for dissimilar tasks.
Natural Language Processing:
In NLP, shared representations have evolved dramatically with the transformer era:
Pre-transformer: Shared word embeddings (Word2Vec, GloVe) provided basic semantic representations. Recurrent layers learned to compose these into sentence representations, shared across tasks like sentiment analysis, NER, and parsing.
Transformer era: Models like BERT represent the ultimate shared representation—a single pre-trained model provides contextual embeddings that transfer to virtually all NLP tasks. The representation is:
Key insight: In NLP, the shared representation often is the model. Fine-tuning adds minimal task-specific parameters on top of massive shared representations.
The success of shared representations in NLP has led to the 'foundation model' paradigm: pre-train a massive shared representation on broad data, then adapt to specific tasks with minimal additional parameters. This pattern is now spreading to vision (ViT, CLIP), audio (Whisper), and multimodal models (GPT-4V).
Reinforcement Learning:
In RL, shared representations are particularly valuable because:
Common sharing patterns:
Multi-modal Learning:
Modern systems increasingly learn shared representations across modalities (text, images, audio). Key approaches:
The challenge is bridging very different data types while preserving modality-specific information.
Based on both theory and extensive empirical evidence, we can distill key principles for designing effective shared representations in multi-task learning systems.
More sharing isn't always better. Excessive sharing can cause negative transfer when tasks have conflicting requirements. The optimal amount of sharing depends on task relatedness, data availability, and model capacity. Treat the degree of sharing as a hyperparameter to be tuned.
Practical Implementation Guidelines:
Start with established backbones: For vision, use pre-trained ResNet/ViT; for NLP, use pre-trained transformers. These provide strong initial shared representations.
Use normalization carefully: Batch normalization can be problematic in MTL because statistics differ across tasks. Consider layer normalization or task-specific batch norm statistics.
Initialize properly: Task-specific heads should be initialized such that initial predictions are reasonable for all tasks. Bad initialization can cause one task to dominate early training.
Validate on all tasks: Multi-task performance should be evaluated holistically. Improvement on average with severe degradation on one task may be unacceptable.
Analyze representations: Periodically inspect the learned representations using the techniques described earlier. This provides insight into whether sharing is working as intended.
Shared representations are the conceptual and mathematical foundation of multi-task learning. By learning features that are useful across multiple tasks, we achieve better generalization, improved sample efficiency, and more robust models.
Next Steps:
With a solid understanding of shared representations, we're ready to explore specific mechanisms for implementing them. The next page covers Hard vs Soft Parameter Sharing—two fundamental paradigms for how parameters are shared across tasks in multi-task learning architectures.
You now understand the theoretical foundations, mathematical formalization, and practical implications of shared representations in multi-task learning. This knowledge provides the conceptual framework for understanding all MTL architectures and techniques covered in subsequent pages.