Loading content...
In the previous page, we discovered REINFORCE's Achilles heel: variance. Gradient estimates fluctuate wildly from episode to episode, making learning slow and unreliable. But variance isn't a death sentence—it's a problem with solutions.
The field of variance reduction for policy gradients has developed remarkably elegant techniques. The most important is the baseline—a simple subtraction that can reduce variance by orders of magnitude without introducing any bias. Understanding baselines is crucial because they're the bridge to Actor-Critic methods, where learned value functions serve as sophisticated baselines.
In this page, we'll derive why baselines work, prove they preserve unbiasedness, explore optimal baseline selection, and implement these techniques. By the end, you'll understand the mathematical foundations that make modern policy gradient algorithms practical.
By the end of this page, you will understand: why baselines reduce variance without bias, how to derive the optimal baseline, the connection between baselines and control variates from statistics, practical implementation strategies, and how these concepts lead naturally to Actor-Critic methods.
The baseline is arguably the single most important variance reduction technique in policy gradients. The idea is deceptively simple: instead of weighting gradients by raw returns G_t, we weight by returns minus a baseline b(s_t).
The Modified Policy Gradient:
12345678910
# Original REINFORCE gradient:∇_θ J(θ) = E[Σ_t ∇_θ log π_θ(a_t|s_t) · G_t] # With baseline b(s):∇_θ J(θ) = E[Σ_t ∇_θ log π_θ(a_t|s_t) · (G_t - b(s_t))] # Key properties:# 1. Still unbiased (we'll prove this)# 2. Variance can be dramatically reduced# 3. b(s) can be any function of state (not dependent on action)Why Does Subtracting a Baseline Help?
Intuition: In REINFORCE, if all returns are positive (common in many tasks), we increase the probability of all actions taken. The relative increase determines which actions become more likely. But in absolute terms, we're doing large positive updates.
With a baseline:
This is like grading on a curve: we care about whether an action is better or worse than expected, not just whether its return is positive.
The baseline centers the gradient weights around zero. Without a baseline, gradients have a large mean component (pushing all probabilities in one direction) plus useful signal. With a good baseline, we remove the mean and amplify the signal-to-noise ratio of the gradient.
The mathematical elegance of baselines lies in the fact that they can be any function of state without introducing bias. Let's prove this rigorously.
Theorem: State-Dependent Baselines Are Unbiased
1234567891011121314151617181920212223242526
# We want to prove:E[∇_θ log π_θ(a|s) · b(s)] = 0 for any function b(s) # Proof:E_{s,a}[∇_θ log π_θ(a|s) · b(s)] = Σ_s ρ(s) Σ_a π_θ(a|s) · ∇_θ log π_θ(a|s) · b(s) # Since b(s) doesn't depend on a, factor it out:= Σ_s ρ(s) · b(s) · Σ_a π_θ(a|s) · ∇_θ log π_θ(a|s) # Now apply the identity: π(a|s) · ∇ log π(a|s) = ∇π(a|s)= Σ_s ρ(s) · b(s) · Σ_a ∇_θ π_θ(a|s) # Interchange sum and gradient:= Σ_s ρ(s) · b(s) · ∇_θ Σ_a π_θ(a|s) # Probabilities sum to 1:= Σ_s ρ(s) · b(s) · ∇_θ 1 = Σ_s ρ(s) · b(s) · 0 = 0 ∎ # This is why the expected score is zero!# E_a[∇_θ log π_θ(a|s)] = 0 for any fixed sKey Insight: The Score Has Zero Mean
The proof relies on the fact that E_a~π[∇_θ log π_θ(a|s)] = 0. This is the zero-mean property of the score function. When we multiply the baseline by this zero-mean quantity, the expected contribution is zero.
What If the Baseline Depends on Action?
12345678910111213
# If b depends on action, we CANNOT factor it out! E[∇_θ log π_θ(a|s) · b(s, a)]= Σ_a π_θ(a|s) · ∇_θ log π_θ(a|s) · b(s, a) # This does NOT equal zero in general!# The baseline interacts with the gradient in action-specific ways # Exception: If b(s, a) = c for all a (constant w.r.t. action),# then it's effectively a state-only baseline # Important: Action-dependent baselines CAN be used but require# additional corrections (importance-weighted baselines)For a baseline to be unbiased without correction, it must NOT depend on the action taken. This is why we use V(s) as a baseline, not Q(s,a). Using Q(s,a) directly as a baseline would introduce bias unless special corrections are applied.
Since any state-dependent baseline is unbiased, which one minimizes variance? This is an optimization problem with a beautiful closed-form solution.
Setting Up the Optimization:
12345678910111213141516171819202122
# Variance of the gradient estimator (for a single state-action):Var[∇_θ log π_θ(a|s) · (G - b)] # For scalar case, expand:= E[(∇_θ log π · (G - b))²] - (E[∇_θ log π · (G - b)])² # Since E[∇_θ log π · b] = 0, the second term simplifies:= E[(∇_θ log π)² · (G - b)²] - (E[∇_θ log π · G])² # Let g = ∇_θ log π and expand the first term:= E[g² · (G² - 2Gb + b²)] - (E[gG])²= E[g²G²] - 2b·E[g²G] + b²·E[g²] - (E[gG])² # To minimize, take derivative w.r.t. b and set to zero:d/db Var = -2E[g²G] + 2b·E[g²] = 0 # Solve for optimal b:b* = E[g²G] / E[g²] = E[(∇_θ log π)² · G] / E[(∇_θ log π)²] # This is a weighted average of returns!# Weights: squared gradient magnitudesInterpretation of the Optimal Baseline:
The optimal baseline b* is a weighted average of returns, where weights are the squared gradient magnitudes. Intuitively:
In Practice: Why Use V(s) Instead?
1234567891011121314151617181920212223
# The truly optimal baseline b* is complex to compute exactly # Approximation 1: Constant baseline# b = E[G] (average return across all episodes)# Simple but not state-dependent # Approximation 2: State value function# b(s) = V(s) = E[G_t | s_t = s]# Expected return from state s under policy π # Why V(s) is approximately optimal:# - It's the expected return conditioned on state# - G_t - V(s_t) measures how much better/worse the actual# trajectory was compared to expectation# - This is exactly the "advantage" intuition! # The advantage function:A(s, a) = Q(s, a) - V(s)# Expected improvement over average action # Using V(s) as baseline:∇_θ J ≈ E[∇_θ log π(a|s) · (G_t - V(s_t))] ≈ E[∇_θ log π(a|s) · A(s, a)] # Advantage interpretationUsing V(s) as a baseline is the foundation of Actor-Critic methods. The 'Critic' learns V(s), providing a baseline that: (1) is approximately optimal, (2) reduces variance dramatically, (3) enables us to interpret gradients as weighting by advantages. This insight connects policy gradients to value-based concepts.
Baselines in policy gradients are a specific instance of a general statistical technique called control variates. Understanding this broader perspective reveals additional variance reduction opportunities.
Control Variates in Statistics:
1234567891011121314151617181920
# Goal: Estimate E[X] with low variance # Idea: Find a random variable C (the control variate) where:# 1. E[C] is known (or zero)# 2. C is correlated with X # New estimator:X̃ = X - α(C - E[C]) # Properties:E[X̃] = E[X] - αE[C - E[C]] = E[X] (still unbiased)Var[X̃] = Var[X] + α²Var[C] - 2α·Cov[X,C] # Optimal α:α* = Cov[X,C] / Var[C] # Resulting variance reduction:Var[X̃*] = Var[X](1 - ρ²_{X,C}) where ρ_{X,C} is the correlation coefficientBaselines as Control Variates:
In policy gradients:
The more correlated b(s) is with G, the more variance reduction we get. Since G depends heavily on the current state s, a state-dependent baseline like V(s) captures much of this correlation.
Advanced Control Variates:
While baselines maintain unbiasedness, more sophisticated control variates may introduce small biases for larger variance reduction. Modern algorithms like GAE (Generalized Advantage Estimation) explicitly tune this tradeoff via a parameter λ.
Let's implement various baseline strategies, from simple to sophisticated.
Implementation 1: Moving Average Baseline
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
class MovingAverageBaseline: """ Simple baseline: exponential moving average of returns. Not state-dependent, but captures overall scale of returns. Good for removing the mean while being simple to implement. """ def __init__(self, alpha: float = 0.99): self.alpha = alpha self.baseline = 0.0 self.initialized = False def update(self, return_value: float): """Update the running average.""" if not self.initialized: self.baseline = return_value self.initialized = True else: self.baseline = self.alpha * self.baseline + (1 - self.alpha) * return_value def get_baseline(self, state=None) -> float: """Get baseline value (ignores state for this simple version).""" return self.baseline class REINFORCEWithBaseline: """REINFORCE with a simple moving average baseline.""" def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99): self.policy = PolicyNetwork(state_dim, action_dim) self.optimizer = optim.Adam(self.policy.parameters(), lr=lr) self.gamma = gamma self.baseline = MovingAverageBaseline() self.log_probs = [] self.rewards = [] def update(self): returns = self.compute_returns() # Update baseline with episode return self.baseline.update(returns[0].item()) # Subtract baseline baseline_value = self.baseline.get_baseline() advantages = returns - baseline_value # Optional: normalize advantages if len(advantages) > 1: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) log_probs = torch.cat(self.log_probs) loss = -(log_probs * advantages).sum() self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.log_probs = [] self.rewards = [] return loss.item()Implementation 2: Learned State-Dependent Baseline (Value Network)
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
class ValueNetwork(nn.Module): """ Neural network that estimates V(s) - the expected return from state s. Used as a learned, state-dependent baseline. """ def __init__(self, state_dim: int, hidden_dims: List[int] = [128, 128]): super().__init__() layers = [] prev_dim = state_dim for hidden_dim in hidden_dims: layers.append(nn.Linear(prev_dim, hidden_dim)) layers.append(nn.ReLU()) prev_dim = hidden_dim layers.append(nn.Linear(prev_dim, 1)) # Single value output self.network = nn.Sequential(*layers) def forward(self, state: torch.Tensor) -> torch.Tensor: return self.network(state).squeeze(-1) class REINFORCEWithLearnedBaseline: """ REINFORCE with a learned value function baseline. This is essentially a simplified Actor-Critic! """ def __init__( self, state_dim: int, action_dim: int, policy_lr: float = 1e-3, value_lr: float = 1e-3, gamma: float = 0.99 ): # Policy network (actor) self.policy = PolicyNetwork(state_dim, action_dim) self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=policy_lr) # Value network (critic/baseline) self.value_net = ValueNetwork(state_dim) self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=value_lr) self.gamma = gamma # Episode storage self.states = [] self.log_probs = [] self.rewards = [] def select_action(self, state: np.ndarray) -> int: state_tensor = torch.FloatTensor(state) self.states.append(state_tensor) action, log_prob = self.policy.get_action(state) self.log_probs.append(log_prob) return action def store_reward(self, reward: float): self.rewards.append(reward) def compute_returns(self) -> torch.Tensor: returns = [] G = 0 for reward in reversed(self.rewards): G = reward + self.gamma * G returns.insert(0, G) return torch.tensor(returns, dtype=torch.float32) def update(self): returns = self.compute_returns() states = torch.stack(self.states) log_probs = torch.cat(self.log_probs) # Get value predictions (baseline) values = self.value_net(states) # Compute advantages: G_t - V(s_t) advantages = returns - values.detach() # Detach to not affect policy gradient # Normalize advantages for stability if len(advantages) > 1: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Policy loss: negative because we want to maximize policy_loss = -(log_probs * advantages).sum() # Value loss: MSE between predicted V(s) and actual returns value_loss = F.mse_loss(values, returns) # Update policy self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() # Update value network self.value_optimizer.zero_grad() value_loss.backward() self.value_optimizer.step() # Clear episode data self.states = [] self.log_probs = [] self.rewards = [] return policy_loss.item(), value_loss.item()Notice that REINFORCEWithLearnedBaseline is essentially an Actor-Critic algorithm! The 'actor' (policy network) decides actions, and the 'critic' (value network) evaluates states. The key difference from full Actor-Critic is that we still wait for complete episodes to compute returns, rather than bootstrapping.
When using V(s) as a baseline, the quantity G_t - V(s_t) is an estimate of the advantage function A(s, a). This connection is profound and leads to Generalized Advantage Estimation (GAE), a cornerstone technique in modern RL.
From Returns to Advantages:
12345678910111213141516
# Advantage function definition:A^π(s, a) = Q^π(s, a) - V^π(s) # Interpretation: How much better is action a than the average action in state s? # Monte Carlo advantage estimate:Â_t^MC = G_t - V(s_t) = (r_t + γr_{t+1} + γ²r_{t+2} + ...) - V(s_t) # This has low bias but high variance (all future rewards included) # One-step TD advantage estimate:Â_t^TD(1) = r_t + γV(s_{t+1}) - V(s_t) = δ_t (the TD error) # This has low variance but high bias (depends on V accuracy)The Bias-Variance Tradeoff:
Different advantage estimates trade off bias and variance:
| Method | Formula | Bias | Variance |
|---|---|---|---|
| Monte Carlo | Â = G_t - V(s_t) | Unbiased | High |
| TD(1) | Â = r_t + γV(s_{t+1}) - V(s_t) | High (depends on V) | Low |
| TD(n) | Â = Σ_{k=0}^{n-1} γ^k r_{t+k} + γ^n V(s_{t+n}) - V(s_t) | Medium | Medium |
| GAE(λ) | Â = Σ_{k=0}^{∞} (γλ)^k δ_{t+k} | Tunable via λ | Tunable via λ |
Generalized Advantage Estimation (GAE):
GAE, introduced by Schulman et al. (2016), elegantly interpolates between these estimates:
12345678910111213141516
# Define TD errors:δ_t = r_t + γV(s_{t+1}) - V(s_t) # GAE with parameter λ ∈ [0, 1]:Â_t^GAE(γ,λ) = Σ_{k=0}^{∞} (γλ)^k δ_{t+k} = δ_t + γλ·δ_{t+1} + (γλ)²·δ_{t+2} + ... # Special cases:# λ = 0: Â = δ_t (TD(1) estimate, low variance, high bias)# λ = 1: Â = Σ γ^k δ_{t+k} = G_t - V(s_t) (MC estimate, low bias, high variance) # Recursive computation:Â_t^GAE = δ_t + γλ · Â_{t+1}^GAE (with Â_T = 0) # Practical: λ ≈ 0.95 works well for most tasks# Balances bias and variance effectively12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
def compute_gae( rewards: torch.Tensor, values: torch.Tensor, next_value: float, gamma: float = 0.99, lam: float = 0.95) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute Generalized Advantage Estimation. Args: rewards: Tensor of rewards [r_0, r_1, ..., r_T] values: Tensor of value estimates [V(s_0), V(s_1), ..., V(s_T)] next_value: Value of terminal state (0 if episode ended) gamma: Discount factor lam: GAE lambda for bias-variance tradeoff Returns: advantages: GAE advantage estimates returns: Returns for value function training (advantages + values) """ T = len(rewards) advantages = torch.zeros(T) # Last advantage uses next_value (bootstrap or 0 if terminal) gae = 0 values_extended = torch.cat([values, torch.tensor([next_value])]) # Compute advantages backwards for t in reversed(range(T)): # TD error: δ_t = r_t + γV(s_{t+1}) - V(s_t) delta = rewards[t] + gamma * values_extended[t + 1] - values_extended[t] # GAE: Â_t = δ_t + γλ · Â_{t+1} gae = delta + gamma * lam * gae advantages[t] = gae # Returns for value training: R_t = Â_t + V(s_t) returns = advantages + values return advantages, returns # Example usage:rewards = torch.tensor([1.0, 0.0, 1.0, 0.0, 10.0]) # Episode rewardsvalues = torch.tensor([5.0, 4.0, 3.0, 2.0, 1.0]) # Value predictionsadvantages, returns = compute_gae(rewards, values, next_value=0.0)print(f"Advantages: {advantages}")print(f"Returns: {returns}")GAE is used in virtually all modern policy gradient implementations, including PPO and A2C. The λ parameter (typically 0.95-0.99) allows tuning for specific tasks. Higher λ reduces bias but increases variance; lower λ does the opposite. GAE is one of the key innovations that made deep policy gradients practical.
Let's quantify how much each variance reduction technique actually helps. Understanding these tradeoffs is crucial for algorithm design.
Experimental Comparison:
| Method | Gradient Variance | Episodes to Solve | Stability |
|---|---|---|---|
| Vanilla REINFORCE | 1.0× (baseline) | ~1000 | Poor |
| 0.5× | ~700 | Good |
| 0.3× | ~500 | Better |
| 0.1× | ~300 | Excellent |
| 0.08× | ~250 | Excellent |
Visualization of Variance Reduction:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
def compare_variance_reduction(env, num_episodes=100): """ Compare gradient variance across different baseline strategies. """ methods = { 'No Baseline': compute_gradients_no_baseline, 'Mean Baseline': compute_gradients_mean_baseline, 'State-Dependent V(s)': compute_gradients_value_baseline, 'GAE': compute_gradients_gae } results = {} for name, compute_fn in methods.items(): gradient_norms = [] for _ in range(num_episodes): # Collect episode and compute gradient grad_norm = compute_fn(env) gradient_norms.append(grad_norm) results[name] = { 'mean': np.mean(gradient_norms), 'std': np.std(gradient_norms), 'cv': np.std(gradient_norms) / np.mean(gradient_norms) # Coefficient of variation } # Print comparison print("Variance Reduction Comparison:") print("-" * 60) for name, stats in results.items(): print(f"{name:25} | Mean: {stats['mean']:8.2f} | " f"Std: {stats['std']:8.2f} | CV: {stats['cv']:.3f}") return results def visualize_gradient_distributions(results): """Visualize gradient distributions for each method.""" import matplotlib.pyplot as plt fig, axes = plt.subplots(2, 2, figsize=(12, 10)) axes = axes.flatten() for idx, (name, gradients) in enumerate(results.items()): ax = axes[idx] ax.hist(gradients, bins=50, density=True, alpha=0.7) ax.axvline(np.mean(gradients), color='r', linestyle='--', label=f'Mean: {np.mean(gradients):.2f}') ax.set_title(f'{name}\nStd: {np.std(gradients):.2f}') ax.set_xlabel('Gradient Norm') ax.set_ylabel('Density') ax.legend() plt.suptitle('Gradient Distribution by Variance Reduction Method') plt.tight_layout() plt.show()Variance reduction techniques are complementary and stack multiplicatively! Using return normalization + learned baseline + GAE can reduce variance by 90%+ compared to vanilla REINFORCE. This is why modern algorithms like PPO combine multiple techniques.
While baselines are the primary variance reduction tool, several other techniques are used in modern policy gradient algorithms.
Technique 1: Reward Shaping
123456789101112131415
# Reward shaping adds intermediate rewards to guide learning# Must be done carefully to preserve optimal policy! # Potential-based reward shaping (Ng et al., 1999):r'(s, a, s') = r(s, a, s') + γΦ(s') - Φ(s) where Φ(s) is a potential function # Key theorem: This transformation:# 1. Preserves the optimal policy# 2. Can dramatically reduce variance# 3. Φ(s) ≈ V*(s) is ideal (but we don't know V*) # In practice: Use learned V(s) for shaping# This is related to advantage estimation!Technique 2: Batch Normalization of Gradients
123456789101112131415161718192021
def normalized_policy_gradient( log_probs: torch.Tensor, advantages: torch.Tensor, max_grad_norm: float = 0.5) -> torch.Tensor: """ Compute policy gradient with multiple normalization strategies. """ # 1. Advantage normalization (zero mean, unit variance) advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # 2. Compute loss loss = -(log_probs * advantages).mean() # 3. Backward pass loss.backward() # 4. Gradient clipping (limit gradient magnitude) grad_norm = torch.nn.utils.clip_grad_norm_(policy.parameters(), max_grad_norm) return loss, grad_normTechnique 3: Entropy Regularization for Exploration
12345678910111213
# Add entropy bonus to prevent premature convergenceJ_entropy(θ) = J(θ) + β · H(π_θ) where H(π_θ) = -E[log π_θ(a|s)] # Benefits:# 1. Encourages exploration (maintain action diversity)# 2. Prevents policy collapse to deterministic# 3. Smoother optimization landscape# 4. Indirectly reduces variance by preventing overconfident policies # Typical β values: 0.01 to 0.1# Decrease over training for exploitationState-of-the-art policy gradient algorithms (PPO, SAC) combine: learned value baseline, GAE for advantage estimation, advantage normalization, gradient clipping, entropy regularization, multiple parallel workers, and trust region constraints. Each addresses a specific variance or stability issue.
We've explored the essential techniques for making policy gradients practical. Let's consolidate our understanding:
What's next:
With variance reduction techniques in hand, we're ready to combine policy learning with value learning. The next page covers Actor-Critic methods—where a learned value function (critic) provides both a baseline for variance reduction and bootstrapped targets for faster learning.
You now understand how to tame policy gradient variance. Baselines are the key insight: subtracting an appropriate value from returns eliminates the mean gradient component while preserving the useful signal. This enables practical policy learning. Next, we'll see how learned value functions serve as both baselines and temporal abstraction in Actor-Critic methods.