Loading problem...
In modern deep learning optimization, one of the most widely adopted techniques is the decoupled weight decay regularization approach, commonly known as AdamW. This algorithm addresses a critical flaw in traditional Adam optimization: the coupling of weight decay with the gradient-based update, which can lead to suboptimal regularization behavior.
The standard Adam optimizer combines adaptive learning rates with momentum, but when L2 regularization is applied by adding the regularization gradient directly to the loss gradient, the effective regularization strength varies inversely with the adaptive learning rate scaling. This means heavily updated parameters receive less regularization, while rarely updated parameters receive more—the opposite of what's typically desired.
AdamW solves this by applying weight decay directly to the parameters rather than incorporating it into the gradient computation. This decoupling ensures consistent regularization regardless of the gradient history.
Given a parameter vector w and its gradient g at time step t, the algorithm proceeds as follows:
Step 1: Update biased first moment estimate (momentum) $$m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t$$
Step 2: Update biased second moment estimate (adaptive learning rate) $$v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2$$
Step 3: Compute bias-corrected estimates $$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}$$ $$\hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$
Step 4: Update parameters with decoupled weight decay $$w_t = w_{t-1} - \eta \cdot \left( \frac{\hat{m}_t}{\sqrt{\hat{v}t} + \epsilon} + \lambda \cdot w{t-1} \right)$$
Where:
Implement a function adamw_update(w, g, m, v, t, lr, beta1, beta2, epsilon, weight_decay) that performs one complete optimization step using the decoupled weight decay algorithm. The function should:
w = [1.0, 2.0]
g = [0.1, -0.2]
m = [0.0, 0.0]
v = [0.0, 0.0]
t = 1
lr = 0.01
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
weight_decay = 0.1w_new = [0.989, 2.008]
m_new = [0.01, -0.02]
v_new = [0.0, 0.0]First Iteration Analysis:
First moment update:
Second moment update:
Bias correction (t=1):
Parameter update with decoupled weight decay:
The parameters move opposite to the gradient direction while being regularized toward zero by the weight decay term.
w = [0.5, -0.5, 1.0]
g = [0.2, 0.3, -0.1]
m = [0.01, -0.02, 0.03]
v = [0.001, 0.002, 0.001]
t = 5
lr = 0.001
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
weight_decay = 0.01w_new = [0.4998, -0.5, 0.9999]
m_new = [0.029, 0.012, 0.017]
v_new = [0.001, 0.0021, 0.001]Mid-Training Update (t=5):
With pre-existing momentum (m) and adaptive learning rate history (v), this represents a typical mid-training scenario where the optimizer has already "warmed up."
First moment update:
Second moment update:
Bias correction factor at t=5:
The smaller learning rate (0.001) and lower weight decay (0.01) result in more conservative parameter updates, typical of fine-tuning scenarios.
w = [1.0, -1.0, 0.5, -0.5]
g = [0.1, 0.2, -0.1, -0.2]
m = [0.0, 0.0, 0.0, 0.0]
v = [0.0, 0.0, 0.0, 0.0]
t = 1
lr = 0.01
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
weight_decay = 0.05w_new = [0.9895, -1.0095, 0.5097, -0.4898]
m_new = [0.01, 0.02, -0.01, -0.02]
v_new = [0.0, 0.0, 0.0, 0.0]4-Dimensional Parameter Space:
This example demonstrates the algorithm's behavior across a higher-dimensional parameter space with mixed positive and negative values.
Notice the asymmetric updates:
The decoupled weight decay uniformly shrinks all parameters toward zero by a factor of (1 - lr × weight_decay) = (1 - 0.01 × 0.05) = 0.9995, independent of the gradient magnitude.
This is the key insight of AdamW: regularization strength is consistent across all parameters, regardless of their update frequency or gradient variance.
Constraints