Loading content...
In modern large language model (LLM) training, inference efficiency is just as crucial as model quality. During reinforcement learning from human feedback (RLHF) or similar policy optimization phases, it's often desirable to encourage models to generate concise, efficient responses that don't waste computational resources.
This problem introduces a budget-aware policy optimization objective that combines three key components:
When a generated response exceeds a predefined token budget B, a linear penalty is applied proportional to the excess tokens:
$$\text{adjusted_reward}_i = r_i - \beta \cdot \max(0, L_i - B)$$
Where:
rᵢ is the original reward for response iLᵢ is the token length of response iB is the token budget thresholdβ is the budget penalty coefficientResponses within the budget (Lᵢ ≤ B) receive no penalty.
For each prompt, compute a baseline as the mean of adjusted rewards across all sampled responses. The advantage for each response is:
$$A_i = \text{adjusted_reward}_i - \text{baseline}$$
This centers the advantages around zero, reducing variance in the policy gradient.
The final loss combines the squared advantage with a KL divergence term between the current policy and a reference policy:
$$\mathcal{L} = \mathbb{E}\left[(A_i - \alpha \cdot (\log \pi_\theta(y|x) - \log \pi_{\text{ref}}(y|x)))^2\right]$$
Where:
α is the KL regularization coefficientlog πθ(y|x) is the current policy's log probabilitylog πref(y|x) is the reference policy's log probabilityThe KL term prevents the policy from deviating too far from the reference, ensuring stable training.
Implement the complete budget-aware RL loss computation:
Round your final answer to 4 decimal places.
rewards = [[1.0, 0.5]]
log_probs = [[-1.0, -1.5]]
old_log_probs = [[-1.2, -1.3]]
response_lengths = [[150, 80]]
token_budget = 100
kl_coef = 0.1
budget_penalty_coef = 0.010.0004Step 1: Apply Budget Penalties • Response 1: length = 150 > budget (100), excess = 50 Penalty = 0.01 × 50 = 0.5 Adjusted reward = 1.0 - 0.5 = 0.5 • Response 2: length = 80 ≤ budget (100), no penalty Adjusted reward = 0.5
Step 2: Compute Baseline and Advantages • Baseline = mean([0.5, 0.5]) = 0.5 • Advantage₁ = 0.5 - 0.5 = 0.0 • Advantage₂ = 0.5 - 0.5 = 0.0
Step 3: Compute KL Terms • KL₁ = kl_coef × (log_prob₁ - old_log_prob₁) = 0.1 × (-1.0 - (-1.2)) = 0.1 × 0.2 = 0.02 • KL₂ = kl_coef × (log_prob₂ - old_log_prob₂) = 0.1 × (-1.5 - (-1.3)) = 0.1 × (-0.2) = -0.02
Step 4: Compute Squared Loss • Loss₁ = (A₁ - KL₁)² = (0.0 - 0.02)² = 0.0004 • Loss₂ = (A₂ - KL₂)² = (0.0 - (-0.02))² = 0.0004 • Final Loss = mean([0.0004, 0.0004]) = 0.0004
rewards = [[1.0, 2.0]]
log_probs = [[-0.5, -0.8]]
old_log_probs = [[-0.5, -0.8]]
response_lengths = [[50, 60]]
token_budget = 100
kl_coef = 0.1
budget_penalty_coef = 0.010.25Step 1: Apply Budget Penalties Both responses are within the budget (50 < 100 and 60 < 100), so no penalties apply. • Adjusted rewards = [1.0, 2.0]
Step 2: Compute Baseline and Advantages • Baseline = mean([1.0, 2.0]) = 1.5 • Advantage₁ = 1.0 - 1.5 = -0.5 • Advantage₂ = 2.0 - 1.5 = 0.5
Step 3: Compute KL Terms Current and reference policies are identical, so: • KL₁ = 0.1 × (-0.5 - (-0.5)) = 0.0 • KL₂ = 0.1 × (-0.8 - (-0.8)) = 0.0
Step 4: Compute Squared Loss • Loss₁ = (-0.5 - 0.0)² = 0.25 • Loss₂ = (0.5 - 0.0)² = 0.25 • Final Loss = mean([0.25, 0.25]) = 0.25
rewards = [[1.0, 0.5], [0.8, 1.2]]
log_probs = [[-1.0, -1.2], [-0.8, -0.9]]
old_log_probs = [[-1.1, -1.1], [-0.7, -1.0]]
response_lengths = [[120, 80], [90, 110]]
token_budget = 100
kl_coef = 0.05
budget_penalty_coef = 0.020.0055Step 1: Apply Budget Penalties Prompt 1: • Response 1: 120 > 100, penalty = 0.02 × 20 = 0.4 → adjusted = 1.0 - 0.4 = 0.6 • Response 2: 80 ≤ 100, no penalty → adjusted = 0.5
Prompt 2: • Response 1: 90 ≤ 100, no penalty → adjusted = 0.8 • Response 2: 110 > 100, penalty = 0.02 × 10 = 0.2 → adjusted = 1.2 - 0.2 = 1.0
Step 2: Compute Baselines and Advantages Prompt 1: baseline = 0.55 → advantages = [0.05, -0.05] Prompt 2: baseline = 0.9 → advantages = [-0.1, 0.1]
Step 3: Compute KL Terms Prompt 1: KL = 0.05 × [0.1, -0.1] = [0.005, -0.005] Prompt 2: KL = 0.05 × [-0.1, 0.1] = [-0.005, 0.005]
Step 4: Compute Squared Loss All losses: [(0.05-0.005)², (-0.05-(-0.005))², (-0.1-(-0.005))², (0.1-0.005)²] = [0.002025, 0.002025, 0.009025, 0.009025] • Final Loss = mean = 0.0055
Constraints