Loading learning content...
The standard EM algorithm is remarkably versatile, but various application scenarios call for modifications. Large datasets may require stochastic or online variants. Regularization needs lead to MAP-EM. Computational efficiency motivates hard EM. Complex posterior distributions demand variational EM.
This page surveys the most important EM variants, explaining when each is appropriate, how they differ from standard EM, and their theoretical and practical tradeoffs. Understanding these variants enables you to apply EM effectively across diverse machine learning applications.
By the end of this page, you will understand: (1) Hard EM and its connection to K-means, (2) MAP-EM for regularized estimation, (3) Stochastic EM for escaping local optima, (4) Online EM for streaming data and large datasets, (5) Variational EM for intractable E-steps, and (6) how to choose the right variant for your application.
Hard EM (also called Classification EM or Viterbi EM) replaces the soft probabilistic E-step with a hard assignment step. Instead of computing posterior probabilities $\gamma_{nk} = p(z_{nk} = 1 \mid \mathbf{x}_n, \boldsymbol{\theta})$, we assign each point to its most likely component:
$$z_{nk}^{(t)} = \begin{cases} 1 & \text{if } k = \arg\max_j p(z_{nj} = 1 \mid \mathbf{x}_n, \boldsymbol{\theta}^{(t)}) \ 0 & \text{otherwise} \end{cases}$$
The M-step then updates parameters using only points assigned to each component.
The Hard EM Algorithm for GMMs
E-step (Hard): For each data point, find the most likely component: $$k^*_n = \arg\max_k \pi_k^{(t)} , \mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_k^{(t)}, \boldsymbol{\Sigma}_k^{(t)})$$
M-step: Update parameters using only assigned points: $$N_k = |{n : k^_n = k}|$$ $$\boldsymbol{\mu}k^{(t+1)} = \frac{1}{N_k} \sum{n : k^_n = k} \mathbf{x}_n$$ $$\boldsymbol{\Sigma}k^{(t+1)} = \frac{1}{N_k} \sum{n : k^*_n = k} (\mathbf{x}_n - \boldsymbol{\mu}_k^{(t+1)})(\mathbf{x}_n - \boldsymbol{\mu}_k^{(t+1)})^\top$$
When all covariances are constrained to be σ²I (spherical, equal size), hard EM is exactly equivalent to K-means clustering! The assignment step becomes 'assign to nearest centroid' and the M-step updates centroids as cluster means. This reveals K-means as a special case of GMM fitting with hard assignments.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
import numpy as npfrom scipy.stats import multivariate_normal def hard_em_gmm(X, K, max_iters=100, tol=1e-6): """ Hard EM algorithm for Gaussian Mixture Models. Uses hard assignments instead of soft responsibilities. """ N, D = X.shape # Initialize with K-means++ pi, mu, sigma = init_kmeans_plus_plus(X, K) prev_assignments = None for iteration in range(max_iters): # ============ HARD E-STEP ============ # Compute log-probabilities for numerical stability log_probs = np.zeros((N, K)) for k in range(K): log_probs[:, k] = ( np.log(pi[k] + 1e-300) + multivariate_normal.logpdf(X, mu[k], sigma[k]) ) # Hard assignment: argmax for each point assignments = np.argmax(log_probs, axis=1) # Check for convergence (no assignment changes) if prev_assignments is not None and np.all(assignments == prev_assignments): print(f"Hard EM converged at iteration {iteration}") break prev_assignments = assignments.copy() # ============ M-STEP ============ for k in range(K): mask = (assignments == k) N_k = mask.sum() if N_k < 2: # Handle empty clusters: reinitialize mu[k] = X[np.random.randint(N)] sigma[k] = np.eye(D) * np.var(X) pi[k] = 1.0 / K else: cluster_data = X[mask] pi[k] = N_k / N mu[k] = cluster_data.mean(axis=0) sigma[k] = np.cov(cluster_data.T) + 1e-6 * np.eye(D) # Normalize mixing coefficients pi = pi / pi.sum() return pi, mu, sigma, assignmentsMAP-EM modifies the optimization objective from maximum likelihood to maximum a posteriori (MAP) estimation. By incorporating prior distributions over parameters, we regularize the solution and prevent pathologies like singular covariances.
The MAP Objective
Instead of maximizing $p(\mathbf{X} \mid \boldsymbol{\theta})$, we maximize:
$$p(\boldsymbol{\theta} \mid \mathbf{X}) \propto p(\mathbf{X} \mid \boldsymbol{\theta}) , p(\boldsymbol{\theta})$$
Taking logarithms:
$$\log p(\boldsymbol{\theta} \mid \mathbf{X}) = \log p(\mathbf{X} \mid \boldsymbol{\theta}) + \log p(\boldsymbol{\theta}) + \text{const}$$
The prior $p(\boldsymbol{\theta})$ acts as a regularizer on the parameters.
Common Priors for GMMs
Dirichlet prior on mixing coefficients: $$p(\boldsymbol{\pi}) = \text{Dir}(\boldsymbol{\pi} \mid \alpha_1, \ldots, \alpha_K)$$ This prevents components from having zero mixing coefficients.
Inverse-Wishart prior on covariances: $$p(\boldsymbol{\Sigma}_k) = \text{IW}(\boldsymbol{\Sigma}_k \mid \nu_0, \mathbf{S}_0)$$ This prevents singular covariances and shrinks toward a prior scale matrix.
Normal prior on means: $$p(\boldsymbol{\mu}_k) = \mathcal{N}(\boldsymbol{\mu}_k \mid \mathbf{m}_0, \mathbf{V}_0)$$ This shrinks means toward a prior location.
MAP-EM Update Equations
With conjugate priors, the M-step has modified closed-form solutions:
Mixing coefficients (with Dirichlet prior $\text{Dir}(\alpha, \ldots, \alpha)$): $$\pi_k^{\text{MAP}} = \frac{N_k + \alpha - 1}{N + K(\alpha - 1)}$$
Covariances (with Inverse-Wishart prior $\text{IW}(\nu_0, \mathbf{S}_0)$): $$\boldsymbol{\Sigma}k^{\text{MAP}} = \frac{\mathbf{S}0 + \sum{n=1}^{N} \gamma{nk}(\mathbf{x}_n - \boldsymbol{\mu}_k)(\mathbf{x}_n - \boldsymbol{\mu}_k)^\top}{\nu_0 + N_k + D + 2}$$
The prior adds "pseudo-counts" that prevent degeneracies.
For regularization without strong assumptions: (1) Set α = 1 + ε for a nearly uniform Dirichlet (prevents zero mixing), (2) Set S₀ = ε·I and ν₀ = D + 2 for minimal inverse-Wishart (prevents singular covariances), (3) Skip the mean prior unless you have domain knowledge. These 'weakly informative' priors regularize without dominating the likelihood.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
def map_em_gmm(X, K, alpha=1.01, nu_0=None, S_0=None, max_iters=100): """ MAP-EM with conjugate priors for regularization. Parameters: X: (N, D) data matrix K: number of components alpha: Dirichlet concentration (α > 1 for regularization) nu_0: Inverse-Wishart degrees of freedom (default: D + 2) S_0: Inverse-Wishart scale matrix (default: small * I) Returns: pi, mu, sigma: MAP estimates """ N, D = X.shape # Default priors if nu_0 is None: nu_0 = D + 2 # Minimum for valid IW if S_0 is None: S_0 = 1e-3 * np.eye(D) # Weak regularization # Initialize pi, mu, sigma = init_kmeans(X, K) for iteration in range(max_iters): # E-step (unchanged from standard EM) gamma, N_k = e_step(X, pi, mu, sigma) # M-step with MAP updates # Mixing coefficients with Dirichlet prior pi = (N_k + alpha - 1) / (N + K * (alpha - 1)) pi = np.maximum(pi, 1e-10) # Prevent zeros pi = pi / pi.sum() # Means (unchanged, assuming flat prior) mu = (gamma.T @ X) / N_k[:, np.newaxis] # Covariances with Inverse-Wishart prior for k in range(K): diff = X - mu[k] weighted_scatter = (gamma[:, k:k+1] * diff).T @ diff # MAP estimate with IW prior sigma[k] = (S_0 + weighted_scatter) / (nu_0 + N_k[k] + D + 2) return pi, mu, sigma def regularized_covariance_em(X, K, reg_covar=1e-3, max_iters=100): """ Simple regularization: add small constant to diagonal. Equivalent to MAP-EM with specific IW prior. """ N, D = X.shape pi, mu, sigma = init_kmeans(X, K) for iteration in range(max_iters): gamma, N_k = e_step(X, pi, mu, sigma) pi = N_k / N mu = (gamma.T @ X) / N_k[:, np.newaxis] for k in range(K): diff = X - mu[k] sigma[k] = (gamma[:, k:k+1] * diff).T @ diff / N_k[k] # Simple regularization sigma[k] += reg_covar * np.eye(D) return pi, mu, sigmaStochastic EM introduces randomness into the algorithm to help escape local optima and explore the parameter space more broadly.
The SEM Algorithm
Unlike standard EM which uses expected assignments, SEM samples actual assignments each iteration.
Properties of Stochastic EM
Annealing Stochastic EM
A common enhancement is to anneal the stochasticity: $$\gamma_{nk}^{\text{annealed}} \propto \gamma_{nk}^{1/T}$$
Start with high temperature $T$ (random assignments), gradually decrease to $T = 0$ (deterministic). This combines exploration (high $T$) with exploitation (low $T$).
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
def stochastic_em_gmm(X, K, max_iters=500, burn_in=100, random_state=None): """ Stochastic EM for GMMs. Returns averaged parameters over post-burn-in iterations. """ rng = np.random.default_rng(random_state) N, D = X.shape # Initialize pi, mu, sigma = init_kmeans(X, K) # Storage for averaging param_samples = {'pi': [], 'mu': [], 'sigma': []} for iteration in range(max_iters): # Standard E-step: compute responsibilities gamma, _ = e_step(X, pi, mu, sigma) # STOCHASTIC STEP: Sample hard assignments assignments = np.array([ rng.choice(K, p=gamma[n]) for n in range(N) ]) # M-step with sampled assignments for k in range(K): mask = (assignments == k) N_k = mask.sum() if N_k < 2: continue # Keep previous parameters cluster_data = X[mask] pi[k] = N_k / N mu[k] = cluster_data.mean(axis=0) sigma[k] = np.cov(cluster_data.T) + 1e-6 * np.eye(D) pi = pi / pi.sum() # Store samples after burn-in if iteration >= burn_in: param_samples['pi'].append(pi.copy()) param_samples['mu'].append(mu.copy()) param_samples['sigma'].append(sigma.copy()) # Average over samples pi_final = np.mean(param_samples['pi'], axis=0) mu_final = np.mean(param_samples['mu'], axis=0) sigma_final = np.mean(param_samples['sigma'], axis=0) return pi_final, mu_final, sigma_final def annealed_em_gmm(X, K, temp_schedule, max_iters=200): """ Annealed EM: gradually sharpen responsibilities. Parameters: temp_schedule: function(iteration) -> temperature T """ N, D = X.shape pi, mu, sigma = init_kmeans(X, K) for iteration in range(max_iters): T = temp_schedule(iteration) # Compute raw responsibilities gamma_raw = np.zeros((N, K)) for k in range(K): gamma_raw[:, k] = pi[k] * multivariate_normal.pdf(X, mu[k], sigma[k]) # Anneal: raise to power 1/T gamma_annealed = gamma_raw ** (1.0 / T) gamma = gamma_annealed / gamma_annealed.sum(axis=1, keepdims=True) # Standard M-step N_k = gamma.sum(axis=0) + 1e-10 pi = N_k / N mu = (gamma.T @ X) / N_k[:, np.newaxis] for k in range(K): diff = X - mu[k] sigma[k] = (gamma[:, k:k+1] * diff).T @ diff / N_k[k] sigma[k] += 1e-6 * np.eye(D) return pi, mu, sigma # Example temperature schedule: exponential decaydef exponential_temp_schedule(iteration, T_init=10.0, T_final=1.0, total_iters=200): """Temperature decreases exponentially from T_init to T_final.""" decay_rate = (T_final / T_init) ** (1.0 / total_iters) return T_init * (decay_rate ** iteration)Stochastic EM is most valuable when: (1) standard EM consistently finds poor local optima, (2) you want posterior uncertainty estimates (via sample variance), (3) the data has complex multi-modal structure, or (4) you're willing to trade more iterations for better solutions. For simple problems, standard EM with restarts is usually sufficient.
Online EM processes data incrementally, making it suitable for:
The key idea is to update parameters after each observation (or mini-batch) rather than after seeing all data.
Incremental EM Formulation
Define sufficient statistics that summarize previous data:
When a new observation $\mathbf{x}_{\text{new}}$ arrives:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
class OnlineEMGMM: """ Online EM for Gaussian Mixture Models. Processes data incrementally with decaying learning rate. """ def __init__(self, K, D, learning_rate='sqrt', mini_batch_size=1): """ Parameters: K: number of components D: data dimensionality learning_rate: 'sqrt' for 1/sqrt(t), 'linear' for 1/t mini_batch_size: number of samples per update """ self.K = K self.D = D self.lr_type = learning_rate self.batch_size = mini_batch_size self.t = 0 # iteration counter # Initialize parameters (will be set on first data) self.pi = np.ones(K) / K self.mu = None self.sigma = None # Sufficient statistics (unnormalized) self.S0 = np.ones(K) # counts self.S1 = None # weighted sums self.S2 = None # weighted outer products def _learning_rate(self): """Compute current learning rate.""" if self.lr_type == 'sqrt': return 1.0 / np.sqrt(self.t + 1) else: # linear return 1.0 / (self.t + 1) def partial_fit(self, X_batch): """ Update model with a batch of observations. """ N_batch = len(X_batch) # Initialize on first call if self.mu is None: self._initialize(X_batch) return # E-step: compute responsibilities gamma = self._compute_responsibilities(X_batch) # Update sufficient statistics eta = self._learning_rate() gamma_sum = gamma.sum(axis=0) # (K,) weighted_sum = gamma.T @ X_batch # (K, D) weighted_outer = np.array([ (gamma[:, k:k+1] * X_batch).T @ X_batch for k in range(self.K) ]) # (K, D, D) self.S0 = (1 - eta) * self.S0 + eta * gamma_sum self.S1 = (1 - eta) * self.S1 + eta * weighted_sum self.S2 = (1 - eta) * self.S2 + eta * weighted_outer # M-step: recompute parameters from sufficient statistics self._update_parameters() self.t += 1 def _compute_responsibilities(self, X): """Compute responsibilities for given data.""" N = len(X) gamma = np.zeros((N, self.K)) for k in range(self.K): gamma[:, k] = self.pi[k] * multivariate_normal.pdf( X, self.mu[k], self.sigma[k] ) gamma = gamma / gamma.sum(axis=1, keepdims=True) return gamma def _update_parameters(self): """Recompute parameters from sufficient statistics.""" # Normalize N_k = self.S0 + 1e-10 self.pi = N_k / N_k.sum() self.mu = self.S1 / N_k[:, np.newaxis] for k in range(self.K): self.sigma[k] = ( self.S2[k] / N_k[k] - np.outer(self.mu[k], self.mu[k]) ) self.sigma[k] += 1e-4 * np.eye(self.D) def _initialize(self, X_batch): """Initialize from first batch.""" # Use K-means on first batch self.pi, self.mu, self.sigma = init_kmeans(X_batch, self.K) # Initialize sufficient statistics self.S0 = np.ones(self.K) self.S1 = self.mu * self.S0[:, np.newaxis] self.S2 = np.array([ self.sigma[k] + np.outer(self.mu[k], self.mu[k]) for k in range(self.K) ])Online EM trades batch optimality for scalability. Key challenges: (1) Learning rate selection—too fast causes oscillation, too slow causes slow adaptation, (2) Initialization from limited data may be poor, (3) No convergence guarantees in non-stationary settings. Use mini-batches (size 50-500) rather than single observations for stability.
Mini-batch EM processes data in small batches rather than all at once (batch EM) or one at a time (online EM). This provides a practical middle ground for large datasets.
Mini-Batch EM Algorithm
Comparison to Alternatives:
| Method | Memory | Per-Iteration Cost | Convergence | Parallelizable |
|---|---|---|---|---|
| Batch EM | $O(NK)$ | $O(NKD^2)$ | Fast & stable | Yes |
| Online EM | $O(KD^2)$ | $O(KD^2)$ | Slower, noisy | Limited |
| Mini-Batch EM | $O(BK)$ | $O(BKD^2)$ | Good balance | Yes |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
def mini_batch_em_gmm(X, K, batch_size=256, n_epochs=10): """ Mini-batch EM for large datasets. Processes data in batches, updating after each epoch. """ N, D = X.shape # Initialize pi, mu, sigma = init_kmeans(X[:min(N, 10000)], K) # Init from subset for epoch in range(n_epochs): # Shuffle data indices = np.random.permutation(N) # Accumulators for sufficient statistics S0 = np.zeros(K) # γ sums S1 = np.zeros((K, D)) # weighted x sums S2 = np.zeros((K, D, D)) # weighted xx^T sums for start_idx in range(0, N, batch_size): end_idx = min(start_idx + batch_size, N) batch_indices = indices[start_idx:end_idx] X_batch = X[batch_indices] # E-step for batch gamma_batch, _ = e_step(X_batch, pi, mu, sigma) # Accumulate sufficient statistics S0 += gamma_batch.sum(axis=0) S1 += gamma_batch.T @ X_batch for k in range(K): diff = X_batch - mu[k] S2[k] += (gamma_batch[:, k:k+1] * diff).T @ diff # M-step: update from accumulated statistics N_k = S0 + 1e-10 pi = N_k / N mu = S1 / N_k[:, np.newaxis] for k in range(K): sigma[k] = S2[k] / N_k[k] + 1e-6 * np.eye(D) # Optional: compute log-likelihood for monitoring if epoch % 2 == 0: ll = compute_log_likelihood(X[:5000], pi, mu, sigma) print(f"Epoch {epoch}: log-likelihood (sample) = {ll:.2f}") return pi, mu, sigmaFor some models, the exact E-step—computing $p(\mathbf{Z} \mid \mathbf{X}, \boldsymbol{\theta})$—is intractable. Variational EM approximates this posterior with a simpler distribution $q(\mathbf{Z})$.
The Variational Principle
Recall the ELBO decomposition: $$\mathcal{L}(\boldsymbol{\theta}) = \text{ELBO}(q, \boldsymbol{\theta}) + \text{KL}(q(\mathbf{Z}) | p(\mathbf{Z} \mid \mathbf{X}, \boldsymbol{\theta}))$$
When the true posterior is intractable, we:
Mean-Field Approximation
A common choice is the mean-field factorization: $$q(\mathbf{Z}) = \prod_{n=1}^{N} q_n(\mathbf{z}_n)$$
This assumes latent variables are independent under $q$, making optimization tractable. For GMMs, mean-field variational inference yields similar update equations to standard EM, but with modified responsibility calculations.
When is Variational EM Needed?
Variational Autoencoders (VAEs) are a deep learning extension of variational EM. The 'encoder' network parameterizes q(z|x), the 'decoder' parameterizes p(x|z), and training maximizes the ELBO. Understanding variational EM provides the theoretical foundation for VAEs and other modern generative models.
Selecting the appropriate EM variant depends on your specific requirements and constraints. Here's a decision framework:
| Scenario | Recommended Variant | Why |
|---|---|---|
| Small-medium data (< 100K), standard use | Standard EM | Best convergence guarantees, simple |
| Discrete clusters needed | Hard EM | Produces hard assignments, faster |
| Regularization needed | MAP-EM | Prevents overfitting, handles singularities |
| Stuck in local optima | Stochastic EM / Annealed EM | Explores solution space |
| Large data (> 1M) that fits in memory | Mini-Batch EM | Efficient, parallelizable |
| Streaming data / memory constrained | Online EM | Constant memory, adaptive |
| Complex models (LDA, VAE) | Variational EM | Handles intractable posteriors |
You have completed the Expectation-Maximization Algorithm module. You now understand: (1) the mathematical foundation of EM for GMMs, (2) the complete derivation of E-step and M-step, (3) convergence properties and guarantees, (4) initialization strategies for practical success, and (5) important variants for different scenarios. This knowledge forms the foundation for applying EM across diverse latent variable models in machine learning.