Loading content...
Optimization in federated learning faces unique challenges absent from centralized training. The server never sees raw gradients—it receives aggregated updates computed on stale model versions. Clients perform multiple local updates before synchronizing, introducing drift. Non-IID data causes local gradients to diverge from the global direction. Communication constraints limit how often we can synchronize.
These factors demand optimization algorithms specifically designed for the federated setting. Standard SGD with centralized aggregation doesn't directly apply. This page explores the theory and practice of federated optimization—from foundational FedAvg analysis to advanced algorithms like SCAFFOLD and FedAdam that address FL-specific challenges.
By the end of this page, you will understand the convergence theory of FedAvg and its limitations, master algorithms that improve on FedAvg (FedProx, SCAFFOLD, FedNova), learn server-side and client-side adaptive optimization (FedOpt, FedAdam), and gain practical guidance for tuning federated optimization in production.
Federated Averaging (FedAvg) is deceptively simple: clients perform local SGD, server averages the results. Yet its convergence properties have been extensively studied, revealing both strengths and fundamental limitations.
The FedAvg Algorithm Revisited:
For each round t = 1, 2, ..., T:
1. Server broadcasts w_t to selected clients
2. Each client k initializes w_k = w_t
3. Each client performs E epochs of local SGD:
w_k = w_k - η∇F_k(w_k; batch)
4. Server aggregates: w_{t+1} = Σ_k (n_k/n) * w_k
Convergence for Convex Functions:
For smooth convex functions with IID data, FedAvg converges at rate O(1/T) where T is the number of communication rounds, similar to centralized SGD.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
# FedAvg Convergence Analysisimport numpy as npfrom typing import Tuple, Listfrom dataclasses import dataclass @dataclassclass FedAvgConvergenceBound: """ Convergence bound for FedAvg under various assumptions. Main result from Li et al., 2020: "On the Convergence of FedAvg on Non-IID Data" For T rounds, K clients, E local epochs, learning rate η: E[F(w_T)] - F(w*) ≤ O(sqrt(E²σ²/KT)) + O(η²E²Γ) First term: Optimization error (decreases with T, K) - σ²: Variance of stochastic gradients - K: Number of clients (more clients = less variance) Second term: Non-IID bias (constant, doesn't vanish!) - Γ: Gradient dissimilarity (non-IID measure) - E: Local epochs (more local steps = more bias) - η: Learning rate (smaller η = less bias accumulation) Key insight: With non-IID data, FedAvg converges to a neighborhood of optimal, not the optimum itself. """ @staticmethod def compute_bound( num_rounds: int, # T num_clients: int, # K local_epochs: int, # E learning_rate: float, # η gradient_variance: float, # σ² heterogeneity: float, # Γ (gradient dissimilarity) smoothness: float = 1.0, # L (Lipschitz constant) ) -> Tuple[float, float]: """ Compute convergence bound components. Returns: optimization_error: Decreases with T (can be made small) noniid_bias: Constant, doesn't decrease with T """ T, K, E = num_rounds, num_clients, local_epochs sigma2, Gamma = gradient_variance, heterogeneity eta, L = learning_rate, smoothness # Optimization error: O(sqrt(E²σ²/KT)) optimization_error = np.sqrt(E**2 * sigma2 / (K * T)) # Non-IID bias: O(η²E²Γ) # This term doesn't decrease with T! noniid_bias = eta**2 * E**2 * Gamma return optimization_error, noniid_bias @staticmethod def suggest_hyperparameters( target_accuracy: float, heterogeneity: float, num_clients: int, budget_rounds: int ) -> dict: """ Suggest hyperparameters given constraints. Strategy: 1. Choose E to balance communication and bias 2. Choose η to keep bias below target 3. Verify T is sufficient for optimization error """ # For high heterogeneity, reduce E to limit bias if heterogeneity > 1.0: suggested_E = 1 # Single local epoch for very non-IID elif heterogeneity > 0.1: suggested_E = 2 else: suggested_E = 5 # Choose η to satisfy: η²E²Γ < target/2 max_eta = np.sqrt(target_accuracy / (2 * suggested_E**2 * heterogeneity)) suggested_eta = min(0.1, max_eta) return { 'local_epochs': suggested_E, 'learning_rate': suggested_eta, 'min_rounds': budget_rounds, 'expected_bias': suggested_eta**2 * suggested_E**2 * heterogeneity } class FedAvgAnalysis: """ Empirical analysis of FedAvg convergence behavior. """ def __init__(self): self.round_losses: List[float] = [] self.client_drifts: List[float] = [] def analyze_round( self, global_weights: np.ndarray, client_weights_before: List[np.ndarray], client_weights_after: List[np.ndarray] ) -> dict: """ Analyze a single round of FedAvg. Metrics computed: - Client drift: How far local models moved from global - Gradient variance: Disagreement between client updates - Update norm: Magnitude of aggregated update """ # Client drift (average distance from global after local training) drifts = [ np.linalg.norm(after - global_weights) for after in client_weights_after ] avg_drift = np.mean(drifts) # Client update variance updates = [ after - before for before, after in zip(client_weights_before, client_weights_after) ] mean_update = np.mean(updates, axis=0) update_variance = np.mean([ np.linalg.norm(u - mean_update)**2 for u in updates ]) # Aggregated update norm update_norm = np.linalg.norm(mean_update) self.client_drifts.append(avg_drift) return { 'avg_drift': avg_drift, 'max_drift': max(drifts), 'update_variance': update_variance, 'update_norm': update_norm } def detect_divergence(self, window: int = 10) -> bool: """ Detect if training is diverging. Signs of divergence: - Client drift increasing over rounds - Loss not decreasing """ if len(self.client_drifts) < window: return False recent = self.client_drifts[-window:] earlier = self.client_drifts[-2*window:-window] # Drift increasing significantly return np.mean(recent) > 2 * np.mean(earlier)Key Insights from Convergence Theory:
Non-IID creates irreducible bias — The O(η²E²Γ) term doesn't vanish with more rounds. FedAvg converges to a neighborhood of optimal, not the optimum.
Local epochs vs. heterogeneity tradeoff — More local epochs (E) reduces communication but amplifies non-IID effects quadratically.
Learning rate must be small enough — Large η accumulates too much bias during local training. Smaller η means slower progress but less drift.
More clients helps variance, not bias — Having K clients reduces the O(1/√(KT)) variance term but doesn't touch the bias term.
No matter how many rounds you run, FedAvg with non-IID data converges to a suboptimal point. The only ways to reduce this bias are: (1) reduce learning rate η, (2) reduce local epochs E, or (3) use algorithms specifically designed to handle non-IID data.
FedProx (Li et al., 2020) addresses FedAvg's limitations by adding a proximal term to the local objective. This term penalizes deviation from the global model, keeping local updates closer to the global direction.
The FedProx Objective:
Instead of minimizing F_k(w), each client minimizes:
h_k(w; w_t) = F_k(w) + (μ/2)||w - w_t||²
where w_t is the current global model and μ > 0 is the proximal coefficient.
Intuition:
The proximal term acts like a spring pulling local models back toward the global model. Even with non-IID data, clients can't drift too far before the penalty becomes large.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
# FedProx: Federated Optimization with Proximal Regularizationimport numpy as npfrom typing import List, Dict, Tuple class FedProx: """ FedProx Algorithm Implementation Li et al., 2020: "Federated Optimization in Heterogeneous Networks" Key insight: Standard FedAvg allows arbitrary client drift. FedProx adds proximal regularization to limit this drift. Benefits: 1. Handles non-IID data better than FedAvg 2. Tolerates partial work (clients can do fewer local steps) 3. Simpler than variance reduction methods (no extra state) Convergence: For μ > 0, FedProx converges to a point where the proximal-regularized local objectives balance. """ def __init__( self, proximal_mu: float = 0.01, learning_rate: float = 0.01, local_epochs: int = 5 ): """ Initialize FedProx. Args: proximal_mu: Proximal coefficient μ. Higher = stronger pull toward global model. Too high = slow progress. Recommended: 0.001 to 0.1 depending on heterogeneity. learning_rate: Local learning rate η local_epochs: Number of local epochs E """ self.mu = proximal_mu self.eta = learning_rate self.E = local_epochs def client_local_train( self, global_weights: np.ndarray, local_data: np.ndarray, local_labels: np.ndarray, batch_size: int = 32 ) -> np.ndarray: """ Perform local training with proximal regularization. At each step, gradient includes: ∇F_k(w) + μ(w - w_global) The second term pulls w back toward w_global. """ w = global_weights.copy() n = len(local_data) for epoch in range(self.E): # Shuffle data each epoch perm = np.random.permutation(n) for i in range(0, n, batch_size): idx = perm[i:i+batch_size] batch_x = local_data[idx] batch_y = local_labels[idx] # Task gradient task_grad = self._compute_gradient(w, batch_x, batch_y) # Proximal gradient: μ(w - w_global) prox_grad = self.mu * (w - global_weights) # Combined update w = w - self.eta * (task_grad + prox_grad) return w def _compute_gradient( self, weights: np.ndarray, x: np.ndarray, y: np.ndarray ) -> np.ndarray: """Compute task gradient (placeholder).""" return np.random.randn(*weights.shape) * 0.01 class FedProxAnalysis: """ Analyze FedProx behavior and compare to FedAvg. """ @staticmethod def optimal_mu( heterogeneity: float, learning_rate: float, local_epochs: int ) -> float: """ Suggest optimal proximal coefficient. Rule of thumb from Li et al.: - Higher heterogeneity → higher μ needed - More local epochs → higher μ needed - μ should roughly compensate for η*E drift A common heuristic: μ ≈ η / E for moderate heterogeneity """ # Scale with heterogeneity base_mu = learning_rate / local_epochs # Increase for high heterogeneity if heterogeneity > 1.0: return base_mu * 2 elif heterogeneity > 0.1: return base_mu else: return base_mu * 0.5 @staticmethod def compare_drift( fedavg_drift: float, fedprox_drift: float, proximal_mu: float ) -> dict: """ Compare client drift between FedAvg and FedProx. With proximal term, expected drift reduction is roughly: drift_reduction ≈ 1 / (1 + μ/η) """ return { 'fedavg_drift': fedavg_drift, 'fedprox_drift': fedprox_drift, 'drift_reduction': (fedavg_drift - fedprox_drift) / fedavg_drift, 'expected_reduction': proximal_mu / (1 + proximal_mu) } def fedprox_convergence_bound( num_rounds: int, num_clients: int, local_epochs: int, learning_rate: float, proximal_mu: float, heterogeneity: float) -> Tuple[float, str]: """ Convergence bound for FedProx. Key difference from FedAvg: - The bias term includes μ, which allows trading off bias for slowdown. With μ > 0, FedProx can converge closer to optimal than FedAvg, at the cost of slower progress (more rounds needed). """ T, K, E = num_rounds, num_clients, local_epochs eta, mu, Gamma = learning_rate, proximal_mu, heterogeneity # Effective learning rate reduced by proximal term effective_eta = eta / (1 + mu * eta) # Bias is reduced by proximal term bias_reduction_factor = 1 / (1 + mu * eta)**2 # Optimization error (slower due to proximal term) opt_error = 1.0 / np.sqrt(K * T * effective_eta) # Bias (reduced by proximal term) noniid_bias = eta**2 * E**2 * Gamma * bias_reduction_factor total_bound = opt_error + noniid_bias interpretation = ( f"With μ={mu}: optimization slower (need ~{1/effective_eta:.1f}x more rounds) " f"but bias reduced by {(1 - bias_reduction_factor)*100:.1f}%" ) return total_bound, interpretation| Aspect | FedAvg | FedProx |
|---|---|---|
| Local objective | F_k(w) | F_k(w) + (μ/2)||w - w_t||² |
| Client drift | Unbounded | Bounded by 1/μ |
| Non-IID handling | Bias ∝ E² | Bias reduced by (1+μη)² |
| Convergence rate | Optimal for IID | Slower but more stable |
| Extra hyperparameter | None | μ (proximal coefficient) |
| Partial work support | Difficult | Natural (fewer steps = less drift) |
Start with μ = 0.01 and increase if training is unstable or clients have very heterogeneous data. If training is too slow, decrease μ. A good rule of thumb: μ ≈ η/E gives balanced behavior.
SCAFFOLD (Karimireddy et al., 2020) takes a fundamentally different approach: instead of regularizing clients, it uses control variates to correct for the drift caused by non-IID data.
The Variance Reduction Insight:
In non-IID FL, client gradients have high variance because they point toward different local optima. SCAFFOLD reduces this variance by tracking each client's "drift direction" and correcting for it.
Control Variates:
Each client maintains a control variate c_k representing its expected gradient direction. The server maintains a global control variate c representing the expected gradient of the global objective.
Corrected Update:
Instead of w ← w - η∇F_k(w), clients use:
w ← w - η(∇F_k(w) - c_k + c)
The (c - c_k) term corrects for the difference between the client's usual direction and the global direction.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
# SCAFFOLD: Stochastic Controlled Averaging for Federated Learningimport numpy as npfrom typing import Dict, List, Tuple, Optional class SCAFFOLD: """ SCAFFOLD Algorithm Implementation Karimireddy et al., ICML 2020 Key insight: The problem with FedAvg under non-IID is that client gradients have high variance. SCAFFOLD uses control variates to reduce this variance. Each client tracks its own "drift" c_k (how its gradients differ from the global gradient). When training, the client corrects for this drift: gradient - c_k + c_global. Result: Clients effectively train toward the global optimum, not their local optimum. Eliminates the non-IID bias term! Convergence: SCAFFOLD achieves O(1/T) convergence even with non-IID data, matching the centralized rate. """ def __init__( self, model_dim: int, num_clients: int, learning_rate: float = 0.01, local_epochs: int = 5 ): self.d = model_dim self.K = num_clients self.eta = learning_rate self.E = local_epochs # Server state self.global_model = np.zeros(model_dim) self.global_control = np.zeros(model_dim) # c # Client state (would be stored on clients in practice) self.client_controls: Dict[int, np.ndarray] = { k: np.zeros(model_dim) for k in range(num_clients) } def client_local_train( self, client_id: int, local_data: np.ndarray, local_labels: np.ndarray, batch_size: int = 32 ) -> Tuple[np.ndarray, np.ndarray]: """ Client local training with SCAFFOLD correction. Key modification from FedAvg: gradient_used = gradient - c_k + c_global This corrects for the client's local bias. Returns: delta_model: Update to apply to global model delta_control: Update to client's control variate """ # Get current state w = self.global_model.copy() w_init = w.copy() c = self.global_control c_k = self.client_controls[client_id] n = len(local_data) num_steps = 0 for epoch in range(self.E): perm = np.random.permutation(n) for i in range(0, n, batch_size): idx = perm[i:i+batch_size] batch_x = local_data[idx] batch_y = local_labels[idx] # Standard gradient gradient = self._compute_gradient(w, batch_x, batch_y) # SCAFFOLD correction: use g - c_k + c instead of g # This shifts the update direction toward global optimum corrected_gradient = gradient - c_k + c w = w - self.eta * corrected_gradient num_steps += 1 # Compute update to send to server delta_model = w - w_init # What to add to global model # Update control variate (Option II from paper) # c_k_new = c - c + (w_init - w) / (K * η) # Simplified: c_k_new = gradient average over training c_k_new = c_k - c + (w_init - w) / (num_steps * self.eta) delta_control = c_k_new - c_k # Store updated control self.client_controls[client_id] = c_k_new return delta_model, delta_control def server_aggregate( self, client_model_deltas: Dict[int, np.ndarray], client_control_deltas: Dict[int, np.ndarray], client_weights: Dict[int, float] ): """ Server aggregation for SCAFFOLD. Aggregates both model updates and control variate updates. """ total_weight = sum(client_weights.values()) # Aggregate model updates (same as FedAvg) model_update = sum( (client_weights[k] / total_weight) * delta for k, delta in client_model_deltas.items() ) self.global_model += model_update # Aggregate control updates # c_new = c + (1/K) * Σ_k Δc_k for full participation num_clients_this_round = len(client_control_deltas) control_update = sum( delta / self.K # Divide by total clients, not participating for delta in client_control_deltas.values() ) self.global_control += control_update def _compute_gradient( self, weights: np.ndarray, x: np.ndarray, y: np.ndarray ) -> np.ndarray: """Compute gradient (placeholder).""" return np.random.randn(*weights.shape) * 0.01 class SCAFFOLDAnalysis: """ Analysis and comparison of SCAFFOLD behavior. """ @staticmethod def variance_reduction_factor( client_gradients: List[np.ndarray], client_controls: List[np.ndarray], global_control: np.ndarray ) -> float: """ Measure variance reduction achieved by SCAFFOLD. Compares: - Variance of raw gradients (FedAvg-style) - Variance of corrected gradients (SCAFFOLD-style) """ # Raw gradient variance raw_variance = np.var([np.linalg.norm(g) for g in client_gradients]) # Corrected gradient variance corrected = [ g - c_k + global_control for g, c_k in zip(client_gradients, client_controls) ] corrected_variance = np.var([np.linalg.norm(g) for g in corrected]) # Reduction factor if raw_variance > 0: return 1 - corrected_variance / raw_variance return 0.0 @staticmethod def convergence_comparison() -> dict: """ Compare theoretical convergence bounds. For T rounds, E local epochs, σ² variance, Γ heterogeneity: """ return { 'fedavg_bound': 'O(E²σ²/T) + O(η²E²Γ)', 'fedavg_limitation': 'Non-IID bias term does not vanish with T', 'scaffold_bound': 'O(Eσ²/T)', 'scaffold_advantage': 'No non-IID bias term! Matches centralized rate.', 'scaffold_cost': '2x communication (send control variate updates)' } def compare_scaffold_fedavg( num_rounds: int, heterogeneity: float, local_epochs: int) -> dict: """ Compare expected performance of SCAFFOLD vs FedAvg. """ # FedAvg: bias doesn't vanish fedavg_final_bias = 0.01 ** 2 * local_epochs ** 2 * heterogeneity fedavg_convergence = f"Converges to within {fedavg_final_bias:.4f} of optimal" # SCAFFOLD: converges to optimal (no bias term) scaffold_convergence = f"Converges to optimal at rate O(1/{num_rounds})" # Communication cost fedavg_comm = 1.0 # Baseline scaffold_comm = 2.0 # 2x due to control variates return { 'fedavg_convergence': fedavg_convergence, 'scaffold_convergence': scaffold_convergence, 'communication_overhead': f"SCAFFOLD uses {scaffold_comm}x communication", 'recommendation': ( "Use SCAFFOLD when heterogeneity is high and communication " "is not the primary bottleneck. Use FedAvg/FedProx otherwise." ) }SCAFFOLD's theoretical guarantee is remarkable: it achieves the same O(1/T) convergence rate as centralized SGD, even with arbitrarily non-IID data. The cost is 2x communication (sending control updates) and maintaining per-client state.
Another approach to improving federated optimization is to apply advanced optimizers at the server rather than (or in addition to) at clients. Clients perform vanilla local SGD, but the server uses momentum or adaptive learning rates when aggregating updates.
FedOpt Framework:
FedOpt (Reddi et al., 2021) generalizes FedAvg by introducing server-side optimization:
The ServerOpt can be SGD, Momentum, Adam, or any optimizer.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
# FedOpt: Federated Optimization with Server-Side Optimizersimport numpy as npfrom typing import Optional, Dictfrom abc import ABC, abstractmethod class ServerOptimizer(ABC): """Base class for server-side optimizers.""" @abstractmethod def step(self, params: np.ndarray, update: np.ndarray) -> np.ndarray: """Apply update to parameters.""" pass class ServerSGD(ServerOptimizer): """Standard SGD at the server.""" def __init__(self, learning_rate: float = 1.0): self.lr = learning_rate def step(self, params: np.ndarray, update: np.ndarray) -> np.ndarray: return params + self.lr * update class ServerMomentum(ServerOptimizer): """ Server-side SGD with Momentum. Even though client updates are "pseudo-gradients" (not true gradients), applying momentum at the server can still help: - Smooths out noisy updates from client sampling - Helps escape local minima - Accelerates convergence in consistent directions """ def __init__( self, learning_rate: float = 1.0, momentum: float = 0.9 ): self.lr = learning_rate self.momentum = momentum self.velocity: Optional[np.ndarray] = None def step(self, params: np.ndarray, update: np.ndarray) -> np.ndarray: if self.velocity is None: self.velocity = np.zeros_like(params) # Momentum update self.velocity = self.momentum * self.velocity + update return params + self.lr * self.velocity class FedAdam(ServerOptimizer): """ FedAdam: Adam optimizer at the server. Reddi et al., 2021: "Adaptive Federated Optimization" Benefits: - Adaptive learning rates per parameter - Handles sparse updates well (useful with sparsification) - More robust to heterogeneity in update magnitudes Caution: Adam's moment estimates are for "pseudo-gradients" (aggregated client updates), not true gradients. Still works but convergence guarantees are more complex. """ def __init__( self, learning_rate: float = 0.01, beta1: float = 0.9, beta2: float = 0.99, epsilon: float = 1e-3, tau: float = 1e-3 # Controls adaptivity (FedAdam-specific) ): self.lr = learning_rate self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon self.tau = tau # Additional regularization for FL self.m: Optional[np.ndarray] = None # First moment self.v: Optional[np.ndarray] = None # Second moment self.t = 0 def step(self, params: np.ndarray, update: np.ndarray) -> np.ndarray: self.t += 1 # Initialize moments if self.m is None: self.m = np.zeros_like(params) self.v = np.zeros_like(params) # Update moments self.m = self.beta1 * self.m + (1 - self.beta1) * update self.v = self.beta2 * self.v + (1 - self.beta2) * (update ** 2) # Bias correction m_hat = self.m / (1 - self.beta1 ** self.t) v_hat = self.v / (1 - self.beta2 ** self.t) # FedAdam update (note: τ adds stability for FL setting) denominator = np.sqrt(v_hat) + self.tau step = self.lr * m_hat / denominator return params + step class FedYogi(ServerOptimizer): """ FedYogi: Yogi optimizer variant for FL. Yogi (Zaheer et al., 2018) uses additive update for second moment instead of exponential moving average. This prevents Adam's issue of forgetting early gradients. v_t = v_{t-1} + (1-β₂) * sign(g² - v_{t-1}) * g² More conservative than Adam, often better for FL. """ def __init__( self, learning_rate: float = 0.01, beta1: float = 0.9, beta2: float = 0.99, epsilon: float = 1e-3 ): self.lr = learning_rate self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon self.m: Optional[np.ndarray] = None self.v: Optional[np.ndarray] = None def step(self, params: np.ndarray, update: np.ndarray) -> np.ndarray: if self.m is None: self.m = np.zeros_like(params) self.v = np.zeros_like(params) # First moment: standard exponential average self.m = self.beta1 * self.m + (1 - self.beta1) * update # Second moment: Yogi-style additive update v_diff = update ** 2 - self.v self.v = self.v + (1 - self.beta2) * np.sign(v_diff) * (update ** 2) # Update step = self.lr * self.m / (np.sqrt(self.v) + self.epsilon) return params + step class FedOptFramework: """ General FedOpt framework supporting arbitrary server optimizers. Paper: "Adaptive Federated Optimization" (Reddi et al., 2021) Key findings: 1. Server-side optimization helps, especially with momentum 2. Adaptive methods (Adam, Yogi) can be unstable in FL 3. FedYogi often outperforms FedAdam 4. Server learning rate should be tuned (often η_s = 1.0 works) """ def __init__( self, server_optimizer: ServerOptimizer, client_learning_rate: float = 0.01, client_local_epochs: int = 5 ): self.server_opt = server_optimizer self.client_lr = client_learning_rate self.client_epochs = client_local_epochs self.global_model: Optional[np.ndarray] = None def round( self, client_updates: Dict[int, np.ndarray], client_weights: Dict[int, float] ) -> np.ndarray: """ Execute one round of FedOpt. Args: client_updates: Δ_k from each participating client client_weights: Weighting factors (usually by data size) Returns: Updated global model """ # Aggregate client updates (same as FedAvg) total_weight = sum(client_weights.values()) aggregated_update = sum( (w / total_weight) * delta for delta, w in zip( client_updates.values(), client_weights.values() ) ) # Apply server optimizer self.global_model = self.server_opt.step( self.global_model, aggregated_update ) return self.global_model| Optimizer | Server Computation | Stability | Best For |
|---|---|---|---|
| SGD (η_s=1) | Minimal | High | Baseline |
| Momentum | Low (1 extra vector) | High | General use |
| FedAdam | Low (2 extra vectors) | Medium | Sparse updates |
| FedYogi | Low (2 extra vectors) | Medium-High | Conservative adaptive |
| FedAdaGrad | Low (1 extra vector) | High | Sparse, convex problems |
Start with server-side momentum (FedAvgM). It's simple, stable, and often provides significant gains over vanilla FedAvg. If you need adaptivity (e.g., sparse updates), try FedYogi before FedAdam—it's typically more stable in the FL setting.
Learning rate scheduling is more complex in FL due to the interplay between client-side and server-side learning rates, and the non-IID nature of updates.
Two Learning Rates:
In FedOpt-style algorithms, there are two learning rates:
Scheduling Strategies:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
# Learning Rate Scheduling for Federated Learningimport numpy as npfrom typing import Callable class FLLearningRateScheduler: """ Learning rate scheduling for federated learning. Manages both client and server learning rates independently. """ def __init__( self, client_lr_schedule: Callable[[int], float], server_lr_schedule: Callable[[int], float] ): self.client_lr_fn = client_lr_schedule self.server_lr_fn = server_lr_schedule self.round = 0 def step(self): """Advance to next round.""" self.round += 1 @property def client_lr(self) -> float: return self.client_lr_fn(self.round) @property def server_lr(self) -> float: return self.server_lr_fn(self.round) # Common scheduling functionsdef constant(value: float) -> Callable[[int], float]: """Constant learning rate.""" return lambda t: value def step_decay( initial: float, decay_rate: float, decay_every: int) -> Callable[[int], float]: """Step decay: reduce by factor every N rounds.""" return lambda t: initial * (decay_rate ** (t // decay_every)) def inverse_decay( initial: float, decay_rate: float) -> Callable[[int], float]: """Inverse time decay: η(t) = η₀ / (1 + α·t).""" return lambda t: initial / (1 + decay_rate * t) def cosine_annealing( initial: float, min_lr: float, total_rounds: int) -> Callable[[int], float]: """Cosine annealing from initial to min_lr over total_rounds.""" def schedule(t): if t >= total_rounds: return min_lr cosine_decay = 0.5 * (1 + np.cos(np.pi * t / total_rounds)) return min_lr + (initial - min_lr) * cosine_decay return schedule def warmup_then_decay( warmup_initial: float, peak: float, warmup_rounds: int, decay_fn: Callable[[int], float]) -> Callable[[int], float]: """Warmup to peak, then apply decay function.""" def schedule(t): if t < warmup_rounds: # Linear warmup return warmup_initial + (peak - warmup_initial) * (t / warmup_rounds) else: # Apply decay function (shifted to start at 0) return decay_fn(t - warmup_rounds) return schedule # Example configurationdef create_recommended_schedule( total_rounds: int, heterogeneity_level: str = 'medium' # 'low', 'medium', 'high') -> FLLearningRateScheduler: """ Create recommended LR schedule based on heterogeneity. High heterogeneity: Use smaller, more decaying client LR Low heterogeneity: Can use larger, more stable LRs """ if heterogeneity_level == 'high': # Conservative: Small client LR, quick decay client_schedule = inverse_decay(initial=0.005, decay_rate=0.02) server_schedule = constant(1.0) elif heterogeneity_level == 'medium': # Balanced: Moderate LR with cosine decay client_schedule = cosine_annealing( initial=0.01, min_lr=0.001, total_rounds=total_rounds ) server_schedule = constant(1.0) else: # low heterogeneity # Aggressive: Larger LR, slower decay client_schedule = step_decay( initial=0.05, decay_rate=0.5, decay_every=total_rounds // 3 ) server_schedule = constant(1.0) return FLLearningRateScheduler(client_schedule, server_schedule)In most FL settings, the client learning rate has more impact on convergence than the server learning rate. Focus tuning efforts on η_c first. Server LR is often set to 1.0 (pure averaging) unless using adaptive optimizers.
Let's compare the convergence guarantees of the algorithms we've covered. Understanding these theoretical bounds helps in selecting the right algorithm for your setting.
| Algorithm | Convergence Rate | Non-IID Handling | Communication |
|---|---|---|---|
| Centralized SGD | O(1/T) | N/A | Every iteration |
| FedAvg (IID) | O(1/T) | N/A (assumes IID) | Every E epochs |
| FedAvg (non-IID) | O(1/√T) + O(bias) | Bias doesn't vanish | Every E epochs |
| FedProx | O(1/T) | Reduced bias | Every E epochs |
| SCAFFOLD | O(1/T) | No bias term! | Every E epochs (2x data) |
| FedNova | O(1/T) | Normalized aggregation | Every E epochs |
Key Observations:
SCAFFOLD is theoretically optimal for non-IID data, matching centralized rates. But it requires 2x communication and extra client state.
FedProx trades off bias for slowdown — The proximal term reduces bias but slows progress. Good for moderate heterogeneity.
FedAvg is sufficient for near-IID — If your data is close to IID, the complexity of SCAFFOLD or FedProx may not be worth it.
FedNova handles unbalanced workloads — If clients do different amounts of local work, FedNova's normalization corrects for this.
Practical Hierarchy:
Near-IID → FedAvg (simplest)
↓
Moderate heterogeneity → FedProx or FedAvgM
↓
High heterogeneity + tolerate 2x comm → SCAFFOLD
↓
High heterogeneity + comm-constrained → Clustered/Personalized FL
Let's synthesize practical guidance for tuning federated optimization in production.
| Scenario | η_c | E | Algorithm | Notes |
|---|---|---|---|---|
| IID, many clients | 0.01 | 5-10 | FedAvg | Baseline |
| Mild non-IID | 0.01 | 3-5 | FedAvg + momentum | Server momentum helps |
| Moderate non-IID | 0.005 | 2-3 | FedProx (μ=0.01) | Balance drift vs progress |
| Severe non-IID | 0.001 | 1 | SCAFFOLD | Worth 2x comm |
| Unbalanced local work | 0.01 | variable | FedNova | Normalizes by work done |
| Very limited comm | 0.01 | 10+ | FedProx (μ=0.1) | High μ limits drift with many local steps |
If training diverges: (1) Reduce client LR by 10x, (2) Reduce local epochs to 1, (3) Add/increase FedProx μ, (4) Check for outlier clients with very different data. Often, a single misconfigured client can destabilize global training.
We've covered the optimization algorithms and principles that power federated learning:
Module Complete:
You've now completed the comprehensive Federated Learning module. You understand:
With this foundation, you're equipped to design, implement, and deploy federated learning systems for privacy-preserving, distributed machine learning.
Congratulations! You've mastered the theory and practice of federated learning. From distributed training fundamentals through privacy guarantees, communication efficiency, heterogeneous data handling, and optimization algorithms, you now have the knowledge to build production federated learning systems.