Loading content...
When fine-tuning large language models (LLMs) using Reinforcement Learning from Human Feedback (RLHF), a critical challenge emerges: the model may optimize so aggressively for the reward signal that it loses its foundational language capabilities—a phenomenon known as catastrophic forgetting or reward hacking.
To address this, modern RLHF systems incorporate a pre-training regularization term (often called the PTX term) that encourages the model to maintain its performance on the original pre-training distribution while simultaneously optimizing for human preferences.
The combined loss function is defined as:
$$\mathcal{L}{\text{total}} = \mathcal{L}{\text{RL}} + \beta \cdot \mathcal{L}_{\text{PTX}}$$
Where:
For each sample in the pre-training batch, the cross-entropy loss is computed as:
$$\mathcal{L}{\text{CE}}^{(i)} = -\log\left(\frac{\exp(z_y)}{\sum{j=1}^{V} \exp(z_j)}\right) = -z_y + \log\left(\sum_{j=1}^{V} \exp(z_j)\right)$$
Where:
The average cross-entropy across the batch is:
$$\mathcal{L}{\text{PTX}} = \frac{1}{N} \sum{i=1}^{N} \mathcal{L}_{\text{CE}}^{(i)}$$
For numerical stability, use the log-sum-exp trick when computing softmax probabilities:
$$\log\sum_j \exp(z_j) = \max_j(z_j) + \log\sum_j \exp(z_j - \max_j(z_j))$$
Your Task: Implement a function that computes the RLHF pre-training regularization loss. Given the RL loss, model logits on a pre-training batch, true labels, and the beta coefficient, return a tuple containing:
All values should be rounded to 6 decimal places.
rl_loss = 0.5
logits = [[10, 0, 0], [0, 10, 0]]
labels = [0, 1]
beta = 0.1(0.500009, 0.000009, 0.000091)The logits strongly favor the correct classes (10 vs 0). After softmax: • Sample 1: P(class 0) ≈ 0.99995, CE ≈ 0.0000454 • Sample 2: P(class 1) ≈ 0.99995, CE ≈ 0.0000454
Average CE ≈ 0.0000908 PTX component = 0.1 × 0.0000908 ≈ 0.0000091 Total loss = 0.5 + 0.0000091 ≈ 0.500009
The minimal PTX penalty reflects that the model already performs well on this pre-training data.
rl_loss = 1.0
logits = [[1, 1, 1], [1, 1, 1]]
labels = [0, 2]
beta = 0.5(1.549306, 0.549306, 1.098612)With uniform logits [1, 1, 1], softmax gives P = [1/3, 1/3, 1/3] for each class. • Cross-entropy = -log(1/3) = log(3) ≈ 1.098612 for each sample
Average CE ≈ 1.098612 PTX component = 0.5 × 1.098612 ≈ 0.549306 Total loss = 1.0 + 0.549306 ≈ 1.549306
The significant PTX penalty indicates the model is uncertain on this pre-training data.
rl_loss = 0.25
logits = [[5, 0, 0], [0, 0, 5]]
labels = [0, 1]
beta = 0.2(0.752677, 0.502677, 2.513386)• Sample 1: Logits [5,0,0] → P(class 0) ≈ 0.9933 → CE ≈ 0.00671 (correct prediction) • Sample 2: Logits [0,0,5] → P(class 1) ≈ 0.0066 → CE ≈ 5.01354 (wrong prediction)
Average CE ≈ (0.00671 + 5.01354) / 2 ≈ 2.510125 PTX component = 0.2 × 2.510125 ≈ 0.502025 Total loss = 0.25 + 0.502025 ≈ 0.752025
The high PTX penalty for the second sample reflects a mismatch between prediction and label.
Constraints