Loading problem...
The Mixture-of-Experts (MoE) architecture is a powerful paradigm in deep learning that enables efficient scaling of neural networks by conditionally activating only a subset of model parameters for each input. Instead of processing every input through the entire network, MoE employs a gating mechanism to dynamically route inputs to the most relevant "expert" sub-networks, dramatically improving computational efficiency while maintaining or even increasing model capacity.
In a Sparse MoE layer, we have:
This sparsity is the key innovation—while the model has massive total capacity (all experts combined), the computation cost per input scales with k, not E.
Given an input tensor x of shape (batch_size, sequence_length, d_input), the Sparse MoE forward pass proceeds as follows:
For each token, compute the raw gating scores by multiplying with the gating weight matrix: $$\text{logits} = x \cdot W_g$$ where W_g has shape (d_input, num_experts), producing logits of shape (batch_size, seq_len, num_experts).
Convert logits to probabilities using the softmax function: $$g_i = \frac{e^{\text{logits}i}}{\sum{j=1}^{E} e^{\text{logits}_j}}$$
Select the indices of the top-k experts with the highest gating probabilities for each token. This ensures sparse activation—only k experts are utilized per token.
Renormalize the gating probabilities for only the selected experts so they sum to 1: $$\hat{g}i = \frac{g_i}{\sum{j \in \text{top-k}} g_j}$$
For each selected expert, compute its output by applying its weight matrix to the input token: $$\text{expert_output}i = x \cdot W{e_i}$$ where W_e has shape (num_experts, d_input, d_output).
Combine the expert outputs using the renormalized gating weights: $$y = \sum_{i \in \text{top-k}} \hat{g}_i \cdot \text{expert_output}_i$$
Implement the sparse_moe_forward function that computes the forward pass of a Sparse Mixture-of-Experts layer. Your function should:
x = [[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]], [[6.0, 7.0], [8.0, 9.0], [10.0, 11.0]]]
We = [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]
Wg = [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]
top_k = 1[[[1.0, 1.0], [5.0, 5.0], [9.0, 9.0]], [[13.0, 13.0], [17.0, 17.0], [21.0, 21.0]]]Input Analysis:
Step-by-step Computation (for first token [0.0, 1.0]):
For token [2.0, 3.0]: output = [2+3, 2+3] = [5.0, 5.0] For token [4.0, 5.0]: output = [4+5, 4+5] = [9.0, 9.0]
And similarly for the second batch, producing the final output tensor.
x = [[[1.0, 2.0], [3.0, 4.0]]]
We = [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]
Wg = [[0.5, 0.5], [0.5, 0.5]]
top_k = 2[[[3.0, 3.0], [7.0, 7.0]]]Input Analysis:
Step-by-step Computation (for token [1.0, 2.0]):
For token [3.0, 4.0]: 0.5 × [7.0, 7.0] + 0.5 × [7.0, 7.0] = [7.0, 7.0]
x = [[[1.0, 1.0]]]
We = [[[2.0, 0.0], [0.0, 2.0]], [[1.0, 0.0], [0.0, 1.0]]]
Wg = [[1.0, 0.0], [1.0, 0.0]]
top_k = 1[[[2.0, 2.0]]]Input Analysis:
Step-by-step Computation:
The result [2.0, 2.0] shows that Expert 0's 2× scaling transformation was applied.
Constraints