Loading content...
Consider a federated learning system training a modest neural network with 10 million parameters. With 32-bit floating-point precision, each model update is 40 megabytes. If 1,000 clients participate each round and training requires 500 rounds, the total data transmitted is:
1,000 clients × 500 rounds × 40 MB × 2 (upload + download) = 40 terabytes
For mobile devices on limited data plans, or IoT sensors on cellular networks, this is completely impractical. Even in cross-silo settings with dedicated networks, this bandwidth consumption translation into real costs—both financial and temporal.
Communication is often the dominant bottleneck in federated learning, far exceeding local computation time. This page explores the techniques that reduce communication costs by 10-100x while preserving model quality.
By the end of this page, you will understand the communication cost structure in FL, master gradient compression techniques (sparsification, quantization, sketching), learn about asynchronous communication and local SGD, and be able to select appropriate compression strategies for different FL scenarios.
Understanding where communication costs arise is essential for optimizing them. Let's break down the communication structure of federated learning:
Bidirectional Communication:
These directions may have asymmetric constraints. Mobile networks often have much higher download than upload bandwidth. Cross-silo FL may have symmetric high-bandwidth links but still face latency issues.
| Model | Parameters | Size (FP32) | Per-Round Cost (1K clients) | 500 Rounds Total |
|---|---|---|---|---|
| MobileNet | 3.4M | 13.6 MB | 27.2 GB | 13.6 TB |
| ResNet-50 | 25M | 100 MB | 200 GB | 100 TB |
| BERT-Base | 110M | 440 MB | 880 GB | 440 TB |
| GPT-2 (Small) | 124M | 496 MB | 992 GB | 496 TB |
| ViT-Large | 304M | 1.2 GB | 2.4 TB | 1.2 PB |
Latency Considerations:
Beyond bandwidth, latency affects FL training time:
The Compression Opportunity:
Deep learning gradients have properties that enable compression:
These properties enable compression ratios of 10x-1000x with minimal accuracy loss.
Aggressive compression introduces bias and variance into gradient estimates, potentially slowing convergence. The key insight: per-round accuracy loss from compression is often acceptable because we save enough communication to afford more rounds. 10 compressed rounds may outperform 1 uncompressed round.
Gradient sparsification exploits the observation that most gradient components are near zero. By sending only the largest components, we dramatically reduce communication while preserving the most important update information.
Top-K Sparsification:
The simplest approach: keep only the K largest gradient components (by absolute value) and zero the rest. With K = 0.01 × d (1% of dimensions), we achieve 100x compression.
Random-K Sparsification:
Alternatively, randomly select K components to transmit. This is unbiased but has higher variance than Top-K.
The Error Accumulation Problem:
Naive sparsification loses information in the zeroed components. Solution: Error Feedback (also called gradient accumulation). Maintain a local error buffer and add it to the next gradient before sparsification.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
# Gradient Sparsification Techniques for Federated Learningimport numpy as npfrom typing import Tuple, Optionalimport heapq class GradientSparsifier: """ Implements gradient sparsification with error feedback. Key insight: Sparsification is biased (we lose small gradients). Error feedback corrects this by accumulating residuals. Theorem (Stich et al., 2018): With error feedback, Top-K sparsification converges at the same rate as full gradient, up to polylog factors. """ def __init__( self, compression_ratio: float = 0.01, # Keep 1% of gradients method: str = "top_k", # "top_k" or "random_k" use_error_feedback: bool = True ): self.k_ratio = compression_ratio self.method = method self.use_error_feedback = use_error_feedback self.error_buffer: Optional[np.ndarray] = None def sparsify( self, gradient: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, int]: """ Sparsify gradient, returning compressed representation. With error feedback: 1. Add accumulated error to current gradient 2. Sparsify the sum 3. Store new error (gradient + old_error - sparsified) Args: gradient: Full gradient vector of dimension d Returns: values: Non-zero gradient values (k elements) indices: Positions of non-zero elements (k elements) original_dim: Original dimensionality for reconstruction """ flat_gradient = gradient.flatten() d = len(flat_gradient) k = int(d * self.k_ratio) # Step 1: Apply error feedback if self.use_error_feedback and self.error_buffer is not None: gradient_with_error = flat_gradient + self.error_buffer else: gradient_with_error = flat_gradient.copy() # Step 2: Select top-k or random-k components if self.method == "top_k": indices, values = self._top_k_selection(gradient_with_error, k) else: indices, values = self._random_k_selection(gradient_with_error, k) # Step 3: Compute and store residual error if self.use_error_feedback: sparse_gradient = np.zeros_like(flat_gradient) sparse_gradient[indices] = values self.error_buffer = gradient_with_error - sparse_gradient return values, indices, d def _top_k_selection( self, gradient: np.ndarray, k: int ) -> Tuple[np.ndarray, np.ndarray]: """ Select top-k components by absolute value. Time complexity: O(d log k) using heap, O(d) using partial sort. """ # Find indices of k largest absolute values abs_gradient = np.abs(gradient) # Use argpartition for O(d) average complexity top_k_indices = np.argpartition(abs_gradient, -k)[-k:] top_k_values = gradient[top_k_indices] return top_k_indices, top_k_values def _random_k_selection( self, gradient: np.ndarray, k: int ) -> Tuple[np.ndarray, np.ndarray]: """ Randomly select k components. Unbiased: E[sparse_gradient] = gradient But higher variance than top-k. """ d = len(gradient) indices = np.random.choice(d, size=k, replace=False) # Scale up to maintain unbiased estimate # E[scaled_component] = (d/k) * (k/d) * gradient[i] = gradient[i] values = gradient[indices] * (d / k) return indices, values @staticmethod def reconstruct( values: np.ndarray, indices: np.ndarray, original_dim: int, original_shape: Tuple[int, ...] ) -> np.ndarray: """Reconstruct sparse gradient into full form.""" flat = np.zeros(original_dim) flat[indices] = values return flat.reshape(original_shape) class ThresholdSparsifier: """ Threshold-based sparsification: keep components above threshold. Advantage: Adaptive—if gradient is truly sparse, send fewer values. Challenge: Determining appropriate threshold. """ def __init__(self, threshold: float = 0.001): self.threshold = threshold self.error_buffer: Optional[np.ndarray] = None def sparsify(self, gradient: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: flat = gradient.flatten() # Apply error feedback if self.error_buffer is not None: flat = flat + self.error_buffer # Select components above threshold mask = np.abs(flat) > self.threshold indices = np.where(mask)[0] values = flat[indices] # Update error buffer sparse = np.zeros_like(flat) sparse[indices] = values self.error_buffer = flat - sparse return values, indices class DGCSparsifier: """ Deep Gradient Compression (Lin et al., 2018). Combines multiple techniques for extreme compression: 1. Momentum correction (apply momentum before sparsification) 2. Local gradient clipping 3. Momentum factor masking 4. Warm-up training (start with less sparsification) Achieves 270x-600x compression with minimal accuracy loss. """ def __init__( self, compression_ratio: float = 0.001, # 0.1% sparsity momentum: float = 0.9, warmup_epochs: int = 4 ): self.k_ratio = compression_ratio self.momentum = momentum self.warmup_epochs = warmup_epochs # State self.velocity: Optional[np.ndarray] = None self.error: Optional[np.ndarray] = None self.epoch = 0 def sparsify(self, gradient: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ DGC sparsification with momentum correction. Key insight: Apply momentum BEFORE sparsification. This prevents momentum from amplifying sparse components and losing important dense components. """ flat = gradient.flatten() d = len(flat) # Initialize state if self.velocity is None: self.velocity = np.zeros_like(flat) self.error = np.zeros_like(flat) # 1. Update velocity with momentum self.velocity = self.momentum * self.velocity + flat # 2. Add error feedback corrected = self.velocity + self.error # 3. Determine k (with warmup) warmup_factor = min(1.0, (self.epoch + 1) / self.warmup_epochs) effective_ratio = self.k_ratio + (1 - self.k_ratio) * (1 - warmup_factor) k = int(d * effective_ratio) # 4. Top-k selection abs_corrected = np.abs(corrected) top_k_indices = np.argpartition(abs_corrected, -k)[-k:] top_k_values = corrected[top_k_indices] # 5. Update error buffer sparse = np.zeros_like(flat) sparse[top_k_indices] = top_k_values self.error = corrected - sparse # 6. Mask momentum (set momentum to 0 for transmitted components) self.velocity[top_k_indices] = 0 return top_k_values, top_k_indicesWithout error feedback, sparsification creates a biased estimator that causes convergence to a suboptimal point. The accumulated error ensures no information is permanently lost—every gradient component eventually gets transmitted after accumulating sufficiently.
Gradient quantization reduces communication by representing gradient values with fewer bits. Neural networks are remarkably tolerant to low-precision gradients—often 2-4 bits suffice.
Quantization Methods:
Deterministic Quantization — Map values to nearest quantization level. Simple but biased.
Stochastic Quantization — Randomly round to nearby levels with probability proportional to distance. Unbiased in expectation.
Learned Quantization — Optimize quantization levels for the gradient distribution. More complex but can achieve better accuracy.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
# Gradient Quantization Techniques for Communication Efficiencyimport numpy as npfrom typing import Tuple class UniformQuantizer: """ Uniform quantization with configurable bit width. Maps continuous values to discrete levels uniformly spaced between min and max values. """ def __init__(self, num_bits: int = 8): self.num_bits = num_bits self.num_levels = 2 ** num_bits def quantize( self, gradient: np.ndarray, stochastic: bool = True ) -> Tuple[np.ndarray, float, float]: """ Quantize gradient to num_bits representation. Args: gradient: Full precision gradient stochastic: If True, use stochastic rounding (unbiased) Returns: quantized: Integer indices into quantization levels scale: Scale factor for reconstruction zero_point: Offset for reconstruction Compression ratio: 32 / num_bits """ # Compute range min_val = gradient.min() max_val = gradient.max() # Avoid division by zero if max_val == min_val: return np.zeros_like(gradient, dtype=np.uint8), 1.0, min_val # Scale to [0, num_levels - 1] scale = (max_val - min_val) / (self.num_levels - 1) scaled = (gradient - min_val) / scale if stochastic: # Stochastic rounding: unbiased quantization # P(round up) = fractional part floor_val = np.floor(scaled) prob = scaled - floor_val random_vals = np.random.uniform(0, 1, size=gradient.shape) quantized = np.where(random_vals < prob, floor_val + 1, floor_val) else: # Deterministic rounding quantized = np.round(scaled) # Clip to valid range quantized = np.clip(quantized, 0, self.num_levels - 1).astype(np.uint8) return quantized, scale, min_val def dequantize( self, quantized: np.ndarray, scale: float, zero_point: float ) -> np.ndarray: """Reconstruct gradient from quantized representation.""" return quantized.astype(np.float32) * scale + zero_point class TernaryQuantizer: """ Ternary Gradient Quantization (Wen et al., 2017). Extreme compression: represent gradients with only 3 values {-scale, 0, +scale} using 2 bits per value. Compression ratio: 32 / 2 = 16x (theoretical) """ def quantize( self, gradient: np.ndarray ) -> Tuple[np.ndarray, float]: """ Ternarize gradient to {-1, 0, +1} * scale. Scale s is chosen to minimize E[(g - s*ternarize(g))²] Optimal s = E[|g|] for components that are quantized. Returns: ternary: Array with values in {-1, 0, 1} scale: Scaling factor for reconstruction """ # Threshold: based on mean absolute threshold = 0.7 * np.mean(np.abs(gradient)) # Ternarize ternary = np.zeros_like(gradient, dtype=np.int8) ternary[gradient > threshold] = 1 ternary[gradient < -threshold] = -1 # Compute optimal scale # s = ||g||₁ / ||ternarize(g)||₀ nonzero_mask = ternary != 0 if nonzero_mask.sum() > 0: scale = np.abs(gradient[nonzero_mask]).sum() / nonzero_mask.sum() else: scale = 0.0 return ternary, scale def dequantize( self, ternary: np.ndarray, scale: float ) -> np.ndarray: """Reconstruct from ternary representation.""" return ternary.astype(np.float32) * scale class QSGD: """ Quantized SGD (Alistarh et al., 2017). Stochastic quantization with theoretical guarantees. Key property: Unbiased! E[Q(g)] = g This ensures convergence to the same optimum as full precision. """ def __init__(self, num_levels: int = 256): self.num_levels = num_levels # s in the paper def quantize( self, gradient: np.ndarray ) -> Tuple[np.ndarray, float]: """ QSGD quantization algorithm. For each component gᵢ: 1. Let l = floor(|gᵢ|/||g||₂ * s) 2. With prob (|gᵢ|/||g||₂ * s - l), output sign(gᵢ) * (l+1) 3. Otherwise output sign(gᵢ) * l The division by ||g||₂ normalizes to [0, 1]. Multiply by s to get s+1 possible levels. """ norm = np.linalg.norm(gradient) if norm == 0: return np.zeros_like(gradient, dtype=np.int16), 0.0 # Normalize and scale normalized = np.abs(gradient) / norm * self.num_levels # Stochastic rounding floor_val = np.floor(normalized) prob = normalized - floor_val # Sample random_vals = np.random.uniform(0, 1, size=gradient.shape) levels = np.where(random_vals < prob, floor_val + 1, floor_val) # Apply sign quantized = np.sign(gradient) * levels return quantized.astype(np.int16), norm def dequantize( self, quantized: np.ndarray, norm: float ) -> np.ndarray: """Reconstruct: g ≈ (quantized / s) * norm.""" return (quantized.astype(np.float32) / self.num_levels) * norm class SignSGD: """ SignSGD (Bernstein et al., 2018). Extreme 1-bit quantization: transmit only the sign of gradients. Compression ratio: 32x Requires majority vote aggregation (not weighted average). Converges if learning rate decays appropriately. """ def quantize(self, gradient: np.ndarray) -> np.ndarray: """ 1-bit quantization: keep only the sign. Returns: Array of {-1, +1} """ return np.sign(gradient).astype(np.int8) @staticmethod def aggregate( client_signs: np.ndarray # Shape: (num_clients, d) ) -> np.ndarray: """ Majority vote aggregation for SignSGD. For each dimension, compute sign of sum of signs. This is equivalent to majority vote. """ sign_sum = np.sum(client_signs, axis=0) return np.sign(sign_sum) # Combining sparsification and quantizationclass SparseQuantizedCompressor: """ Combined sparsification and quantization for extreme compression. Pipeline: Top-K selection → Quantization → Encoding Achieves 100-1000x compression. """ def __init__( self, sparsity: float = 0.01, # Keep 1% num_bits: int = 4 # 4-bit quantization ): self.sparsifier = GradientSparsifier(sparsity) self.quantizer = UniformQuantizer(num_bits) def compress( self, gradient: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, float, float, int]: """ Apply combined compression. Compression ratio = (1/sparsity) * (32/num_bits) Example: 1% sparsity + 4-bit = 100 * 8 = 800x """ # First: sparsification values, indices, dim = self.sparsifier.sparsify(gradient) # Second: quantization of sparse values quant_values, scale, zero = self.quantizer.quantize(values) return quant_values, indices, scale, zero, dim def decompress( self, quant_values: np.ndarray, indices: np.ndarray, scale: float, zero: float, dim: int, shape: Tuple[int, ...] ) -> np.ndarray: """Reconstruct full gradient.""" values = self.quantizer.dequantize(quant_values, scale, zero) return GradientSparsifier.reconstruct(values, indices, dim, shape)| Method | Bits | Compression | Biased? | Best For |
|---|---|---|---|---|
| Full precision | 32 | 1x | No | Baseline |
| FP16 | 16 | 2x | No | Standard practice |
| 8-bit uniform | 8 | 4x | No (stochastic) | General purpose |
| 4-bit uniform | 4 | 8x | Yes | Robust models |
| Ternary | 2 | 16x | Yes | Simple models |
| SignSGD | 1 | 32x | Yes | Distributed sync |
| QSGD | ~8 | ~4x | No | When unbiased required |
Sketching algorithms provide a principled approach to gradient compression based on streaming algorithms and dimensionality reduction. Unlike ad-hoc compression, sketches offer mathematical guarantees on reconstruction error.
Count Sketch for Gradients:
The Count Sketch data structure projects a high-dimensional vector into a lower-dimensional space while preserving the ability to recover heavy hitters (large components) with high probability.
Count-Min Sketch:
Similar to Count Sketch but uses only positive counters and returns upper bounds on frequencies.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
# Sketch-Based Gradient Compressionimport numpy as npfrom typing import Tuple class CountSketch: """ Count Sketch for gradient compression. Projects d-dimensional gradient into r × c sketch matrix. Compression ratio: d / (r * c) Properties: - Insert: O(r) hash evaluations - Query: O(r) hash evaluations - Space: O(r * c) - Error: O(||g||₂ / √c) for each heavy hitter Reference: Charikar et al., "Finding Frequent Items in Data Streams" """ def __init__(self, num_rows: int = 5, num_cols: int = 10000): """ Initialize Count Sketch. Args: num_rows: Number of hash functions (more = lower variance) num_cols: Width of each row (more = lower bias) Sketch size: num_rows * num_cols floats """ self.r = num_rows self.c = num_cols # Generate random hash functions (using random seeds) self.hash_seeds = np.random.randint(0, 2**31, size=num_rows) self.sign_seeds = np.random.randint(0, 2**31, size=num_rows) def sketch(self, gradient: np.ndarray) -> np.ndarray: """ Create sketch of gradient vector. For each index i in gradient: 1. Hash i to column h(i) for each row 2. Compute random sign s(i) ∈ {-1, +1} 3. Add s(i) * gradient[i] to sketch[row, h(i)] Args: gradient: d-dimensional gradient vector Returns: sketch: r × c matrix """ flat = gradient.flatten() d = len(flat) sketch = np.zeros((self.r, self.c)) indices = np.arange(d) for row in range(self.r): # Hash indices to columns col_hashes = self._hash(indices, self.hash_seeds[row]) % self.c # Random signs signs = 2 * (self._hash(indices, self.sign_seeds[row]) % 2) - 1 # Accumulate np.add.at(sketch[row], col_hashes, signs * flat) return sketch def recover( self, sketch: np.ndarray, original_dim: int, top_k: int ) -> Tuple[np.ndarray, np.ndarray]: """ Recover approximate heavy hitters from sketch. For each candidate index: 1. Query all rows 2. Take median of (sign[row] * sketch[row, hash(i)]) 3. Return top-k by estimated magnitude This is approximate—small components may be missed, collisions cause noise. """ estimates = np.zeros(original_dim) indices = np.arange(original_dim) # Query each index row_estimates = np.zeros((self.r, original_dim)) for row in range(self.r): col_hashes = self._hash(indices, self.hash_seeds[row]) % self.c signs = 2 * (self._hash(indices, self.sign_seeds[row]) % 2) - 1 row_estimates[row] = signs * sketch[row, col_hashes] # Take median across rows (robust to collisions) estimates = np.median(row_estimates, axis=0) # Return top-k top_k_indices = np.argpartition(np.abs(estimates), -top_k)[-top_k:] return estimates[top_k_indices], top_k_indices def _hash(self, values: np.ndarray, seed: int) -> np.ndarray: """Simple hash function using modular arithmetic.""" # In practice, use a proper hash function like MurmurHash a = (seed * 2654435761) % (2 ** 32) b = ((seed + 1) * 2654435761) % (2 ** 32) return ((a * values + b) % (2 ** 32)).astype(np.int64) class FetchSGD: """ FetchSGD: Communication-Efficient Federated Learning with Sketching Rothchild et al., ICML 2020. Uses Count Sketch for federated learning: 1. Each client sketches their gradient update 2. Server aggregates sketches (sketches are linear!) 3. Server recovers heavy hitters from aggregated sketch 4. Error feedback for unsent components Advantages: - Aggregation in sketch space (don't decompress at server) - Automatic heavy hitter identification - Theoretical guarantees on recovery """ def __init__( self, num_rows: int = 5, num_cols: int = 50000, top_k: int = 50000 ): self.sketch = CountSketch(num_rows, num_cols) self.top_k = top_k self.error_buffer: Optional[np.ndarray] = None def client_compress( self, gradient: np.ndarray ) -> np.ndarray: """Client-side: compress gradient to sketch.""" flat = gradient.flatten() # Add error feedback if self.error_buffer is not None: flat = flat + self.error_buffer # Create sketch sketch = self.sketch.sketch(flat) # Recover what we're sending (for error feedback) recovered_values, recovered_indices = self.sketch.recover( sketch, len(flat), self.top_k ) # Compute error for next round sent = np.zeros_like(flat) sent[recovered_indices] = recovered_values self.error_buffer = flat - sent return sketch @staticmethod def server_aggregate( client_sketches: np.ndarray # Shape: (num_clients, r, c) ) -> np.ndarray: """ Server-side: aggregate sketches by simple sum. Key property: Sketch(Σ gradients) = Σ Sketch(gradients) Linear aggregation in compressed space! """ return np.sum(client_sketches, axis=0) def server_recover( self, aggregated_sketch: np.ndarray, original_dim: int ) -> np.ndarray: """Server-side: recover approximate gradient from sketch.""" values, indices = self.sketch.recover( aggregated_sketch, original_dim, self.top_k ) full_gradient = np.zeros(original_dim) full_gradient[indices] = values return full_gradientA major advantage of sketch-based compression: aggregation happens in the compressed domain. The server sums sketches directly without decompression. This reduces server-side computation and enables secure aggregation with compressed updates.
The most fundamental communication reduction technique in federated learning is simply communicating less often. Local SGD (also called FedAvg with multiple local epochs) has clients perform multiple gradient updates locally before synchronizing with the server.
Communication Reduction via Local Updates:
If each client performs E local epochs with K local updates per epoch before communicating:
The Local-Global Divergence Problem:
Local updates cause client models to diverge from each other. With non-IID data, this divergence is amplified—clients pull in different directions based on their local data distributions.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
# Local SGD and FedAvg Variants for Communication Efficiencyimport numpy as npfrom typing import List, Optionalfrom dataclasses import dataclass @dataclassclass LocalTrainingConfig: """Configuration for local training.""" num_epochs: int = 1 # E: local epochs batch_size: int = 32 # B: local batch size learning_rate: float = 0.01 # η: local learning rate momentum: float = 0.0 # Local momentum @property def local_steps(self) -> int: """Approximate number of local gradient steps.""" return self.num_epochs * (10000 // self.batch_size) # Assuming 10K samples class LocalSGDClient: """ Local SGD client performing multiple local updates. FedAvg (McMahan et al., 2017) is Local SGD applied to federated learning settings. """ def __init__(self, config: LocalTrainingConfig): self.config = config self.model = None def local_train( self, global_weights: np.ndarray, local_data: np.ndarray, local_labels: np.ndarray ) -> np.ndarray: """ Perform local training and return model delta. Communication cost: O(|model|) per round Computation: E * |local_data| / B gradient steps """ self.model = global_weights.copy() velocity = np.zeros_like(self.model) for epoch in range(self.config.num_epochs): # Shuffle data each epoch perm = np.random.permutation(len(local_data)) for i in range(0, len(local_data), self.config.batch_size): batch_indices = perm[i:i + self.config.batch_size] batch_x = local_data[batch_indices] batch_y = local_labels[batch_indices] # Compute gradients (simplified) gradient = self._compute_gradient(batch_x, batch_y) # SGD with momentum velocity = self.config.momentum * velocity + gradient self.model -= self.config.learning_rate * velocity # Return delta (update) rather than full model return self.model - global_weights def _compute_gradient(self, x, y) -> np.ndarray: """Placeholder for actual gradient computation.""" return np.random.randn(*self.model.shape) * 0.01 class FedProxClient(LocalSGDClient): """ FedProx (Li et al., 2020): Addresses client heterogeneity. Key modification: Add proximal term to local objective Local objective: F_k(w) + (μ/2)||w - w_global||² The proximal term keeps local model close to global model, reducing divergence caused by non-IID data. μ = 0 recovers FedAvg μ > 0 adds regularization toward global model """ def __init__(self, config: LocalTrainingConfig, proximal_mu: float = 0.01): super().__init__(config) self.mu = proximal_mu def local_train( self, global_weights: np.ndarray, local_data: np.ndarray, local_labels: np.ndarray ) -> np.ndarray: """ FedProx local training with proximal regularization. Gradient at each step includes: ∇F_k(w) + μ(w - w_global) """ self.model = global_weights.copy() for epoch in range(self.config.num_epochs): perm = np.random.permutation(len(local_data)) for i in range(0, len(local_data), self.config.batch_size): batch_indices = perm[i:i + self.config.batch_size] batch_x = local_data[batch_indices] batch_y = local_labels[batch_indices] # Task gradient task_gradient = self._compute_gradient(batch_x, batch_y) # Proximal gradient: μ(w - w_global) proximal_gradient = self.mu * (self.model - global_weights) # Combined update total_gradient = task_gradient + proximal_gradient self.model -= self.config.learning_rate * total_gradient return self.model - global_weights class SCAFFOLDClient(LocalSGDClient): """ SCAFFOLD (Karimireddy et al., 2020): Variance reduction for FL. Problem: In non-IID settings, client gradients have high variance. Solution: Use control variates to reduce variance. Each client maintains a control variate c_i. The server maintains a global control variate c. Local update: w ← w - η(∇F_k(w) - c_i + c) The (c - c_i) term corrects for client drift, dramatically improving convergence in heterogeneous settings. """ def __init__(self, config: LocalTrainingConfig): super().__init__(config) self.control_variate: Optional[np.ndarray] = None self.server_control: Optional[np.ndarray] = None def set_control_variates( self, client_control: np.ndarray, server_control: np.ndarray ): """Set control variates before training round.""" self.control_variate = client_control.copy() self.server_control = server_control.copy() def local_train( self, global_weights: np.ndarray, local_data: np.ndarray, local_labels: np.ndarray ) -> tuple: """ SCAFFOLD local training with variance reduction. Returns: delta: Model update (w_new - w_global) control_delta: Control variate update (c_new - c_old) """ self.model = global_weights.copy() # Initialize if first round if self.control_variate is None: self.control_variate = np.zeros_like(global_weights) if self.server_control is None: self.server_control = np.zeros_like(global_weights) # Local training with control variate correction for epoch in range(self.config.num_epochs): for batch_x, batch_y in self._get_batches(local_data, local_labels): gradient = self._compute_gradient(batch_x, batch_y) # SCAFFOLD correction: g - c_i + c corrected_gradient = ( gradient - self.control_variate + self.server_control ) self.model -= self.config.learning_rate * corrected_gradient # Update control variate # Option 2 from paper: c_i_new = c_i - c + (w_old - w_new) / (K * η) K = self.config.local_steps control_update = ( self.control_variate - self.server_control + (global_weights - self.model) / (K * self.config.learning_rate) ) delta_model = self.model - global_weights delta_control = control_update - self.control_variate # Update stored control variate self.control_variate = control_update return delta_model, delta_control def _get_batches(self, data, labels): """Iterate over batches.""" perm = np.random.permutation(len(data)) for i in range(0, len(data), self.config.batch_size): idx = perm[i:i + self.config.batch_size] yield data[idx], labels[idx]| Algorithm | Key Idea | Communication | Handles Non-IID |
|---|---|---|---|
| FedAvg | Multiple local epochs | 1x baseline | Moderate |
| FedProx |
| 1x baseline | Good |
| SCAFFOLD |
| 2x baseline (send control) | Excellent |
| FedNova | Normalized averaging by local steps | 1x baseline | Good |
| MOON |
| 1.5x baseline | Excellent |
Synchronous FL requires waiting for all participating clients before aggregation. With heterogeneous clients (different computation speeds, network latencies), this means waiting for the slowest client each round—the straggler problem.
Asynchronous Approaches:
Asynchronous FL allows the server to aggregate updates as they arrive, without waiting for all clients. This improves wall-clock training time but introduces staleness—updates may be computed on outdated model versions.
Bounded Staleness:
A practical middle ground: allow some asynchrony but bound the maximum staleness. If a client's update is based on a model more than τ versions old, discard it or down-weight it significantly.
Adaptive Weighting:
Weight updates inversely proportional to staleness:
weight = 1 / (1 + staleness^α)
This reduces the influence of very stale updates while still benefiting from fast clients.
Asynchronous FL complicates differential privacy. Standard DP-SGD assumes synchronized batches with known composition. With async updates arriving at irregular intervals, tracking privacy budget becomes non-trivial. Conservative accounting is necessary, potentially losing efficiency gains.
Let's synthesize the techniques into practical strategies for different FL scenarios.
| Setting | Key Constraint | Recommended Approach | Expected Compression |
|---|---|---|---|
| Cross-device (mobile) | Upload bandwidth | FedAvg + Top-K + 8-bit quantization | 50-100x |
| Cross-device (IoT) | Very low bandwidth | SignSGD + extreme sparsity | 500-1000x |
| Cross-silo (healthcare) | Privacy + moderate bandwidth | FedProx + secure aggregation | 1-10x (utility priority) |
| Cross-silo (finance) | Strict privacy + reliability | SCAFFOLD + DP + redundancy | 1-5x (correctness priority) |
| Edge computing | Latency + heterogeneity | FedNova + async + partial updates | 10-50x |
We've covered the critical communication challenges and solutions in federated learning:
What's Next:
With communication efficiency addressed, we turn to Heterogeneous Data in the next page. You'll learn why non-IID data distributions cause FL algorithms to struggle, the theoretical understanding of this phenomenon, and techniques like personalization, clustering, and multi-task learning that unlock effective learning from heterogeneous client populations.
You now understand communication bottlenecks in FL and the arsenal of techniques to address them. You can implement gradient sparsification, quantization, and local SGD variants. Next, we tackle the challenge of heterogeneous, non-IID data across federated clients.