Loading content...
In modern large-scale neural network architectures, Mixture-of-Experts (MoE) models have emerged as a powerful paradigm for achieving massive model capacity while maintaining computational efficiency. The key innovation lies in conditional computation—instead of using all model parameters for every input, an intelligent routing mechanism selects only a subset of specialized "expert" sub-networks to process each input.
At the heart of MoE architectures is the gating mechanism, which determines how to distribute inputs across experts. The Stochastic Top-K Expert Routing mechanism combines several sophisticated components to achieve both efficiency and training stability:
Step 1: Computing Raw Gating Logits
The raw affinity of each input for each expert is computed through a linear transformation:
$$H(x) = x \cdot W_g$$
where x is an input vector of dimension d, W_g ∈ ℝ^(d × e) is the gating weight matrix, and e is the number of experts.
Step 2: Adding Controlled Noise
To encourage exploration during training and prevent routing collapse (where only a few experts receive all inputs), the mechanism injects tunable noise into the gating logits:
$$\text{NoisyLogits}(x) = H(x) + \text{Softplus}(x \cdot W_{noise}) \odot N$$
where:
Step 3: Sparse Top-K Selection
Rather than computing a full softmax over all experts (which would be expensive and defeat the purpose of sparsity), we select only the top-k experts:
$$G(x) = \text{Softmax}(\text{KeepTopK}(\text{NoisyLogits}(x), k))$$
The resulting gating vector G(x) is sparse—exactly k positions have non-zero values that sum to 1.
Your Task:
Implement the stochastic top-k expert routing function that takes an input feature matrix, gating weights, noise weights, pre-sampled noise, and a sparsity parameter k, and returns the final sparse routing probabilities for each input.
Note: Round all final probability values to 3 decimal places for comparison purposes.
X = [[1.0, 2.0]]
W_g = [[1.0, 0.0], [0.0, 1.0]]
W_noise = [[0.5, 0.5], [0.5, 0.5]]
N = [[1.0, -1.0]]
k = 2[[0.881, 0.119]]Step 1: Compute raw gating logits H = X @ W_g = [[1.0, 2.0]] @ [[1.0, 0.0], [0.0, 1.0]] = [[1.0, 2.0]]
Step 2: Compute noise scales using Softplus noise_input = X @ W_noise = [[1.0, 2.0]] @ [[0.5, 0.5], [0.5, 0.5]] = [[1.5, 1.5]] noise_scales = Softplus([[1.5, 1.5]]) = log(1 + e^1.5) ≈ [[1.701, 1.701]]
Step 3: Add scaled noise scaled_noise = noise_scales ⊙ N = [[1.701, 1.701]] ⊙ [[1.0, -1.0]] = [[1.701, -1.701]] noisy_logits = H + scaled_noise = [[1.0 + 1.701, 2.0 - 1.701]] = [[2.701, 0.299]]
Step 4: Apply Top-K selection (k=2, so all experts kept) Since k=2 and we have 2 experts, all logits are kept.
Step 5: Apply Softmax exp_logits = [e^2.701, e^0.299] ≈ [14.90, 1.35] sum = 16.25 probabilities = [14.90/16.25, 1.35/16.25] ≈ [0.917, 0.083]
Rounded to 3 decimals: [[0.881, 0.119]]
The first expert receives ~88.1% of the routing weight due to its higher noisy logit value.
X = [[1.0, 0.0], [0.0, 1.0]]
W_g = [[1.0, 0.5, 0.0], [0.0, 0.5, 1.0]]
W_noise = [[0.1, 0.1, 0.1], [0.1, 0.1, 0.1]]
N = [[0.5, -0.5, 0.2], [0.1, 0.0, -0.1]]
k = 2[[0.646, 0.354, 0.0], [0.0, 0.38, 0.62]]For the first input [1.0, 0.0]: • Raw logits: [1.0, 0.5, 0.0] • After noise: the top 2 logits correspond to experts 0 and 1 • Expert 2's logit is masked to -∞, giving it 0 probability • Softmax over remaining: [0.646, 0.354, 0.0]
For the second input [0.0, 1.0]: • Raw logits: [0.0, 0.5, 1.0] • After noise: the top 2 logits correspond to experts 1 and 2 • Expert 0's logit is masked to -∞, giving it 0 probability • Softmax over remaining: [0.0, 0.38, 0.62]
Each input is routed to exactly k=2 experts, with probabilities summing to 1.
X = [[1.0, 1.0]]
W_g = [[2.0, 1.0, 0.5], [0.0, 1.0, 0.5]]
W_noise = [[0.1, 0.1, 0.1], [0.1, 0.1, 0.1]]
N = [[0.0, 0.0, 0.0]]
k = 1[[0.0, 1.0, 0.0]]Step 1: Compute raw gating logits H = [[1.0, 1.0]] @ [[2.0, 1.0, 0.5], [0.0, 1.0, 0.5]] = [[2.0, 2.0, 1.0]]
Step 2: With zero noise (N = [[0.0, 0.0, 0.0]]), noisy logits equal raw logits noisy_logits = [[2.0, 2.0, 1.0]]
Step 3: Apply Top-K selection with k=1 The maximum values are 2.0 at positions 0 and 1 (tie). When there's a tie, the algorithm selects the expert with the higher index. Only expert 1 is kept, others are masked to -∞.
Step 4: Apply Softmax With only one valid expert, it receives 100% of the probability: [[0.0, 1.0, 0.0]]
This demonstrates extreme sparsity where each token is processed by a single expert.
Constraints