Loading content...
In modern machine learning systems, particularly large language models (LLMs), there's often a need to selectively forget or retain specific knowledge without retraining the entire model. This technique, known as Selective Gradient Masking (SGM), enables fine-grained control over which parts of a neural network are updated during training.
Neural network parameters can be partitioned into distinct groups based on their functional role:
The key insight is that by applying binary masks to gradients before the parameter update step, we can control which neurons are affected by each training batch.
Given a parameter vector θ, gradient vector g, a binary forget mask M (where 1 indicates forget parameters), and learning rate η, the update depends on the batch type:
For 'forget' batches (containing data we want the model to unlearn): $$\theta_{new} = \theta - \eta \cdot (M \odot g)$$
For 'retain' batches (containing data we want to preserve): $$\theta_{new} = \theta - \eta \cdot ((1 - M) \odot g)$$
For 'unlabeled' batches (general training data): $$\theta_{new} = \theta - \eta \cdot g$$
Where ⊙ denotes element-wise multiplication (Hadamard product).
Implement the function masked_gradient_step that applies a single gradient descent update to a 1D parameter vector using the selective masking approach described above. Based on the batch_type, your function should:
forget_mask[i] = 1 (mask the gradients for retain parameters to zero)forget_mask[i] = 0 (mask the gradients for forget parameters to zero)Return the updated parameter vector after applying the masked gradient step.
params = [1.0, 1.0, 1.0, 1.0]
grad = [0.1, 0.2, 0.3, 0.4]
forget_mask = [1, 1, 0, 0]
lr = 0.1
batch_type = 'forget'[0.99, 0.98, 1.0, 1.0]The forget_mask indicates that the first two parameters are in the 'forget' group (mask value = 1) and the last two are in the 'retain' group (mask value = 0).
For a 'forget' batch, only forget parameters should be updated. The effective gradients become: • Masked grad = [0.1 × 1, 0.2 × 1, 0.3 × 0, 0.4 × 0] = [0.1, 0.2, 0.0, 0.0]
Applying the update θ_new = θ - lr × masked_grad: • θ₀ = 1.0 - 0.1 × 0.1 = 0.99 • θ₁ = 1.0 - 0.1 × 0.2 = 0.98 • θ₂ = 1.0 - 0.1 × 0.0 = 1.0 (unchanged) • θ₃ = 1.0 - 0.1 × 0.0 = 1.0 (unchanged)
Result: [0.99, 0.98, 1.0, 1.0]
params = [1.0, 1.0, 1.0, 1.0]
grad = [0.1, 0.2, 0.3, 0.4]
forget_mask = [1, 1, 0, 0]
lr = 0.1
batch_type = 'retain'[1.0, 1.0, 0.97, 0.96]For a 'retain' batch, only retain parameters (where forget_mask = 0) should be updated.
The retain mask is the complement of forget_mask: [0, 0, 1, 1] The effective gradients become: • Masked grad = [0.1 × 0, 0.2 × 0, 0.3 × 1, 0.4 × 1] = [0.0, 0.0, 0.3, 0.4]
Applying the update θ_new = θ - lr × masked_grad: • θ₀ = 1.0 - 0.1 × 0.0 = 1.0 (unchanged) • θ₁ = 1.0 - 0.1 × 0.0 = 1.0 (unchanged) • θ₂ = 1.0 - 0.1 × 0.3 = 0.97 • θ₃ = 1.0 - 0.1 × 0.4 = 0.96
Result: [1.0, 1.0, 0.97, 0.96]
params = [1.0, 1.0, 1.0, 1.0]
grad = [0.1, 0.2, 0.3, 0.4]
forget_mask = [1, 1, 0, 0]
lr = 0.1
batch_type = 'unlabeled'[0.99, 0.98, 0.97, 0.96]For an 'unlabeled' batch, all parameters are updated normally without any masking.
Applying the standard gradient descent update θ_new = θ - lr × grad: • θ₀ = 1.0 - 0.1 × 0.1 = 0.99 • θ₁ = 1.0 - 0.1 × 0.2 = 0.98 • θ₂ = 1.0 - 0.1 × 0.3 = 0.97 • θ₃ = 1.0 - 0.1 × 0.4 = 0.96
Result: [0.99, 0.98, 0.97, 0.96]
Constraints