Loading problem...
Sparse Mixture-of-Experts (MoE) architectures represent one of the most transformative paradigms in modern deep learning, enabling models to achieve unprecedented scale while maintaining computational efficiency. Unlike dense neural networks where every parameter participates in every computation, MoE models employ a conditional computation strategy—activating only a subset of specialized "experts" for each input token.
In a Sparse MoE layer, the architecture consists of two key components:
Modern flagship models leverage this architecture at massive scale:
The routing process operates as follows:
Step 1 - Score Computation: For each input token, the router computes raw logit scores for all N experts: $$\text{router_logits} = W_r \cdot x$$
Step 2 - Expert Selection: Select the top-K experts based on their logit scores. Let $\mathcal{T}_k$ denote the set of indices for the top-K experts.
Step 3 - Weight Normalization: Apply softmax only over the selected experts to compute normalized routing weights: $$w_i = \frac{\exp(\text{logit}i)}{\sum{j \in \mathcal{T}_k} \exp(\text{logit}_j)} \quad \text{for } i \in \mathcal{T}_k$$
Step 4 - Expert Output Aggregation: Combine the outputs of selected experts using the computed weights: $$\text{output} = \sum_{i \in \mathcal{T}_k} w_i \cdot E_i(x)$$
where $E_i(x)$ represents the output of expert $i$ for input $x$.
Implement the complete top-K routing logic for a Sparse MoE layer. Given the router logits for each token, the pre-computed outputs from all experts, and the sparsity parameter K, your function should:
Key Insight: The softmax normalization occurs after expert selection, meaning we compute a probability distribution only over the K selected experts, not over all N experts. This ensures routing weights sum to 1.0 for each token.
router_logits = [[2.0, 1.0, 0.5, 0.1]]
expert_outputs = [[[1, 0], [0, 1], [1, 1], [0, 0]]]
k = 2[[0.731, 0.269]]Single Token Routing with K=2
For the single token, we have 4 experts with logits [2.0, 1.0, 0.5, 0.1].
Step 1 - Expert Selection: Top-2 experts are: • Expert 0: logit = 2.0 (highest) • Expert 1: logit = 1.0 (second highest)
Step 2 - Weight Computation: Softmax over selected logits [2.0, 1.0]: • w₀ = exp(2.0) / (exp(2.0) + exp(1.0)) = 7.389 / (7.389 + 2.718) ≈ 0.731 • w₁ = exp(1.0) / (exp(2.0) + exp(1.0)) = 2.718 / (7.389 + 2.718) ≈ 0.269
Step 3 - Output Aggregation: • Expert 0 output: [1, 0] • Expert 1 output: [0, 1] • Final = 0.731 × [1, 0] + 0.269 × [0, 1] = [0.731, 0.269]
router_logits = [[1.0, 2.0, 0.5], [3.0, 1.0, 0.5]]
expert_outputs = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 1, 0], [0, 1, 1], [1, 0, 1]]]
k = 1[[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]Multi-Token Routing with K=1 (Hard Routing)
With K=1, each token is routed to exactly one expert (hard routing).
Token 0: Logits = [1.0, 2.0, 0.5] • Top expert: Expert 1 (logit = 2.0) • Softmax of single value = 1.0 • Output = 1.0 × [0, 1, 0] = [0.0, 1.0, 0.0]
Token 1: Logits = [3.0, 1.0, 0.5] • Top expert: Expert 0 (logit = 3.0) • Softmax of single value = 1.0 • Output = 1.0 × [1, 1, 0] = [1.0, 1.0, 0.0]
router_logits = [[1.5, 2.5, 3.5, 0.5], [4.0, 3.0, 2.0, 1.0]]
expert_outputs = [[[1, 0], [0, 1], [1, 1], [0, 0]], [[2, 1], [1, 2], [0, 1], [1, 0]]]
k = 3[[0.755, 0.91], [1.575, 1.245]]Multi-Token Routing with K=3
Token 0: Logits = [1.5, 2.5, 3.5, 0.5] • Top-3 experts: Expert 2 (3.5), Expert 1 (2.5), Expert 0 (1.5) • Softmax over [3.5, 2.5, 1.5]:
Token 1: Logits = [4.0, 3.0, 2.0, 1.0] • Top-3 experts: Expert 0 (4.0), Expert 1 (3.0), Expert 2 (2.0) • Softmax over [4.0, 3.0, 2.0]:
Constraints