Loading learning content...
We have established the theoretical foundations of stochastic variational inference: mini-batch gradient estimation and natural gradients. But theory alone doesn't scale algorithms to billions of data points distributed across hundreds of machines.
Consider the practical demands of modern probabilistic modeling:
This page bridges theory and practice, addressing the systems-level challenges of scaling SVI to datasets that exceed the capacity of single machines.
By the end of this page, you will understand memory-efficient data streaming strategies, distributed SVI across multiple machines, gradient compression and communication reduction, asynchronous and parallel update schemes, and the algorithmic modifications needed for extreme-scale inference.
The first scaling challenge is fitting data in memory. When datasets exceed available RAM, naive implementations that load all data fail immediately.
Data streaming architecture:
SVI naturally supports streaming data access because each iteration only requires a mini-batch. The core principle is:
This enables datasets orders of magnitude larger than RAM to be processed efficiently.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
import numpy as npfrom typing import Iterator, Tupleimport mmapimport struct class StreamingDataLoader: """ Memory-mapped data loader for datasets larger than RAM. Uses memory-mapped files to stream mini-batches without loading the full dataset. Operating system manages caching automatically. """ def __init__( self, data_path: str, batch_size: int, feature_dim: int, dtype: np.dtype = np.float32, shuffle_buffer_size: int = 10000 ): """ Args: data_path: Path to binary data file batch_size: Mini-batch size feature_dim: Dimensionality of each data point dtype: Data type (default float32) shuffle_buffer_size: Buffer for reservoir sampling shuffle """ self.batch_size = batch_size self.feature_dim = feature_dim self.dtype = dtype self.bytes_per_sample = feature_dim * np.dtype(dtype).itemsize # Memory-map the file (doesn't load into RAM) self.file = open(data_path, 'rb') self.mmap = mmap.mmap(self.file.fileno(), 0, access=mmap.ACCESS_READ) # Calculate dataset size self.n_samples = len(self.mmap) // self.bytes_per_sample # Shuffle buffer for randomization without full-data shuffle self.shuffle_buffer_size = min(shuffle_buffer_size, self.n_samples) self.shuffle_buffer = [] self.buffer_index = 0 def __len__(self) -> int: return self.n_samples // self.batch_size def _read_sample(self, index: int) -> np.ndarray: """Read a single sample by index.""" start = index * self.bytes_per_sample end = start + self.bytes_per_sample data = self.mmap[start:end] return np.frombuffer(data, dtype=self.dtype) def _fill_buffer(self, start_index: int) -> None: """Fill the shuffle buffer starting from given index.""" self.shuffle_buffer = [] for i in range(start_index, min(start_index + self.shuffle_buffer_size, self.n_samples)): self.shuffle_buffer.append(self._read_sample(i)) np.random.shuffle(self.shuffle_buffer) self.buffer_index = 0 def __iter__(self) -> Iterator[np.ndarray]: """Iterate over mini-batches.""" # Shuffle start position for each epoch epoch_start = np.random.randint(0, self.n_samples) self._fill_buffer(epoch_start) batch = [] samples_yielded = 0 current_position = epoch_start while samples_yielded < self.n_samples: # Refill buffer if exhausted if self.buffer_index >= len(self.shuffle_buffer): current_position = (current_position + self.shuffle_buffer_size) % self.n_samples self._fill_buffer(current_position) # Add sample to batch batch.append(self.shuffle_buffer[self.buffer_index]) self.buffer_index += 1 samples_yielded += 1 # Yield complete batch if len(batch) == self.batch_size: yield np.stack(batch) batch = [] # Yield final partial batch if exists if batch: yield np.stack(batch) def __del__(self): """Clean up resources.""" if hasattr(self, 'mmap'): self.mmap.close() if hasattr(self, 'file'): self.file.close() class ShardedDataLoader: """ Data loader for datasets sharded across multiple files. Useful when data is too large for a single file or when data is naturally partitioned (e.g., by date, region, etc.) """ def __init__( self, shard_paths: list, batch_size: int, feature_dim: int ): self.shard_paths = shard_paths self.batch_size = batch_size self.feature_dim = feature_dim # Load shard metadata (sizes) self.shard_sizes = [] for path in shard_paths: with open(path, 'rb') as f: f.seek(0, 2) # Seek to end size = f.tell() // (feature_dim * 4) # Assuming float32 self.shard_sizes.append(size) self.total_samples = sum(self.shard_sizes) def __iter__(self) -> Iterator[np.ndarray]: """Iterate over shards in random order, streaming from each.""" shard_order = np.random.permutation(len(self.shard_paths)) for shard_idx in shard_order: loader = StreamingDataLoader( self.shard_paths[shard_idx], self.batch_size, self.feature_dim ) yield from loaderMemory-mapped I/O lets the operating system manage data caching automatically. Frequently accessed regions stay in RAM; rarely accessed regions are paged out. This provides near-optimal memory usage without explicit cache management.
When datasets span multiple machines—either due to size or data locality constraints—SVI must be adapted for distributed computation. Two primary paradigms exist:
1. Data Parallel SVI:
Data is partitioned across workers. Each worker:
2. Model Parallel SVI:
Variational parameters are partitioned across workers. Each worker:
Data parallelism is more common for SVI because parameters are typically smaller than data, and mini-batch processing naturally distributes.
| Aspect | Data Parallel | Model Parallel |
|---|---|---|
| Parameter location | Replicated on all workers | Partitioned across workers |
| Data location | Partitioned across workers | Accessible from all workers |
| Communication | Gradient aggregation | Activation sharing |
| Scaling bottleneck | Gradient sync bandwidth | Cross-worker data movement |
| Best for | Large datasets, moderate params | Huge models (billions of params) |
Synchronous data parallel SVI:
In synchronous training, all workers must complete their gradient computation before parameters update:
Synchronous training guarantees convergence equivalent to single-machine training with batch size \(M_{\text{total}} = K \cdot M_{\text{local}}\) where \(K\) is the number of workers.
Implementation with AllReduce:
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
import torchimport torch.distributed as distfrom torch.nn.parallel import DistributedDataParallel as DDP def setup_distributed(rank: int, world_size: int, backend: str = 'nccl'): """Initialize distributed training environment.""" dist.init_process_group( backend=backend, init_method='env://', world_size=world_size, rank=rank ) torch.cuda.set_device(rank) class DistributedSVI: """ Synchronous distributed SVI using PyTorch's DistributedDataParallel. Each worker processes a different data shard, and gradients are automatically synchronized via AllReduce. """ def __init__( self, model: torch.nn.Module, local_dataset, rank: int, world_size: int, batch_size: int = 128, learning_rate: float = 1e-3 ): self.rank = rank self.world_size = world_size # Wrap model for distributed training self.model = DDP(model.cuda(rank), device_ids=[rank]) # Each worker gets a different shard of data sampler = torch.utils.data.distributed.DistributedSampler( local_dataset, num_replicas=world_size, rank=rank, shuffle=True ) self.dataloader = torch.utils.data.DataLoader( local_dataset, batch_size=batch_size, sampler=sampler, pin_memory=True ) self.optimizer = torch.optim.Adam( self.model.parameters(), lr=learning_rate ) # Total dataset size across all workers (for ELBO scaling) local_size = torch.tensor([len(local_dataset)], device=f'cuda:{rank}') dist.all_reduce(local_size, op=dist.ReduceOp.SUM) self.total_N = local_size.item() def train_epoch(self) -> float: """Train for one epoch with synchronized gradients.""" self.dataloader.sampler.set_epoch(self.epoch if hasattr(self, 'epoch') else 0) total_elbo = 0.0 num_batches = 0 for batch in self.dataloader: x = batch[0].cuda(self.rank) batch_size = x.size(0) self.optimizer.zero_grad() # Forward pass (ELBO computation) z, kl_div = self.model.module.encode_and_sample(x) log_likelihood = self.model.module.decode_log_prob(x, z) # Scale likelihood by total N / effective batch size effective_batch = batch_size * self.world_size elbo = log_likelihood.sum() * (self.total_N / effective_batch) - kl_div.sum() loss = -elbo / batch_size # Per-sample loss for gradient scaling # Backward pass - DDP automatically syncs gradients loss.backward() self.optimizer.step() total_elbo += elbo.item() num_batches += 1 # Average ELBO across workers elbo_tensor = torch.tensor([total_elbo / num_batches], device=f'cuda:{self.rank}') dist.all_reduce(elbo_tensor, op=dist.ReduceOp.AVG) return elbo_tensor.item()Synchronous training has a critical flaw: all workers must wait for the slowest one. When workers have varying speeds (due to hardware differences, network latency, or data complexity), fast workers sit idle.
Asynchronous SGD addresses this by allowing workers to update parameters without waiting:
In async training, workers compute gradients on stale parameters. If a gradient is computed using parameters θ_t but applied when parameters are at θ_(t+k), the update direction may be incorrect. This staleness can cause divergence if not managed properly.
Hogwild! style updates:
For sparse gradients (where each update touches few parameters), asynchronous updates with lock-free writes work surprisingly well. The probability of conflicting writes is low, and the speedup from avoiding locks outweighs occasional inconsistencies.
Staleness-aware learning rates:
To compensate for staleness, reduce the learning rate for stale gradients:
$$\alpha_{\text{effective}} = \frac{\alpha}{1 + \lambda \cdot \tau}$$
where \(\tau\) is the staleness (number of updates since gradient was computed) and \(\lambda\) is a damping factor.
Bounded asynchrony:
A middle ground between sync and async: allow at most \(B\) outstanding updates before forcing synchronization. This bounds maximum staleness while preserving most async benefits.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
import numpy as npfrom threading import Lockfrom typing import Dictimport queueimport threading class AsyncParameterServer: """ Asynchronous parameter server for distributed SVI. Workers push gradients; server applies updates immediately. Supports staleness-aware learning rate adjustment. """ def __init__( self, initial_params: Dict[str, np.ndarray], base_learning_rate: float = 0.01, staleness_decay: float = 0.1, max_staleness: int = 10 ): self.params = {k: v.copy() for k, v in initial_params.items()} self.locks = {k: Lock() for k in self.params} self.base_lr = base_learning_rate self.staleness_decay = staleness_decay self.max_staleness = max_staleness self.global_step = 0 self.step_lock = Lock() def get_params(self) -> tuple: """Get current parameters and their version.""" with self.step_lock: version = self.global_step params_copy = {} for k, v in self.params.items(): with self.locks[k]: params_copy[k] = v.copy() return params_copy, version def apply_gradient( self, gradients: Dict[str, np.ndarray], version: int ) -> bool: """ Apply gradient update with staleness adjustment. Args: gradients: Gradient for each parameter version: Parameter version when gradient was computed Returns: Whether update was applied (False if too stale) """ with self.step_lock: staleness = self.global_step - version # Reject extremely stale updates if staleness > self.max_staleness: return False # Compute staleness-adjusted learning rate effective_lr = self.base_lr / (1 + self.staleness_decay * staleness) # Apply updates (fine-grained locking per parameter) for key, grad in gradients.items(): with self.locks[key]: self.params[key] += effective_lr * grad # Increment global step with self.step_lock: self.global_step += 1 return True class AsyncSVIWorker: """ Asynchronous SVI worker that pulls parameters, computes gradients, and pushes updates without waiting for other workers. """ def __init__( self, worker_id: int, param_server: AsyncParameterServer, data_shard: np.ndarray, model, batch_size: int = 128 ): self.worker_id = worker_id self.param_server = param_server self.data_shard = data_shard self.model = model self.batch_size = batch_size self.running = False def run(self, num_iterations: int): """Run the worker for specified iterations.""" self.running = True N_local = len(self.data_shard) for iteration in range(num_iterations): if not self.running: break # 1. Pull current parameters params, version = self.param_server.get_params() # 2. Sample mini-batch from local shard batch_idx = np.random.choice(N_local, self.batch_size, replace=False) batch = self.data_shard[batch_idx] # 3. Compute gradient (using current params) self.model.set_params(params) gradient = self.model.compute_gradient(batch) # 4. Push gradient to parameter server applied = self.param_server.apply_gradient(gradient, version) if not applied: print(f"Worker {self.worker_id}: Gradient rejected (too stale)") def stop(self): """Signal worker to stop.""" self.running = False def run_async_svi( model, data_shards: list, num_workers: int, iterations_per_worker: int): """ Run asynchronous SVI with multiple workers. """ # Initialize parameter server param_server = AsyncParameterServer( initial_params=model.get_initial_params(), base_learning_rate=0.01, staleness_decay=0.1 ) # Create workers workers = [] for i in range(num_workers): worker = AsyncSVIWorker( worker_id=i, param_server=param_server, data_shard=data_shards[i], model=model ) workers.append(worker) # Launch worker threads threads = [] for worker in workers: thread = threading.Thread( target=worker.run, args=(iterations_per_worker,) ) thread.start() threads.append(thread) # Wait for completion for thread in threads: thread.join() # Return final parameters return param_server.get_params()[0]In distributed settings, communication bandwidth often becomes the bottleneck. Transmitting full gradients with millions of parameters can dominate total training time.
Gradient compression techniques reduce communication by sending approximate gradients:
1. Quantization: Reduce gradient precision from 32-bit floats to lower precision:
2. Sparsification: Send only the largest gradient components:
3. Error accumulation: Store the "error" from compression and add it to future gradients:
$$e_{t+1} = g_t - Q(g_t + e_t)$$ $$\text{send} = Q(g_t + e_t)$$
where \(Q\) is the compression operator. This ensures all gradient information is eventually transmitted.
| Technique | Compression Ratio | Error Introduced | Convergence Impact |
|---|---|---|---|
| Float16 quantization | 2× | ~0.01% | Negligible |
| Int8 quantization | 4× | ~0.1% | Small slowdown |
| Top-1% sparsification | 100× | Varies | Moderate, recovers with error accumulation |
| 1-bit SGD (signSGD) | 32× | Significant | Converges but to different point |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
import numpy as npfrom typing import Tuple class GradientCompressor: """ Gradient compression with error feedback for distributed training. """ def __init__(self, compression_ratio: float = 0.01, use_error_feedback: bool = True): """ Args: compression_ratio: Fraction of gradients to keep (for Top-K) use_error_feedback: Whether to accumulate compression error """ self.compression_ratio = compression_ratio self.use_error_feedback = use_error_feedback self.error_buffer = None def compress_topk(self, gradient: np.ndarray) -> Tuple[np.ndarray, np.ndarray, tuple]: """ Top-K sparsification with error feedback. Returns: indices: Indices of non-zero elements values: Values at those indices shape: Original gradient shape (for reconstruction) """ flat_grad = gradient.flatten() # Add accumulated error if using error feedback if self.use_error_feedback and self.error_buffer is not None: flat_grad = flat_grad + self.error_buffer # Select top-k by magnitude k = max(1, int(len(flat_grad) * self.compression_ratio)) top_k_indices = np.argpartition(np.abs(flat_grad), -k)[-k:] # Extract values values = flat_grad[top_k_indices] # Compute error (for feedback) if self.use_error_feedback: compressed = np.zeros_like(flat_grad) compressed[top_k_indices] = values self.error_buffer = flat_grad - compressed return top_k_indices.astype(np.int32), values.astype(np.float16), gradient.shape def decompress( self, indices: np.ndarray, values: np.ndarray, shape: tuple ) -> np.ndarray: """Reconstruct gradient from compressed representation.""" flat_grad = np.zeros(np.prod(shape), dtype=np.float32) flat_grad[indices] = values.astype(np.float32) return flat_grad.reshape(shape) def compress_quantize( self, gradient: np.ndarray, bits: int = 8 ) -> Tuple[np.ndarray, float, float]: """ Uniform quantization to lower precision. Returns: quantized: Integer quantized gradient scale: Scale factor for dequantization min_val: Minimum value for dequantization """ min_val = gradient.min() max_val = gradient.max() # Scale to [0, 2^bits - 1] scale = (max_val - min_val) / (2**bits - 1) if scale == 0: scale = 1.0 # Handle constant gradients quantized = np.round((gradient - min_val) / scale).astype(np.uint8 if bits == 8 else np.uint16) return quantized, scale, min_val def dequantize( self, quantized: np.ndarray, scale: float, min_val: float ) -> np.ndarray: """Reconstruct gradient from quantized representation.""" return quantized.astype(np.float32) * scale + min_val class SignSGDCompressor: """ Extreme compression: send only gradient signs. Each gradient element compressed to 1 bit. Requires special aggregation (majority vote). """ def compress(self, gradient: np.ndarray) -> Tuple[np.ndarray, float]: """ Compress to signs only. Returns: signs: Packed bit array of gradient signs scale: Magnitude for reconstruction """ signs = (gradient >= 0).astype(np.uint8) scale = np.abs(gradient).mean() # Average magnitude for reconstruction # Pack bits (8 per byte) packed = np.packbits(signs.flatten()) return packed, scale def decompress(self, packed: np.ndarray, scale: float, original_size: int) -> np.ndarray: """Reconstruct gradient from signs.""" signs = np.unpackbits(packed)[:original_size] return (2 * signs.astype(np.float32) - 1) * scaleAn alternative to frequent gradient communication is Local SGD: workers take multiple local steps before synchronizing.
Local SGD algorithm:
This reduces communication by factor \(H\) compared to synchronous SGD, at the cost of some optimization inefficiency from divergent local trajectories.
Local SGD works well when: • Data is relatively homogeneous across workers (IID setting) • The loss surface is smooth (low gradient variance) • H is not too large relative to the size of gradient noise
For non-IID data or highly curved loss surfaces, local trajectories diverge significantly, requiring more frequent synchronization.
Federated Learning:
Local SGD principles extend to federated learning, where data lives on edge devices (phones, IoT sensors) that cannot share raw data due to privacy constraints.
Federated Averaging (FedAvg):
For probabilistic models, federated variational inference applies the same principles:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
import numpy as npfrom typing import List, Dict def federated_variational_inference( initial_global_params: Dict[str, np.ndarray], clients: List, # List of client objects with local_data and model num_rounds: int = 100, local_epochs: int = 5, client_fraction: float = 0.1) -> Dict[str, np.ndarray]: """ Federated Variational Inference with FedAvg-style aggregation. Args: initial_global_params: Initial variational parameters clients: List of client objects, each with local data num_rounds: Number of federation rounds local_epochs: Local training epochs per round client_fraction: Fraction of clients to sample each round Returns: Aggregated variational parameters """ global_params = {k: v.copy() for k, v in initial_global_params.items()} num_clients = len(clients) clients_per_round = max(1, int(client_fraction * num_clients)) for round_idx in range(num_rounds): # Select subset of clients selected_clients = np.random.choice( num_clients, clients_per_round, replace=False ) # Collect client updates client_params = [] client_weights = [] for client_idx in selected_clients: client = clients[client_idx] # Send global params to client local_params = {k: v.copy() for k, v in global_params.items()} # Client performs local VI for _ in range(local_epochs): local_params = client.local_vi_step(local_params) client_params.append(local_params) client_weights.append(len(client.local_data)) # Weighted averaging of client parameters total_weight = sum(client_weights) for key in global_params: global_params[key] = np.zeros_like(global_params[key]) for params, weight in zip(client_params, client_weights): global_params[key] += params[key] * (weight / total_weight) if round_idx % 10 == 0: print(f"Round {round_idx}: Updated from {clients_per_round} clients") return global_params class DifferentiallyPrivateSVI: """ Differentially private SVI for privacy-preserving inference. Adds calibrated noise to gradients to provide (ε, δ)-differential privacy. """ def __init__( self, epsilon: float = 1.0, delta: float = 1e-5, max_grad_norm: float = 1.0 ): """ Args: epsilon: Privacy parameter (lower = more private) delta: Privacy failure probability max_grad_norm: Gradient clipping threshold """ self.epsilon = epsilon self.delta = delta self.max_grad_norm = max_grad_norm # Compute noise scale using Gaussian mechanism # σ ≥ sqrt(2 * ln(1.25/δ)) * Δf / ε self.noise_scale = ( np.sqrt(2 * np.log(1.25 / delta)) * max_grad_norm / epsilon ) def privatize_gradient(self, gradient: np.ndarray) -> np.ndarray: """ Clip and add noise to gradient for differential privacy. """ # Gradient clipping grad_norm = np.linalg.norm(gradient) if grad_norm > self.max_grad_norm: gradient = gradient * (self.max_grad_norm / grad_norm) # Add Gaussian noise noise = np.random.normal(0, self.noise_scale, gradient.shape) return gradient + noiseMini-batch optimization subsamples data points, but for very large models, we can also subsample model components.
Subsampling the ELBO:
For models with many latent variables or mixture components, the ELBO may contain expensive sums:
$$\mathcal{L} = \sum_{k=1}^{K} \text{term}_k(\phi)$$
If \(K\) is large (millions of latent dimensions, mixture components, or neural network parameters), we can estimate \(\mathcal{L}\) by subsampling terms:
$$\hat{\mathcal{L}} = \frac{K}{S} \sum_{k \in \text{sample}} \text{term}_k(\phi)$$
Doubly stochastic variational inference:
Combining data subsampling with model component subsampling yields doubly stochastic estimators:
$$\hat{\mathcal{L}} = \frac{N}{M} \sum_{n \in \text{data sample}} \frac{K}{S} \sum_{k \in \text{model sample}} \ell_{n,k}(\phi)$$
This is unbiased as long as both sampling processes are unbiased. The variance compounds, requiring careful tuning of both sample sizes.
Example: Large-scale topic models
For LDA with millions of documents and millions of vocabulary terms:
Each level of subsampling reduces computation while preserving unbiasedness.
Non-uniform sampling with importance weighting can reduce variance dramatically. Sample model components proportional to their expected contribution magnitude, then reweight. This is particularly effective when contributions are highly variable across components.
Beyond algorithmic improvements, systems-level optimizations significantly impact large-scale SVI performance.
| Optimization | Impact | Implementation |
|---|---|---|
| Mixed-precision training | 2× memory, 2-3× speed on modern GPUs | Use FP16 for forward/backward, FP32 for accumulation |
| Gradient checkpointing | Trade compute for memory | Recompute activations during backward pass |
| Prefetching and pipelining | Hide data loading latency | Load next batch while computing current |
| Kernel fusion | Reduce memory bandwidth | Combine multiple ops into single GPU kernel |
| Tensor parallelism | Enables models too large for single GPU | Split matrix operations across devices |
Mixed-precision training for SVI:
Modern GPUs (Volta and later) have specialized tensor cores for FP16 matrix multiplication. Using mixed precision:
For variational inference, care is needed with:
Gradient checkpointing for deep probabilistic models:
Memory for activations grows linearly with model depth. Gradient checkpointing trades memory for compute:
This enables training deeper Bayesian networks with uncertainty quantification.
Scaling stochastic variational inference to truly large datasets requires a combination of algorithmic and systems-level techniques, each addressing different bottlenecks.
What's next:
With the mechanics of scalable SVI established, we now turn to convergence analysis—the theoretical guarantees that underpin these algorithms. Understanding convergence rates and conditions is essential for tuning hyperparameters and diagnosing optimization issues in practice.
You now understand the systems-level considerations for scaling SVI to massive datasets. From memory-efficient streaming to distributed training to gradient compression, these techniques together enable Bayesian inference at internet scale—making probabilistic machine learning practical for real-world applications with billions of observations.