Loading learning content...
Standard mixture models—whether Gaussian, Student-t, or otherwise—model the marginal distribution $p(\mathbf{y})$ of observations. Each component has a fixed mixing proportion $\pi_k$ that applies uniformly across the entire data space. But what if the "right" component to use varies depending on context?
Consider predicting house prices. In urban areas, prices might follow one distribution (high mean, moderate variance). In suburban areas, another pattern emerges (lower mean, higher variance). A standard mixture model would blend these patterns into a single, averaged prediction. But intuitively, if we know whether a house is urban or suburban, we should use different models.
Mixture of Experts (MoE) models formalize this intuition. Instead of fixed mixing proportions $\pi_k$, they use input-dependent gating functions $g_k(\mathbf{x})$ that determine which "expert" (component) is most relevant for each input. This transforms mixture models from unconditional density estimators into powerful conditional models $p(\mathbf{y} | \mathbf{x})$.
MoE models are foundational to modern machine learning: they connect classical mixture models to neural networks, form the basis of Gated Mixture of Experts used in large language models, and provide a principled framework for combining specialized models.
By the end of this page, you will understand: (1) The fundamental formulation of MoE as input-dependent mixtures; (2) Gating networks and softmax parameterization; (3) Expert networks for regression and classification; (4) EM algorithm derivation for MoE training; (5) Hierarchical MoE architectures; (6) Modern applications in deep learning including sparse MoE layers.
Before diving into Mixture of Experts, let's understand precisely what problem it solves by examining the limitations of standard approaches.
In supervised learning, we want to model the conditional distribution $p(y | \mathbf{x})$ or, at minimum, predict $\mathbb{E}[y | \mathbf{x}]$. Standard approaches:
Linear regression: Assumes $p(y | \mathbf{x}) = \mathcal{N}(y | \mathbf{w}^T\mathbf{x}, \sigma^2)$. This is a single, global model.
Gaussian Mixture Model: Models the marginal $p(y) = \sum_k \pi_k \mathcal{N}(y | \mu_k, \sigma_k^2)$. This captures multimodality in $y$ but ignores $\mathbf{x}$ entirely.
Neither approach directly addresses input-dependent multimodality—situations where:
Consider a robot arm: given a target position $(x, y)$ in Cartesian coordinates, find joint angles $\theta$ that reach that position. For most targets, multiple solutions exist (the arm can reach the same point in different configurations). The mapping from position to angles is one-to-many.
A standard regression model averages over solutions, predicting joint angles that might not correspond to any valid configuration. What we need is a model that:
| Approach | Models | Handles Multimodality? | Input-Dependent? |
|---|---|---|---|
| Linear Regression | E[y|x] | No | Yes (via x) |
| Gaussian Mixture | p(y) | Yes | No |
| Joint GMM on (x,y) | p(x,y) | Yes | Yes (via conditioning) |
| Mixture of Experts | p(y|x) directly | Yes | Yes (gating network) |
One could model p(x, y) jointly with a GMM and compute p(y|x) by conditioning. However, this requires modeling p(x), which may be unnecessary and wasteful. MoE models p(y|x) directly, focusing capacity where it matters.
The Mixture of Experts model defines the conditional distribution as a mixture where both the mixing proportions and the component densities depend on the input.
$$p(y | \mathbf{x}, \boldsymbol{\Theta}) = \sum_{k=1}^K g_k(\mathbf{x} | \boldsymbol{\theta}_g) \cdot p(y | \mathbf{x}, \boldsymbol{\theta}_k)$$
Components:
Gating network $g_k(\mathbf{x} | \boldsymbol{\theta}_g)$: Computes the input-dependent probability of selecting expert $k$. Must satisfy:
Expert networks $p(y | \mathbf{x}, \boldsymbol{\theta}_k)$: Each expert provides a conditional distribution over outputs given inputs. Common choices:
Latent variable interpretation: Let $z \in {1, \ldots, K}$ indicate which expert generated the output:
The most common gating function uses a softmax over linear combinations:
$$g_k(\mathbf{x}) = \frac{\exp(\mathbf{v}k^T \mathbf{x})}{\sum{j=1}^K \exp(\mathbf{v}_j^T \mathbf{x})} = \text{softmax}(\mathbf{V}^T \mathbf{x})_k$$
where $\mathbf{V} = [\mathbf{v}_1, \ldots, \mathbf{v}_K]$ are the gating parameters.
Properties of softmax gating:
With Gaussian linear experts:
$$p(y | \mathbf{x}) = \sum_{k=1}^K g_k(\mathbf{x}) \cdot \mathcal{N}(y | \mathbf{w}_k^T \mathbf{x}, \sigma_k^2)$$
This is a piecewise linear model with soft boundaries. In regions where $g_k(\mathbf{x}) \approx 1$, the model behaves like expert $k$'s linear regression. Transition regions blend multiple experts.
Parameter count: For $D$-dimensional inputs and $K$ experts:
Softmax produces 'soft' gating where multiple experts contribute. For efficiency (especially in large models), 'hard' gating selects only top-k experts. This sparse MoE approach is crucial for scaling—we'll discuss it when covering modern applications.
Training Mixture of Experts models follows the EM framework, treating the expert assignment as latent. However, because the gating function is parameterized, the M-step typically requires gradient-based optimization.
Given data ${(\mathbf{x}n, y_n)}{n=1}^N$, we want to maximize:
$$\mathcal{L}(\boldsymbol{\Theta}) = \sum_{n=1}^N \log p(y_n | \mathbf{x}n, \boldsymbol{\Theta}) = \sum{n=1}^N \log \sum_{k=1}^K g_k(\mathbf{x}_n) \cdot p(y_n | \mathbf{x}_n, \boldsymbol{\theta}_k)$$
For each observation, compute the posterior probability that expert $k$ generated $(\mathbf{x}_n, y_n)$:
$$h_{nk} = p(z_n = k | y_n, \mathbf{x}_n) = \frac{g_k(\mathbf{x}_n) \cdot p(y_n | \mathbf{x}_n, \boldsymbol{\theta}k)}{\sum{j=1}^K g_j(\mathbf{x}_n) \cdot p(y_n | \mathbf{x}_n, \boldsymbol{\theta}_j)}$$
Note the key difference from standard GMM: the responsibilities $h_{nk}$ depend on both $\mathbf{x}_n$ (through gating) and $y_n$ (through expert likelihood).
The expected complete-data log-likelihood is:
$$Q(\boldsymbol{\Theta} | \boldsymbol{\Theta}^{(t)}) = \sum_{n=1}^N \sum_{k=1}^K h_{nk} \left[ \log g_k(\mathbf{x}_n) + \log p(y_n | \mathbf{x}_n, \boldsymbol{\theta}_k) \right]$$
This decomposes into two independent optimization problems:
Gating parameters: Maximize $$Q_g = \sum_{n=1}^N \sum_{k=1}^K h_{nk} \log g_k(\mathbf{x}_n)$$
For softmax gating, this is equivalent to weighted multinomial logistic regression with soft labels $h_{nk}$. Solved via gradient descent:
$$\frac{\partial Q_g}{\partial \mathbf{v}k} = \sum{n=1}^N (h_{nk} - g_k(\mathbf{x}_n)) \mathbf{x}_n$$
Expert parameters: For each expert $k$, maximize $$Q_k = \sum_{n=1}^N h_{nk} \log p(y_n | \mathbf{x}_n, \boldsymbol{\theta}_k)$$
For Gaussian linear experts, this is weighted least squares:
$$\mathbf{w}k^{\text{new}} = \left( \sum{n=1}^N h_{nk} \mathbf{x}n \mathbf{x}n^T \right)^{-1} \sum{n=1}^N h{nk} \mathbf{x}_n y_n$$
$$\sigma_k^{2, \text{new}} = \frac{\sum_{n=1}^N h_{nk} (y_n - \mathbf{w}k^T \mathbf{x}n)^2}{\sum{n=1}^N h{nk}}$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
import numpy as npfrom scipy.special import softmax class MixtureOfExperts: """Mixture of Experts with linear experts and softmax gating.""" def __init__(self, n_experts=3, max_iter=100, tol=1e-4): self.K = n_experts self.max_iter = max_iter self.tol = tol def _initialize(self, X, y): """Initialize parameters randomly.""" N, D = X.shape self.D = D # Gating parameters (K x D matrix) self.V = np.random.randn(self.K, D) * 0.1 # Expert parameters self.W = np.random.randn(self.K, D) * 0.1 # regression weights self.sigmas = np.ones(self.K) * np.std(y) # noise std devs def _gating(self, X): """Compute gating probabilities (N x K).""" logits = X @ self.V.T # N x K return softmax(logits, axis=1) def _expert_log_likelihood(self, X, y): """Compute log p(y|x, expert k) for each expert (N x K).""" N = X.shape[0] log_lik = np.zeros((N, self.K)) for k in range(self.K): mean = X @ self.W[k] # N predictions log_lik[:, k] = (-0.5 * np.log(2 * np.pi * self.sigmas[k]**2) - 0.5 * ((y - mean) / self.sigmas[k])**2) return log_lik def _e_step(self, X, y): """Compute responsibilities h_nk.""" g = self._gating(X) # N x K log_lik = self._expert_log_likelihood(X, y) # N x K # Log responsibilities (unnormalized) log_h = np.log(g + 1e-10) + log_lik # Normalize (log-sum-exp) log_h_max = log_h.max(axis=1, keepdims=True) log_sum = log_h_max + np.log(np.exp(log_h - log_h_max).sum(axis=1, keepdims=True)) log_h -= log_sum h = np.exp(log_h) # Log-likelihood for monitoring ll = log_sum.sum() return h, ll def _m_step(self, X, y, h): """Update gating and expert parameters.""" N = X.shape[0] # Update gating via gradient ascent g = self._gating(X) grad_V = (h - g).T @ X # K x D self.V += 0.1 * grad_V / N # learning rate 0.1 # Update experts via weighted least squares for k in range(self.K): weights = h[:, k] W_sum = weights.sum() # Weighted least squares XtWX = (X.T * weights) @ X # D x D XtWy = (X.T * weights) @ y # D # Regularized solve self.W[k] = np.linalg.solve(XtWX + 1e-6 * np.eye(self.D), XtWy) # Weighted variance residuals = y - X @ self.W[k] self.sigmas[k] = np.sqrt((weights * residuals**2).sum() / W_sum + 1e-6) def fit(self, X, y): """Fit the MoE model using EM.""" # Add bias term X = np.column_stack([np.ones(len(X)), X]) self._initialize(X, y) prev_ll = -np.inf for iteration in range(self.max_iter): # E-step h, ll = self._e_step(X, y) # Check convergence if abs(ll - prev_ll) < self.tol: print(f"Converged at iteration {iteration}") break prev_ll = ll # M-step self._m_step(X, y, h) return self def predict(self, X, return_std=False): """Predict using weighted expert outputs.""" X = np.column_stack([np.ones(len(X)), X]) g = self._gating(X) # N x K # Weighted mean prediction predictions = np.zeros(len(X)) variances = np.zeros(len(X)) for k in range(self.K): mean_k = X @ self.W[k] predictions += g[:, k] * mean_k variances += g[:, k] * (self.sigmas[k]**2 + mean_k**2) variances -= predictions**2 # Var = E[X^2] - E[X]^2 if return_std: return predictions, np.sqrt(variances) return predictionsA single level of gating may not capture complex input-output relationships. Hierarchical Mixture of Experts (HME) extends the framework by organizing experts in a tree structure, with gating functions at each internal node.
In a two-level HME:
The conditional distribution becomes:
$$p(y | \mathbf{x}) = \sum_{m=1}^M g_m^{(1)}(\mathbf{x}) \sum_{k=1}^{K_m} g_{mk}^{(2)}(\mathbf{x}) \cdot p(y | \mathbf{x}, \boldsymbol{\theta}_{mk})$$
Advantages of hierarchy:
The E-step computes responsibilities at all levels:
$$h_{nmk} = p(z^{(1)} = m, z^{(2)} = k | y_n, \mathbf{x}_n) = \frac{g_m^{(1)}(\mathbf{x}n) \cdot g{mk}^{(2)}(\mathbf{x}n) \cdot p(y_n | \mathbf{x}n, \boldsymbol{\theta}{mk})}{\sum{m', k'} \cdots}$$
We can also compute marginal responsibilities:
The M-step updates each gating network and expert independently, using the appropriate responsibilities.
HME provides a probabilistic generalization of decision trees:
| Decision Tree | HME |
|---|---|
| Hard splits at nodes | Soft probabilistic splits (gating) |
| Constant predictions at leaves | Parametric models at leaves (experts) |
| Greedy construction | Global optimization (EM) |
| Axis-aligned splits | Oblique (linear) splits |
| No uncertainty quantification | Full probabilistic predictions |
This connection makes HME attractive for interpretable modeling—the tree structure provides some insight into the decision process.
Deeper hierarchies enable more complex partitioning but increase computational cost and risk of overfitting. In practice, 2-3 levels often suffice. Width (number of experts per level) controls local approximation quality. The optimal structure depends on the problem's inherent hierarchical nature.
The Mixture of Experts framework has experienced a remarkable resurgence in deep learning, particularly in large language models (LLMs). Modern implementations differ from classical MoE in several important ways.
In classical MoE, all experts contribute to every prediction (soft gating). This becomes prohibitive with hundreds or thousands of experts. Sparse MoE addresses this:
$$p(y | \mathbf{x}) = \sum_{k \in \text{Top-}K} g_k(\mathbf{x}) \cdot p(y | \mathbf{x}, \boldsymbol{\theta}_k)$$
where Top-$K$ selects only the $K$ experts with highest gating weights. Common choice: $K = 1$ or $K = 2$.
Benefits:
Challenges:
1. Noisy Top-K Gating (Shazeer et al., 2017)
Add noise before taking top-K to encourage exploration: $$g(\mathbf{x}) = \text{softmax}(\mathbf{x}^T W_g + \epsilon \cdot \text{Softplus}(\mathbf{x}^T W_{\text{noise}}))$$
2. Load Balancing Losses
Add auxiliary losses that penalize uneven expert usage: $$\mathcal{L}_{\text{balance}} = N \cdot \sum_k f_k \cdot P_k$$
where $f_k$ is the fraction of tokens routed to expert $k$, and $P_k$ is the average gating probability for expert $k$.
3. Expert Capacity Limits
Limit the number of tokens each expert can process: $$\text{capacity} = \frac{\text{tokens per batch}}{\text{number of experts}} \times \text{capacity factor}$$
Tokens exceeding capacity are dropped or processed by a residual connection.
4. Switch Transformer (Fedus et al., 2021)
Simplified MoE with $K = 1$ routing and careful initialization. Demonstrated that sparse MoE can scale to trillion-parameter models while maintaining training efficiency.
| Model | Parameters (Total) | Experts | Routing | Key Innovation |
|---|---|---|---|---|
| GShard (Google) | 600B | 2048 | Top-2 | Distributed MoE training |
| Switch Transformer | 1.6T | 2048 | Top-1 | Simplified routing |
| GLaM (Google) | 1.2T | 64 | Top-2 | Focused on efficiency |
| Mixtral 8x7B | 47B | 8 | Top-2 | Open-source MoE |
A model like Mixtral 8x7B has 8 experts of 7B parameters each (roughly 47B total). But with Top-2 routing, only ~14B parameters are activated per token—similar compute to a dense 14B model, but with 47B capacity for specialization. This separation of capacity from compute is the key insight.
Beyond linear models, experts can be arbitrary neural networks. This dramatically increases expressiveness but requires end-to-end gradient training rather than EM.
Standard form: Replace each expert $p(y | \mathbf{x}, \boldsymbol{\theta}_k)$ with a neural network $f_k(\mathbf{x})$:
$$\hat{y} = \sum_{k=1}^K g_k(\mathbf{x}) \cdot f_k(\mathbf{x})$$
where:
In transformer layers: MoE typically replaces the feedforward network (FFN) sublayer:
$$\text{FFN}{\text{MoE}}(\mathbf{x}) = \sum{k \in \text{Top-}K} g_k(\mathbf{x}) \cdot \text{FFN}_k(\mathbf{x})$$
Each expert is a separate FFN with distinct parameters, enabling specialization.
End-to-end backpropagation: With soft gating, gradients flow through all experts. With hard (Top-K) gating, gradients only flow through selected experts.
Straight-through estimator: For hard gating, use soft gating in the backward pass to enable gradient flow to non-selected experts (for the gating network).
Gradient scaling: Experts receiving fewer tokens may have unstable gradients. Some implementations scale gradients by routing frequency.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
import torchimport torch.nn as nnimport torch.nn.functional as F class MoELayer(nn.Module): """Mixture of Experts layer for transformers.""" def __init__(self, d_model, d_ff, n_experts, top_k=2, capacity_factor=1.25): super().__init__() self.d_model = d_model self.n_experts = n_experts self.top_k = top_k self.capacity_factor = capacity_factor # Gating network self.gate = nn.Linear(d_model, n_experts, bias=False) # Expert networks (each is a 2-layer FFN) self.experts = nn.ModuleList([ nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model) ) for _ in range(n_experts) ]) def forward(self, x, training=True): """ Args: x: Input tensor of shape (batch, seq_len, d_model) Returns: Output tensor of same shape, plus auxiliary loss """ batch_size, seq_len, d_model = x.shape x_flat = x.view(-1, d_model) # (batch * seq_len, d_model) n_tokens = x_flat.shape[0] # Compute gating scores logits = self.gate(x_flat) # (n_tokens, n_experts) # Add noise during training for exploration if training: noise = torch.randn_like(logits) * 0.1 logits = logits + noise # Soft gating probabilities (for auxiliary loss) probs = F.softmax(logits, dim=-1) # Top-K selection top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1) top_k_gates = F.softmax(top_k_logits, dim=-1) # Renormalize # Compute expert outputs (naive implementation) # Production code would batch this efficiently output = torch.zeros_like(x_flat) for k in range(self.top_k): expert_indices = top_k_indices[:, k] # Which expert for each token gates = top_k_gates[:, k:k+1] # Weight for this selection # Group tokens by expert for e in range(self.n_experts): mask = (expert_indices == e) if mask.any(): expert_input = x_flat[mask] expert_output = self.experts[e](expert_input) output[mask] += gates[mask] * expert_output # Auxiliary load balancing loss # f_k: fraction of tokens assigned to expert k # P_k: average gating probability for expert k f = torch.zeros(self.n_experts, device=x.device) for k in range(self.top_k): for e in range(self.n_experts): f[e] += (top_k_indices[:, k] == e).float().mean() / self.top_k P = probs.mean(dim=0) aux_loss = self.n_experts * (f * P).sum() output = output.view(batch_size, seq_len, d_model) return output, aux_lossMixture of Experts models find application wherever different subsets of the input space benefit from specialized treatment.
Large Language Models: MoE enables scaling to trillion+ parameters while maintaining tractable inference. Different experts may specialize in:
Machine Translation: Language pairs with different linguistic properties benefit from expert specialization. One expert might handle Romance languages, another Germanic, etc.
Vision-Language Models: Experts can specialize in:
User Modeling: Different user segments (casual browsers, power users, bargain hunters) exhibit distinct behavior patterns. MoE naturally segments users and applies appropriate models.
MoE differs from ensemble methods like random forests or boosting. Ensembles combine predictions from independently trained models. MoE trains experts jointly with a learned routing mechanism. The gating network is trained to identify which expert is most appropriate for each input, enabling specialization that ensembles cannot achieve.
This page has provided a comprehensive treatment of Mixture of Experts, from classical formulations to modern neural network implementations.
What's Next: Hidden Markov Models
Mixture of Experts addresses input-dependent gating but assumes observations are independent. Many real-world problems involve sequential data where past observations inform future ones. The next page introduces Hidden Markov Models (HMMs), which extend mixture models to sequences by introducing temporal dependencies between latent states. This leads to elegant dynamic programming algorithms for inference and learning.
You now understand Mixture of Experts from classical probabilistic formulations to modern neural implementations. The key insight is input-dependent gating: different regions of the input space route to specialized experts. This framework underpins many state-of-the-art large-scale models. Next, we'll explore how latent variable models handle sequential dependencies.