Loading problem...
In modern transformer architectures, attention mechanisms compute compatibility scores between queries and keys to determine how much each position should attend to others. However, when these scores grow too large, they can cause the softmax function to saturate, leading to vanishing gradients and numerical instability during training.
The Attention Score Clipping Stabilizer is a sophisticated technique that dynamically rescales the query (W_q) and key (W_k) projection weight matrices to ensure that the maximum pre-softmax attention score remains bounded within a safe threshold. This prevents numerical overflow and maintains healthy gradient flow throughout the network.
Given input features x of shape (batch, seq_len, d_model) and projection matrices W_q and W_k, the attention score for positions i and j is computed as:
$$\text{score}{ij} = \frac{(x_i W_q) \cdot (x_j W_k)^T}{\sqrt{d{head}}}$$
where d_head is the dimension of the projected queries/keys (number of columns in W_q or W_k).
Your function must:
Compute the maximum pre-clip score: Project all input features through W_q and W_k, compute all pairwise dot products, scale by 1/√d_head, and find the maximum absolute value across all batches and positions.
Apply conditional rescaling: If the maximum score exceeds threshold t:
This fractional distribution ensures the combined effect reduces the maximum score exactly to t, since: $$\text{new_score}{max} = \text{old_score}{max} \cdot \eta^\alpha \cdot \eta^{(1-\alpha)} = \text{old_score}_{max} \cdot \eta = t$$
Return the results: Provide the (possibly rescaled) weight matrices, a boolean indicating whether clipping occurred, and the final maximum score.
The alpha parameter controls how the scaling is distributed between W_q and W_k:
Important: Round all floating-point values in the returned matrices and the final score to 4 decimal places for reproducibility.
W_q = [[2.0, 0.0], [0.0, 2.0]]
W_k = [[2.0, 0.0], [0.0, 2.0]]
x = [[[1.0, 0.0], [0.0, 1.0]]]
t = 1.0
alpha = 0.5([[1.1892, 0.0], [0.0, 1.1892]], [[1.1892, 0.0], [0.0, 1.1892]], True, 1.0)Step 1 - Compute Projections: • Q = x @ W_q gives us query vectors [[2.0, 0.0], [0.0, 2.0]] for the sequence • K = x @ W_k gives us key vectors [[2.0, 0.0], [0.0, 2.0]]
Step 2 - Calculate Attention Scores: • The QK^T matrix = [[4.0, 0.0], [0.0, 4.0]] • Scaling by 1/√d_head = 1/√2 ≈ 0.7071 • Scaled scores = [[2.8284, 0.0], [0.0, 2.8284]] • Maximum score = 2.8284, which exceeds threshold t = 1.0
Step 3 - Apply Clipping: • η = t / max_score = 1.0 / 2.8284 ≈ 0.3536 • With α = 0.5, both matrices are scaled by √η ≈ 0.5946 • New W_q = W_k = 2.0 × 0.5946 ≈ 1.1892
The maximum score is now exactly 1.0, matching the threshold.
W_q = [[0.5, 0.0], [0.0, 0.5]]
W_k = [[0.5, 0.0], [0.0, 0.5]]
x = [[[1.0, 0.0], [0.0, 1.0]]]
t = 10.0
alpha = 0.5([[0.5, 0.0], [0.0, 0.5]], [[0.5, 0.0], [0.0, 0.5]], False, 0.1768)Step 1 - Compute Projections: • Q = x @ W_q = [[0.5, 0.0], [0.0, 0.5]] • K = x @ W_k = [[0.5, 0.0], [0.0, 0.5]]
Step 2 - Calculate Attention Scores: • QK^T = [[0.25, 0.0], [0.0, 0.25]] • Scaled by 1/√2: [[0.1768, 0.0], [0.0, 0.1768]] • Maximum score = 0.1768
Step 3 - Check Threshold: • 0.1768 < 10.0, so no clipping is needed • Weights remain unchanged, clipped = False
W_q = [[1.5, 0.5], [0.5, 1.5]]
W_k = [[1.5, 0.5], [0.5, 1.5]]
x = [[[1.0, 1.0], [1.0, 0.0], [0.0, 1.0]]]
t = 2.0
alpha = 0.3([[1.0981, 0.366], [0.366, 1.0981]], [[0.7245, 0.2415], [0.2415, 0.7245]], True, 2.0)Step 1 - Asymmetric Alpha Distribution: With α = 0.3, more scaling is applied to W_k than W_q.
Step 2 - Compute Maximum Score: After projecting X through both weight matrices and computing scaled dot products, the maximum attention score exceeds the threshold of 2.0.
Step 3 - Apply Asymmetric Scaling: • W_q is scaled by η^0.3 (smaller adjustment) • W_k is scaled by η^0.7 (larger adjustment) • The product η^0.3 × η^0.7 = η ensures the final max score equals exactly t = 2.0
This asymmetric distribution can be useful when you want to preserve more of the original query space while adjusting keys more aggressively.
Constraints