Loading content...
In deep learning optimization, gradients can explode during backpropagation—growing to extremely large magnitudes that destabilize training and cause parameter updates to overshoot optimal values. This phenomenon, known as the exploding gradient problem, is particularly prevalent in recurrent neural networks (RNNs), deep transformers, and networks with many layers.
Global Norm Gradient Rescaling is a powerful stabilization technique that addresses this issue by constraining the total magnitude of all gradients across an entire model. Unlike per-parameter clipping methods, this approach maintains the relative proportions between different gradient components while ensuring the combined magnitude stays within safe bounds.
Given a collection of gradient arrays (G = {g_1, g_2, ..., g_k}) representing gradients for different parameters and a maximum norm threshold (\tau), the algorithm works as follows:
Step 1: Compute the Global L2 Norm
Calculate the global norm by treating all gradient values across all arrays as a single flattened vector:
$$|G|2 = \sqrt{\sum{i=1}^{k} \sum_{j} g_{i,j}^2}$$
This represents the total Euclidean magnitude of all gradients combined.
Step 2: Determine the Scaling Factor
If the global norm exceeds the threshold, compute a scaling factor to bring it exactly to the threshold:
$$\alpha = \frac{\tau}{|G|_2} \quad \text{if } |G|_2 > \tau$$
If the global norm is already within bounds ((|G|_2 \leq \tau)), set (\alpha = 1) (no scaling needed).
Step 3: Apply Uniform Rescaling
Multiply every gradient value by the scaling factor:
$$g'{i,j} = \alpha \cdot g{i,j}$$
This ensures all gradients are scaled proportionally, preserving their relative directions and magnitudes while constraining the total norm.
Write a function that takes a list of gradient arrays and a maximum norm threshold, then returns the rescaled gradients. The function should preserve the original structure of the input (same number of arrays, same dimensions per array).
gradients = [[3.0, 4.0], [0.0, 0.0]]
max_norm = 1.0[[0.6, 0.8], [0.0, 0.0]]First, compute the global L2 norm across all gradient values:
||G||₂ = √(3.0² + 4.0² + 0.0² + 0.0²) = √(9 + 16 + 0 + 0) = √25 = 5.0
Since 5.0 > 1.0 (global norm exceeds threshold), we need to rescale.
The scaling factor is: α = max_norm / global_norm = 1.0 / 5.0 = 0.2
Apply this factor to all gradients: • First array: [3.0 × 0.2, 4.0 × 0.2] = [0.6, 0.8] • Second array: [0.0 × 0.2, 0.0 × 0.2] = [0.0, 0.0]
The resulting rescaled gradients are [[0.6, 0.8], [0.0, 0.0]].
gradients = [[0.3, 0.4], [0.0, 0.0]]
max_norm = 1.0[[0.3, 0.4], [0.0, 0.0]]Compute the global L2 norm:
||G||₂ = √(0.3² + 0.4² + 0.0² + 0.0²) = √(0.09 + 0.16) = √0.25 = 0.5
Since 0.5 ≤ 1.0 (global norm is within the threshold), no rescaling is needed.
The gradients are returned unchanged: [[0.3, 0.4], [0.0, 0.0]].
gradients = [[1.0, 2.0, 2.0], [4.0, 0.0]]
max_norm = 2.5[[0.5, 1.0, 1.0], [2.0, 0.0]]This example demonstrates handling arrays of different sizes.
Compute the global L2 norm: ||G||₂ = √(1.0² + 2.0² + 2.0² + 4.0² + 0.0²) = √(1 + 4 + 4 + 16 + 0) = √25 = 5.0
Since 5.0 > 2.5, rescaling is required.
Scaling factor: α = 2.5 / 5.0 = 0.5
Apply to all gradients: • First array: [1.0 × 0.5, 2.0 × 0.5, 2.0 × 0.5] = [0.5, 1.0, 1.0] • Second array: [4.0 × 0.5, 0.0 × 0.5] = [2.0, 0.0]
Result: [[0.5, 1.0, 1.0], [2.0, 0.0]]
Constraints