Loading content...
In reinforcement learning from human feedback (RLHF), policy optimization objectives balance exploration and stability by combining reward maximization with constraints that prevent the policy from deviating too far from a reference distribution. This problem introduces a sophisticated optimization framework that merges clipped surrogate objectives with importance-weighted divergence penalties.
You are tasked with implementing a Group Policy Optimization (GPO) objective function that combines two critical components:
The clipped surrogate objective constrains policy updates to a trust region, preventing catastrophically large changes. For each sample, the surrogate is computed as:
$$L^{CLIP}_i = \min\left(\rho_i \cdot A_i, ; \text{clip}(\rho_i, 1-\epsilon, 1+\epsilon) \cdot A_i\right)$$
Where:
To maintain proximity to a reference policy, we apply a KL divergence penalty using an unbiased importance-weighted estimator:
$$D_{KL,i} = \rho_i \cdot \left(r_i - \log(r_i) - 1\right)$$
Where:
The complete objective averages the clipped surrogate across all samples and subtracts the β-weighted KL penalty:
$$J(\theta) = \frac{1}{N} \sum_{i=1}^{N} L^{CLIP}i - \beta \cdot \frac{1}{N} \sum{i=1}^{N} D_{KL,i}$$
Where β (beta) controls the strength of the KL regularization.
Implement a function that:
ratios = [1.2, 0.8, 1.1]
advantages = [1.0, 1.0, 1.0]
pi_current = [0.9, 1.1, 1.0]
pi_ref = [1.0, 0.5, 1.5]
epsilon = 0.2
beta = 0.011.03195Step 1: Compute Clipped Surrogate Objective
For each sample, we compute min(ρ × A, clip(ρ, 0.8, 1.2) × A):
• Sample 1: ρ₁=1.2, A₁=1.0 → clip(1.2, 0.8, 1.2)=1.2 → min(1.2×1.0, 1.2×1.0) = 1.2 • Sample 2: ρ₂=0.8, A₂=1.0 → clip(0.8, 0.8, 1.2)=0.8 → min(0.8×1.0, 0.8×1.0) = 0.8 • Sample 3: ρ₃=1.1, A₃=1.0 → clip(1.1, 0.8, 1.2)=1.1 → min(1.1×1.0, 1.1×1.0) = 1.1
Average surrogate = (1.2 + 0.8 + 1.1) / 3 = 1.0333...
Step 2: Compute KL Divergence Penalty
For each sample, compute ρᵢ × (rᵢ - log(rᵢ) - 1) where rᵢ = π_current / π_ref:
• Sample 1: r₁=0.9/1.0=0.9, KL₁=1.2×(0.9-log(0.9)-1) ≈ 0.00632 • Sample 2: r₂=1.1/0.5=2.2, KL₂=0.8×(2.2-log(2.2)-1) ≈ 0.3290 • Sample 3: r₃=1.0/1.5≈0.667, KL₃=1.1×(0.667-log(0.667)-1) ≈ 0.0773
Average KL = (0.00632 + 0.3290 + 0.0773) / 3 ≈ 0.1376
Step 3: Final Objective
J(θ) = 1.0333 - 0.01 × 0.1376 ≈ 1.03195
ratios = [1.0, 1.0, 1.0]
advantages = [1.0, 1.0, 1.0]
pi_current = [0.5, 0.5, 0.5]
pi_ref = [0.5, 0.5, 0.5]
epsilon = 0.2
beta = 0.011.0Perfect Alignment Scenario
When all likelihood ratios equal 1.0 and current policy matches the reference policy exactly:
Clipped Surrogate: • All ratios are 1.0, so clip(1.0, 0.8, 1.2) = 1.0 • Each surrogate = 1.0 × 1.0 = 1.0 • Average surrogate = 1.0
KL Divergence: • All rᵢ = π_current[i] / π_ref[i] = 0.5/0.5 = 1.0 • For r = 1.0: KL = ρ × (1.0 - log(1.0) - 1) = ρ × (1 - 0 - 1) = 0 • Average KL = 0
Final Objective: J(θ) = 1.0 - 0.01 × 0 = 1.0
This demonstrates that when policies are perfectly aligned, the KL penalty is zero.
ratios = [1.1, 0.9, 1.0]
advantages = [2.0, -1.5, 0.5]
pi_current = [0.4, 0.6, 0.5]
pi_ref = [0.5, 0.5, 0.5]
epsilon = 0.2
beta = 0.050.449311Mixed Advantages with Varying Policies
Step 1: Clipped Surrogate Objective
• Sample 1: ρ₁=1.1, A₁=2.0 (positive advantage)
• Sample 2: ρ₂=0.9, A₂=-1.5 (negative advantage)
• Sample 3: ρ₃=1.0, A₃=0.5 (small positive)
Average surrogate = (2.2 + (-1.35) + 0.5) / 3 = 0.45
Step 2: KL Divergence Penalty
• r₁ = 0.4/0.5 = 0.8, KL₁ = 1.1 × (0.8 - log(0.8) - 1) ≈ 0.0222 • r₂ = 0.6/0.5 = 1.2, KL₂ = 0.9 × (1.2 - log(1.2) - 1) ≈ 0.0162 • r₃ = 0.5/0.5 = 1.0, KL₃ = 1.0 × (1.0 - log(1.0) - 1) = 0
Average KL ≈ 0.0128
Step 3: Final Objective
J(θ) = 0.45 - 0.05 × 0.0128 ≈ 0.449311
Constraints