Loading content...
The Muon optimizer (short for Momentum + Orthogonalization + Update normalizatioN) represents a cutting-edge approach to neural network optimization that combines classical momentum techniques with matrix orthogonalization for faster and more stable convergence.
Traditional optimizers like SGD with momentum accumulate gradients over time, but they can suffer from ill-conditioned updates when the gradient matrix has highly skewed singular values. The Muon optimizer addresses this by orthogonalizing the momentum-accumulated gradient matrix before applying updates, ensuring that the parameter update has well-balanced eigenvalue magnitudes.
Given:
The Muon update proceeds as follows:
Step 1: Momentum Accumulation $$B_{new} = \mu \cdot B_{prev} + G$$
Step 2: Newton-Schulz Orthogonalization Apply 5 iterations of the quintic Newton-Schulz method to orthogonalize (B_{new}):
Starting with (X_0 = B_{new}), iterate: $$X_{k+1} = \frac{1}{4} X_k \cdot (15I - (G^T \cdot G) \cdot (25I - G^T \cdot G \cdot (3I)))$$
where (G = X_k^T \cdot X_k). After 5 iterations, denote the result as O.
Step 3: Parameter Update $$\theta_{new} = \theta - \eta \cdot O$$
The Newton-Schulz iteration is an iterative method for computing the orthogonal polar factor of a matrix. Unlike SVD-based approaches, it avoids explicit eigenvalue decomposition and converges quadratically when the input matrix has singular values within a certain range.
The quintic variant used here (polynomial degree 5) converges faster than the classical cubic variant, making it suitable for real-time optimization loops in deep learning.
Implement the muon_update function that:
Note: Round all output values to 4 decimal places.
theta = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
grad = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
B_prev = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
mu = 0.9
lr = 0.01theta_new = [[1.0041, 0.9996, 0.9951], [0.9996, 0.998, 0.9964], [0.9951, 0.9964, 0.9978]]
B_new = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
O = [[-0.4144, 0.0395, 0.4933], [0.0395, 0.1973, 0.3552], [0.4933, 0.3552, 0.2171]]Step 1: Momentum Update Since B_prev is all zeros, B_new = 0.9 × 0 + grad = grad
Step 2: Newton-Schulz Orthogonalization The gradient matrix is orthogonalized through 5 quintic Newton-Schulz iterations. Starting from the raw gradient, each iteration progressively transforms the matrix toward an orthogonal form. The resulting O matrix has more balanced singular values.
Step 3: Parameter Update θ_new = θ - 0.01 × O Each element of theta is adjusted by the corresponding orthogonalized gradient scaled by the learning rate.
The orthogonalization ensures the update direction is well-conditioned, even when the original gradient has highly skewed eigenvalues (as in this rank-deficient gradient matrix).
theta = [[1.0, 2.0], [3.0, 4.0]]
grad = [[0.1, 0.2], [0.3, 0.4]]
B_prev = [[0.0, 0.0], [0.0, 0.0]]
mu = 0.9
lr = 0.01theta_new = [[1.0036, 1.9939], [2.9939, 3.9964]]
B_new = [[0.1, 0.2], [0.3, 0.4]]
O = [[-0.3638, 0.6063], [0.6063, 0.3638]]With a 2×2 gradient matrix, the Newton-Schulz iteration transforms the momentum buffer into an approximately orthogonal matrix O.
The momentum buffer B_new simply equals the gradient since B_prev is zero.
The orthogonalized matrix O shows the characteristic structure of a near-orthogonal matrix where OᵀO ≈ I. Notice how O has values that create balanced row and column norms.
The parameter update θ_new = θ - lr × O applies this well-conditioned update to the original parameters.
theta = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]
grad = [[0.5, -0.5, 0.3], [0.2, 0.1, -0.2], [-0.1, 0.4, 0.2]]
B_prev = [[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.1, 0.1, 0.1]]
mu = 0.9
lr = 0.1theta_new = [[1.9552, 2.0352, 1.9581], [1.9459, 1.9633, 2.027], [2.0083, 1.9508, 1.9499]]
B_new = [[0.59, -0.41, 0.39], [0.38, 0.28, -0.02], [-0.01, 0.49, 0.29]]
O = [[0.4479, -0.3516, 0.4192], [0.5408, 0.3669, -0.2701], [-0.0832, 0.4917, 0.5013]]This example demonstrates the full Muon algorithm with non-zero previous momentum.
Momentum Accumulation: B_new = 0.9 × B_prev + grad For example: B_new[0][0] = 0.9 × 0.1 + 0.5 = 0.59
Newton-Schulz Orthogonalization: The non-trivial momentum matrix undergoes 5 quintic iterations. The resulting O matrix maintains directional information while normalizing the update magnitude.
Parameter Update: With a larger learning rate (0.1), the updates are more substantial. θ_new[0][0] = 2.0 - 0.1 × 0.4479 ≈ 1.9552
The Muon optimizer's orthogonalization prevents any single direction from dominating the update, leading to more stable training dynamics.
Constraints