Loading content...
In traditional machine learning, we carefully shuffle our dataset to ensure training batches are IID—Independent and Identically Distributed. Random shuffling means each batch is a representative sample of the full data distribution, enabling unbiased gradient estimates and predictable convergence.
Federated learning shatters this assumption. Each client's local data reflects their unique context:
This data heterogeneity causes fundamental challenges: local gradients point in different directions, simple averaging produces a model that fits no one well, and convergence becomes erratic or fails entirely. Understanding and addressing non-IID data is the central technical challenge of federated learning.
By the end of this page, you will understand the taxonomy of data heterogeneity in FL, the theoretical analysis of why non-IID causes problems, personalization techniques that adapt models to individual clients, clustering approaches for discovering client groups, and multi-task learning frameworks that embrace heterogeneity as a feature rather than a bug.
Data heterogeneity in federated learning manifests in multiple ways. Understanding these categories is crucial for selecting appropriate mitigation strategies.
1. Label Distribution Skew (Label Imbalance)
Different clients have different proportions of each class. In extreme cases, a client may have samples from only a subset of classes.
Example: In a handwritten digit recognition system, one user writes mostly "1" and "7" (quick notes), while another writes all digits equally (student doing homework).
2. Feature Distribution Skew (Covariate Shift)
The same label appears differently across clients. The conditional distribution P(X|Y) varies.
Example: Different hospitals' X-ray machines produce subtly different images. Even for the same diagnosis (Y), the image features (X) differ.
3. Same Features, Different Labels (Concept Drift)
The same features lead to different labels for different clients. P(Y|X) varies.
Example: In sentiment analysis, "This product is sick!" is positive for younger users but negative for older users.
| Type | What Varies | Mathematical Form | Real-World Example |
|---|---|---|---|
| Label skew | Class proportions | P(Y) differs | Some users write only certain digits |
| Feature skew | Feature distribution per class | P(X|Y) differs | Different camera qualities |
| Concept shift | Labels for same features | P(Y|X) differs | Cultural differences in sentiment |
| Quantity skew | Number of samples | |D_k| varies | Power users vs. casual users |
| Temporal skew | Data distribution over time | P(X,Y;t) varies | Seasonal behavior changes |
4. Quantity Skew (Unbalanced Clients)
Clients have vastly different amounts of data. Some may have millions of samples, others just dozens.
Challenge: Simple averaging gives equal weight to a client with 10 samples and one with 10,000. Weighted averaging helps but introduces other issues.
5. Temporal Heterogeneity
Data distributions change over time, both within and across clients. New trends emerge; old patterns fade.
Challenge: A model trained on last year's patterns may not apply this year. Continual learning meets federated learning.
In practice, multiple types of heterogeneity occur simultaneously. A federated health system may have label skew (different patient populations), feature skew (different equipment), quantity skew (different hospital sizes), and temporal skew (seasonal diseases). Solutions must address this compound heterogeneity.
Let's develop a theoretical understanding of why non-IID data causes convergence issues in federated learning.
The IID Case:
In IID settings, each client's local gradient is an unbiased estimate of the global gradient:
E[∇F_k(w)] = ∇F(w) for all clients k
When we average local gradients, we get the true gradient with variance reduced by a factor of 1/K (K clients).
The Non-IID Case:
With non-IID data, local gradients are biased:
E[∇F_k(w)] = ∇F_k(w) ≠ ∇F(w)
The local objective F_k differs from the global objective F. Averaging gives us an estimate of the weighted average of local objectives—which may differ substantially from the true global optimum.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
# Theoretical Analysis of Non-IID Effects in Federated Learningimport numpy as npfrom typing import List, Tuplefrom dataclasses import dataclass @dataclassclass ConvergenceAnalysis: """ Analysis of FedAvg convergence under non-IID data. Key theoretical results from Li et al., 2020: "On the Convergence of FedAvg on Non-IID Data" For K clients, E local epochs, learning rate η: Convergence bound includes term: O(η²E²G²) where G² = (1/K)Σ_k ||∇F_k(w*) - ∇F(w*)||² G² is the "gradient dissimilarity" - measures how different local gradients are at the global optimum. In IID case: G² = 0 (local gradients same at optimum) In non-IID case: G² > 0 (local gradients point different directions) """ @staticmethod def gradient_dissimilarity( local_gradients: List[np.ndarray], # Gradients at w* client_weights: List[float] # Proportion of data per client ) -> float: """ Compute gradient dissimilarity G². G² = Σ_k p_k ||∇F_k(w*) - ∇F(w*)||² where p_k = |D_k| / Σ|D_k| is client weight, and ∇F(w*) = Σ_k p_k ∇F_k(w*) is weighted average. Higher G² ⟹ harder convergence """ # Compute global gradient as weighted average global_gradient = sum( w * g for w, g in zip(client_weights, local_gradients) ) # Compute dissimilarity dissimilarity = sum( w * np.linalg.norm(g - global_gradient) ** 2 for w, g in zip(client_weights, local_gradients) ) return dissimilarity @staticmethod def fedavg_convergence_bound( gradient_dissimilarity: float, # G² num_local_epochs: int, # E learning_rate: float, # η num_rounds: int, # T smoothness: float = 1.0, # L (smoothness constant) variance: float = 1.0 # σ² (stochastic gradient variance) ) -> float: """ Upper bound on ||∇F(w_T)||² after T rounds. Simplified from Li et al., 2020: E[||∇F(w_T)||²] ≤ O(1/√(KT)) + O(η²E²G²) First term: Optimization error (decreases with more rounds/clients) Second term: Bias from non-IID (increases with local epochs and G²) Key insight: More local epochs (E) amplifies non-IID bias! This explains why FedAvg struggles with non-IID data when E is large. """ K = 1 # Placeholder (number of clients affects first term) # Optimization error (decreases with T) opt_error = 1.0 / np.sqrt(K * num_rounds) # Non-IID bias (constant, doesn't decrease with T!) non_iid_bias = learning_rate ** 2 * num_local_epochs ** 2 * gradient_dissimilarity return opt_error + non_iid_bias def simulate_non_iid_convergence(): """ Simulation demonstrating non-IID effects. Shows that with non-IID data: 1. Local updates drift from global optimum 2. Averaging doesn't yield global optimum 3. More local epochs can hurt in non-IID setting """ np.random.seed(42) # Simulate 4 clients with different local optima # Client k's local function: F_k(w) = 0.5 * ||w - w_k*||² # where w_k* is client k's local optimum local_optima = [ np.array([1.0, 0.0]), # Client 0: optimum at (1, 0) np.array([0.0, 1.0]), # Client 1: optimum at (0, 1) np.array([-1.0, 0.0]), # Client 2: optimum at (-1, 0) np.array([0.0, -1.0]), # Client 3: optimum at (0, -1) ] # True global optimum (if equally weighted): origin (0, 0) global_optimum = np.mean(local_optima, axis=0) print(f"True global optimum: {global_optimum}") # Federated averaging simulation w = np.array([0.5, 0.5]) # Initial weights learning_rate = 0.1 num_rounds = 50 results = {"iid": [], "non_iid_E1": [], "non_iid_E5": []} for E in [1, 5]: # Compare 1 vs 5 local epochs w = np.array([0.5, 0.5]) for round_t in range(num_rounds): # Each client does E local gradient steps local_updates = [] for k, w_k_star in enumerate(local_optima): w_local = w.copy() for _ in range(E): # Gradient of F_k(w) = 0.5 * ||w - w_k*||² # ∇F_k(w) = w - w_k* gradient = w_local - w_k_star w_local = w_local - learning_rate * gradient local_updates.append(w_local - w) # Delta # Average updates avg_update = np.mean(local_updates, axis=0) w = w + avg_update # Record distance to global optimum key = f"non_iid_E{E}" results[key].append(np.linalg.norm(w - global_optimum)) return results class ClientDrift: """ Analyze and visualize client drift in non-IID FL. Client drift = deviation of local model from global model after local training. High drift ⟹ clients pulling in different directions ⟹ averaging produces suboptimal result """ def __init__(self, num_clients: int): self.num_clients = num_clients self.drift_history: List[List[float]] = [] def measure_drift( self, global_weights: np.ndarray, local_weights: List[np.ndarray] ) -> Tuple[float, List[float]]: """ Measure drift of each client from global model. Returns: avg_drift: Average drift across clients per_client_drift: List of per-client drift values """ drifts = [ np.linalg.norm(local - global_weights) for local in local_weights ] self.drift_history.append(drifts) return np.mean(drifts), drifts def analyze_drift_growth(self, num_local_epochs: int) -> float: """ Analyze how drift grows with local epochs. Theoretical expectation: drift ∝ E (local epochs) With non-IID: drift ∝ E × G (gradient dissimilarity) """ if len(self.drift_history) < 2: return 0.0 # Compare consecutive rounds drift_growth = [] for i in range(1, len(self.drift_history)): prev_avg = np.mean(self.drift_history[i-1]) curr_avg = np.mean(self.drift_history[i]) if prev_avg > 0: drift_growth.append(curr_avg / prev_avg) return np.mean(drift_growth) if drift_growth else 0.0Key Theoretical Insights:
Gradient dissimilarity bounds convergence — The term G² = Σ_k ||∇F_k(w*) - ∇F(w*)|| measures how different local gradients are. Higher G² means worse convergence.
Local epochs amplify bias — The convergence bound includes O(η²E²G²). More local epochs (E) quadratically amplifies the non-IID effect!
Cannot decrease bias with more rounds — Unlike optimization error (decreases as O(1/√T)), the non-IID bias term is constant. More training rounds don't help.
Trade-off exists — More local epochs reduce communication but amplify non-IID issues. Optimal E depends on heterogeneity level.
Non-IID introduces bias (local gradients point wrong direction). More local steps accumulate this bias. IID data has no bias (only variance), so more local steps just reduce variance. This explains why FedAvg works well for IID but struggles with non-IID.
Instead of forcing all clients to share a single global model, personalization tailors models to each client's local distribution while still benefiting from collaborative learning.
Why Personalization?
A single global model optimizes for the "average" user—which may not exist. If half your users speak English and half speak Spanish, the average model serves neither well. Personalization learns to serve each population effectively.
Personalization Strategies:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
# Personalization Methods for Federated Learningimport numpy as npfrom typing import List, Dict, Tuple, Optionalfrom abc import ABC, abstractmethod class PersonalizationStrategy(ABC): """Base class for personalization strategies.""" @abstractmethod def personalize( self, global_model: np.ndarray, local_data: np.ndarray, local_labels: np.ndarray ) -> np.ndarray: """Personalize global model for a specific client.""" pass class LocalFineTuning(PersonalizationStrategy): """ Simplest personalization: fine-tune global model locally. Process: 1. Train global model with FedAvg (communication rounds) 2. Each client downloads final global model 3. Each client fine-tunes on local data (no communication) Advantages: - Simple, no changes to FL algorithm - Works well in practice Disadvantages: - Requires enough local data to fine-tune without overfitting - Global model may be far from local optimum """ def __init__( self, num_epochs: int = 5, learning_rate: float = 0.01, early_stopping: bool = True ): self.num_epochs = num_epochs self.learning_rate = learning_rate self.early_stopping = early_stopping def personalize( self, global_model: np.ndarray, local_data: np.ndarray, local_labels: np.ndarray ) -> np.ndarray: """Fine-tune global model on local data.""" personal_model = global_model.copy() best_model = personal_model.copy() best_loss = float('inf') # Split local data for validation if early stopping if self.early_stopping and len(local_data) > 20: split = int(0.8 * len(local_data)) train_x, val_x = local_data[:split], local_data[split:] train_y, val_y = local_labels[:split], local_labels[split:] else: train_x, train_y = local_data, local_labels val_x = val_y = None for epoch in range(self.num_epochs): # Simple gradient descent gradient = self._compute_gradient(personal_model, train_x, train_y) personal_model -= self.learning_rate * gradient # Early stopping check if val_x is not None: val_loss = self._compute_loss(personal_model, val_x, val_y) if val_loss < best_loss: best_loss = val_loss best_model = personal_model.copy() elif self.early_stopping: break # Stop if validation loss increases return best_model if self.early_stopping else personal_model def _compute_gradient(self, model, x, y) -> np.ndarray: """Compute gradient (placeholder).""" return np.random.randn(*model.shape) * 0.01 def _compute_loss(self, model, x, y) -> float: """Compute loss (placeholder).""" return np.random.random() class PartialPersonalization(PersonalizationStrategy): """ Share base layers globally, personalize head layers. Intuition: Early layers learn general features (edges, textures) that transfer across clients. Later layers learn task-specific patterns that should be personalized. Paper: "Think Locally, Act Globally: Federated Learning with Local and Global Representations" (Liang et al., 2020) """ def __init__( self, shared_layer_indices: List[int], personal_layer_indices: List[int] ): self.shared_layers = set(shared_layer_indices) self.personal_layers = set(personal_layer_indices) def split_model( self, full_model: Dict[str, np.ndarray] ) -> Tuple[Dict, Dict]: """ Split model into shared and personal components. Returns: shared_weights: Layers to participate in federation personal_weights: Layers to keep local """ shared = {} personal = {} for name, weights in full_model.items(): layer_idx = self._get_layer_index(name) if layer_idx in self.shared_layers: shared[name] = weights else: personal[name] = weights return shared, personal def merge_model( self, shared_weights: Dict[str, np.ndarray], personal_weights: Dict[str, np.ndarray] ) -> Dict[str, np.ndarray]: """Merge shared and personal components.""" return {**shared_weights, **personal_weights} def _get_layer_index(self, name: str) -> int: """Extract layer index from parameter name.""" # Simplified: assumes naming like "layer_0.weight" import re match = re.search(r'layer_(d+)', name) return int(match.group(1)) if match else 0 class FedPer: """ FedPer: Federated Learning with Personalization Layers (Arivazhagan et al., 2019) Similar to partial personalization but with specific protocol: 1. Base layers: updated via FedAvg 2. Personalization layers: never sent to server Typically: all convolutional layers are base; fully-connected head is personalization. """ def __init__( self, base_layer_names: List[str], personal_layer_names: List[str] ): self.base_layers = set(base_layer_names) self.personal_layers = set(personal_layer_names) self.local_personal_weights: Dict[str, Dict] = {} def client_get_base_update( self, client_id: str, trained_model: Dict[str, np.ndarray] ) -> Dict[str, np.ndarray]: """ Extract base layers for federation. Personal layers are stored locally, not sent. """ # Store personal layers locally self.local_personal_weights[client_id] = { name: weights for name, weights in trained_model.items() if name in self.personal_layers } # Return only base layers for aggregation return { name: weights for name, weights in trained_model.items() if name in self.base_layers } def client_receive_global( self, client_id: str, global_base: Dict[str, np.ndarray] ) -> Dict[str, np.ndarray]: """ Combine received global base with local personal layers. """ personal = self.local_personal_weights.get(client_id, {}) return {**global_base, **personal} class pFedMe: """ pFedMe: Personalized Federated Learning with Moreau Envelopes (T Dinh et al., 2020) Each client maintains a personalized model θᵢ and contributes to learning a global model w that serves as a regularizer. Client objective: min_{θᵢ} F_i(θᵢ) + (λ/2)||θᵢ - w||² The global model w provides a "prior" that prevents personal models from overfitting on small local datasets. Key insight: The global model is learned to be a good initialization for all clients, not the optimum for any. """ def __init__( self, lambda_reg: float = 15.0, # Regularization strength local_lr: float = 0.01, personalization_steps: int = 5 ): self.lambda_reg = lambda_reg self.local_lr = local_lr self.K = personalization_steps # Steps to approximate θᵢ* def local_personalize( self, global_weights: np.ndarray, personal_weights: np.ndarray, local_data: np.ndarray, local_labels: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ Bi-level optimization for pFedMe. Inner loop: Find personalized θᵢ given current w θᵢ = argmin F_i(θ) + (λ/2)||θ - w||² Outer loop: Update w using gradient at θᵢ w ← w - β * (w - θᵢ) Returns: updated_personal: New personalized model global_gradient: Gradient for global model update """ theta = personal_weights.copy() # Inner loop: K steps to approximate θᵢ*(w) for k in range(self.K): # Gradient of local loss task_grad = self._compute_gradient(theta, local_data, local_labels) # Gradient of regularization reg_grad = self.lambda_reg * (theta - global_weights) # Update personal model theta -= self.local_lr * (task_grad + reg_grad) # Compute update for global model # Gradient is simply λ(w - θᵢ) from the envelope global_gradient = self.lambda_reg * (global_weights - theta) return theta, global_gradient def _compute_gradient(self, model, x, y) -> np.ndarray: """Compute task gradient (placeholder).""" return np.random.randn(*model.shape) * 0.01Personalization is most valuable when (1) significant label or concept shift exists between clients, (2) clients have enough local data to avoid overfitting, and (3) a single global model demonstrably underperforms local-only training. Start with local fine-tuning—it's simple and often sufficient.
When client populations have natural groupings (e.g., different languages, regions, or industries), learning cluster-specific models can outperform both a single global model and fully personalized models.
The Clustering Intuition:
Rather than assuming all clients are the same (global) or completely different (full personalization), clustering finds a middle ground: clients within a cluster share similar distributions and benefit from a shared model, while different clusters have different models.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
# Clustered Federated Learningimport numpy as npfrom typing import List, Dict, Tuple, Optionalfrom sklearn.cluster import KMeansfrom scipy.spatial.distance import cosine class ClusteredFL: """ Clustered Federated Learning (CFL) Automatically discovers clusters of clients with similar data distributions and trains separate models per cluster. Key challenge: How to cluster without seeing client data? Solution: Use model updates/gradients as proxy for data similarity. Reference: "Clustered Federated Learning: Model-Agnostic Distributed Multi-Task Optimization" (Sattler et al., 2020) """ def __init__( self, num_clusters: int, similarity_threshold: float = 0.8 ): self.num_clusters = num_clusters self.similarity_threshold = similarity_threshold # State self.cluster_models: Dict[int, np.ndarray] = {} self.client_clusters: Dict[str, int] = {} self.client_update_history: Dict[str, List[np.ndarray]] = {} def initialize_clusters(self, initial_model: np.ndarray): """Initialize all cluster models to same starting point.""" for c in range(self.num_clusters): self.cluster_models[c] = initial_model.copy() def assign_client_to_cluster( self, client_id: str, client_update: np.ndarray ) -> int: """ Assign client to cluster based on update similarity. Idea: Clients with similar data distributions produce similar gradient updates. Use gradient as feature for clustering. """ # Store update for this client if client_id not in self.client_update_history: self.client_update_history[client_id] = [] self.client_update_history[client_id].append(client_update) # If client already assigned and update is consistent, keep assignment if client_id in self.client_clusters: # Check if update is consistent with cluster cluster_id = self.client_clusters[client_id] cluster_centroid = self._get_cluster_centroid(cluster_id) if cluster_centroid is not None: similarity = self._compute_similarity(client_update, cluster_centroid) if similarity > self.similarity_threshold: return cluster_id # Keep current assignment # Find best matching cluster best_cluster = 0 best_similarity = -1 for c in range(self.num_clusters): centroid = self._get_cluster_centroid(c) if centroid is None: continue similarity = self._compute_similarity(client_update, centroid) if similarity > best_similarity: best_similarity = similarity best_cluster = c self.client_clusters[client_id] = best_cluster return best_cluster def _get_cluster_centroid(self, cluster_id: int) -> Optional[np.ndarray]: """Compute centroid of updates from clients in this cluster.""" cluster_clients = [ cid for cid, c in self.client_clusters.items() if c == cluster_id ] if not cluster_clients: return None # Average most recent updates from cluster members updates = [ self.client_update_history[cid][-1] for cid in cluster_clients if cid in self.client_update_history ] if not updates: return None return np.mean(updates, axis=0) def _compute_similarity( self, update1: np.ndarray, update2: np.ndarray ) -> float: """Compute cosine similarity between updates.""" flat1, flat2 = update1.flatten(), update2.flatten() # Handle zero vectors if np.linalg.norm(flat1) == 0 or np.linalg.norm(flat2) == 0: return 0.0 return 1 - cosine(flat1, flat2) def aggregate_cluster( self, cluster_id: int, client_updates: Dict[str, np.ndarray], client_weights: Dict[str, float] ) -> np.ndarray: """Aggregate updates for a specific cluster.""" cluster_updates = { cid: update for cid, update in client_updates.items() if self.client_clusters.get(cid) == cluster_id } if not cluster_updates: return self.cluster_models[cluster_id] # Weighted average total_weight = sum( client_weights[cid] for cid in cluster_updates.keys() ) aggregated = np.zeros_like(list(cluster_updates.values())[0]) for cid, update in cluster_updates.items(): aggregated += (client_weights[cid] / total_weight) * update # Update cluster model self.cluster_models[cluster_id] += aggregated return self.cluster_models[cluster_id] class IFCA: """ Iterative Federated Clustering Algorithm (IFCA) Ghosh et al., 2020 Alternates between: 1. Cluster assignment: Each client picks best cluster model 2. Model update: Train each cluster model with assigned clients Assumes number of clusters known a priori. """ def __init__(self, num_clusters: int): self.k = num_clusters self.cluster_models: List[np.ndarray] = [] def initialize(self, model_shape: Tuple[int, ...]): """Initialize K random cluster models.""" self.cluster_models = [ np.random.randn(*model_shape) * 0.01 for _ in range(self.k) ] def client_select_cluster( self, local_data: np.ndarray, local_labels: np.ndarray ) -> int: """ Client selects cluster with lowest loss on local data. This is evaluated locally—no data leaves the client. """ losses = [] for c, model in enumerate(self.cluster_models): loss = self._compute_loss(model, local_data, local_labels) losses.append(loss) return int(np.argmin(losses)) def server_aggregate( self, cluster_assignments: Dict[str, int], client_updates: Dict[str, np.ndarray], client_weights: Dict[str, float] ): """Aggregate updates per cluster.""" for c in range(self.k): # Get clients assigned to this cluster cluster_clients = [ cid for cid, assigned in cluster_assignments.items() if assigned == c ] if not cluster_clients: continue # Weighted average of updates total_weight = sum(client_weights[cid] for cid in cluster_clients) aggregated = np.zeros_like(self.cluster_models[c]) for cid in cluster_clients: aggregated += ( (client_weights[cid] / total_weight) * client_updates[cid] ) self.cluster_models[c] += aggregated def _compute_loss(self, model, x, y) -> float: """Compute loss (placeholder).""" return np.random.random()The key insight is that gradient updates encode information about the underlying data distribution. Clients with similar data produce similar gradients. By clustering based on gradients, we indirectly cluster based on data similarity—without ever accessing the data.
Multi-task learning views each client as having a related but distinct task. Rather than assuming all clients should have the same model (FedAvg) or ignoring relationships between clients (local-only), multi-task FL explicitly models the relationships between clients' learning tasks.
The Task Relationship Matrix:
In multi-task FL, we learn not just client models but also a task relationship matrix Ω that captures how clients' optimal models relate to each other. Clients with similar tasks should have similar models.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
# Multi-Task Federated Learningimport numpy as npfrom typing import List, Dict, Tuplefrom scipy.linalg import block_diag class MOCHA: """ MOCHA: Federated Multi-Task Learning Smith et al., NIPS 2017 Learns client-specific models W = [w₁, ..., wₖ] and a task relationship matrix Ω ∈ ℝᵏˣᵏ that encodes how clients' optimal models relate. Objective: min_W Σₖ Fₖ(wₖ) + λ⋅trace(W^T⋅Ω⁻¹⋅W) The trace term regularizes model differences based on the learned relationships in Ω. Key insight: Learn Ω based on similarity of optimal models. Clients with similar optima should have high Ω values. """ def __init__( self, num_clients: int, model_dim: int, lambda_reg: float = 0.1 ): self.K = num_clients self.d = model_dim self.lambda_reg = lambda_reg # Initialize client models self.W = np.random.randn(model_dim, num_clients) * 0.01 # Initialize relationship matrix as identity (no prior relationships) self.Omega = np.eye(num_clients) def client_local_update( self, client_id: int, local_data: np.ndarray, local_labels: np.ndarray, learning_rate: float = 0.01, num_steps: int = 10 ) -> np.ndarray: """ Client performs local updates considering task relationships. The gradient includes not just the local loss gradient but also a term from the task relationship regularizer. """ w_k = self.W[:, client_id].copy() for step in range(num_steps): # Local loss gradient task_grad = self._compute_gradient(w_k, local_data, local_labels) # Regularization gradient from task relationships # ∂/∂wₖ [λ⋅trace(W^T⋅Ω⁻¹⋅W)] = 2λ⋅Ω⁻¹[k,:]⋅W^T Omega_inv = np.linalg.inv(self.Omega) relationship_grad = 2 * self.lambda_reg * ( self.W @ Omega_inv[client_id, :] ) # Combined gradient total_grad = task_grad + relationship_grad w_k -= learning_rate * total_grad # Update stored model self.W[:, client_id] = w_k return w_k def update_task_relationships(self): """ Update the task relationship matrix Ω based on current models. Ω should reflect similarity between client models: Ωᵢⱼ ∝ similarity(wᵢ, wⱼ) One approach: Ω = W^T⋅W (Gram matrix) """ # Compute model similarity matrix self.Omega = self.W.T @ self.W # Add small regularization for numerical stability self.Omega += 1e-6 * np.eye(self.K) def _compute_gradient(self, model, x, y) -> np.ndarray: """Compute gradient (placeholder).""" return np.random.randn(*model.shape) * 0.01 class FedMTL: """ Federated Multi-Task Learning with hard parameter sharing. Architecture: - Shared base layers (learned via federated averaging) - Task-specific heads (one per client, not shared) This is equivalent to partial personalization but framed in the multi-task learning paradigm. """ def __init__( self, shared_dim: int, task_head_dim: int, num_clients: int ): self.shared_weights = np.random.randn(shared_dim) * 0.01 self.task_heads = { k: np.random.randn(task_head_dim) * 0.01 for k in range(num_clients) } def forward( self, client_id: int, x: np.ndarray ) -> np.ndarray: """Forward pass through shared + task-specific layers.""" # Shared representation shared_output = x @ self.shared_weights # Task-specific output task_output = shared_output @ self.task_heads[client_id] return task_output def aggregate_shared( self, client_shared_updates: Dict[int, np.ndarray], client_weights: Dict[int, float] ): """FedAvg on shared layers only.""" total_weight = sum(client_weights.values()) aggregated = sum( (w / total_weight) * update for update, w in zip( client_shared_updates.values(), client_weights.values() ) ) self.shared_weights += aggregated class FedEM: """ Federated Expectation Maximization (FedEM) Marfoq et al., 2021 Assumes client data is generated by a mixture of K distributions. Each client's data may come from multiple mixture components. E-step: Clients estimate posterior over mixture components M-step: Server aggregates to update mixture component models Unlike hard clustering, this allows soft assignments: a client can be 60% cluster A and 40% cluster B. """ def __init__(self, num_components: int, model_dim: int): self.K = num_components self.component_models = [ np.random.randn(model_dim) * 0.01 for _ in range(num_components) ] self.mixing_weights = np.ones(num_components) / num_components def client_e_step( self, local_data: np.ndarray, local_labels: np.ndarray ) -> np.ndarray: """ E-step: Compute posterior over components for each sample. For each sample x, compute: q(k) ∝ π_k ⋅ p(x | θ_k) where π_k is mixing weight and θ_k is component model. """ n_samples = len(local_data) posteriors = np.zeros((n_samples, self.K)) for k in range(self.K): # Likelihood under component k log_likelihood = self._compute_log_likelihood( self.component_models[k], local_data, local_labels ) posteriors[:, k] = np.log(self.mixing_weights[k]) + log_likelihood # Normalize (softmax) posteriors -= posteriors.max(axis=1, keepdims=True) posteriors = np.exp(posteriors) posteriors /= posteriors.sum(axis=1, keepdims=True) return posteriors def client_m_step( self, local_data: np.ndarray, local_labels: np.ndarray, posteriors: np.ndarray, learning_rate: float = 0.01 ) -> List[np.ndarray]: """ M-step: Update component models weighted by posteriors. Returns weighted gradient updates for each component. """ updates = [] for k in range(self.K): # Weight samples by their posterior for component k weights = posteriors[:, k] # Weighted gradient gradient = self._compute_weighted_gradient( self.component_models[k], local_data, local_labels, weights ) updates.append(-learning_rate * gradient) return updates def _compute_log_likelihood(self, model, x, y) -> np.ndarray: """Compute log likelihood (placeholder).""" return -np.random.random(len(x)) def _compute_weighted_gradient( self, model, x, y, weights ) -> np.ndarray: """Compute weighted gradient (placeholder).""" return np.random.randn(*model.shape) * 0.01| Approach | What It Learns | Complexity | Best For |
|---|---|---|---|
| FedAvg (baseline) | Single global model | Low | Near-IID data |
| Local fine-tuning | Global + adapted per client | Low | Sufficient local data |
| Partial personalization | Shared base + personal head | Medium | Transfer-friendly tasks |
| Clustered FL | One model per cluster | Medium | Natural groupings exist |
| Multi-task FL | Per-client + relationships | High | Heavily non-IID |
| FedEM | Mixture model | High | Multi-modal distributions |
Quantity skew—where clients have vastly different amounts of data—presents unique challenges. Should a client with 10,000 samples have 1000x the influence of a client with 10 samples?
Aggregation Weighting Strategies:
Uniform Weighting — All clients contribute equally regardless of data size. May underweight valuable large datasets.
Data-proportional Weighting — Weight by |D_k| / Σ|D_k|. Standard in FedAvg. May allow large clients to dominate.
Capped Weighting — Proportional but with a maximum weight cap. Balances contribution.
Fairness-aware Weighting — Explicitly optimize for equal performance across clients.
q-FedAvg: Fairness-Aware Federation
q-FedAvg (Li et al., 2019) reweights client contributions to improve fairness. Instead of optimizing for average loss, it optimizes a reweighted objective:
min Σ_k (F_k(w))^q
For q > 1, clients with higher loss get more weight, pushing the optimizer to reduce worst-case performance. This prevents the model from ignoring minority clients.
Practical Recommendation:
Use proportional weighting as default but monitor per-client metrics. If some clients have significantly worse performance, consider adding fairness constraints or personalization for those clients.
Be careful with data-proportional weighting in privacy-sensitive contexts. Revealing exact dataset sizes can leak information (e.g., hospital patient volume). Consider using size buckets or noisy estimates of client sizes.
Let's synthesize practical guidance for handling heterogeneity in real FL deployments.
| Scenario | Recommended Approach | Rationale |
|---|---|---|
| Mild label skew | FedAvg + more rounds | Standard FedAvg handles small heterogeneity |
| Moderate heterogeneity | FedProx or SCAFFOLD | Regularization/variance reduction helps |
| Distinct client groups | Clustered FL (IFCA, CFL) | Separate models for each group |
| Clients need personalization | pFedMe or local fine-tuning | Personal models with global regularization |
| Transfer learning scenario | Partial personalization (FedPer) | Share base, personalize head |
| Severe heterogeneity | Multi-task FL (MOCHA) | Model relationships between clients |
We've covered the critical challenge of data heterogeneity in federated learning:
What's Next:
With heterogeneous data challenges addressed, we turn to Federated Optimization in the final page. You'll learn about the convergence theory of federated algorithms, advanced optimizers designed for FL (FedOpt, FedAdam), and how to select and tune optimization strategies for production federated systems.
You now understand the non-IID data challenge in FL and the arsenal of techniques to address it. You can implement personalization strategies, clustered FL, and multi-task approaches. Next, we complete the module with federated optimization algorithms.