Loading content...
While the reparameterization trick is preferred when applicable, the score function estimator (also known as REINFORCE or likelihood ratio gradient) is the universal tool that works for any distribution—including discrete latents where reparameterization fails.
This page provides mastery of the score function estimator: its derivation, implementations, variance reduction techniques, and modern extensions that bridge the gap with reparameterization.
By the end of this page, you will: (1) Derive and implement the score function estimator from first principles, (2) Apply variance reduction techniques including baselines, control variates, and Rao-Blackwellization, (3) Implement modern extensions like REBAR and RELAX, (4) Choose between score function and pathwise gradients appropriately.
We want to compute $ abla_\phi \mathbb{E}{q\phi(\mathbf{z})}[f(\mathbf{z})]$ where $q_\phi$ depends on parameters $\phi$.
The Log-Derivative Trick:
The key identity is: $$ abla_\phi q_\phi(\mathbf{z}) = q_\phi(\mathbf{z}) abla_\phi \log q_\phi(\mathbf{z})$$
This follows from $ abla \log q = abla q / q$.
Derivation:
$$ abla_\phi \mathbb{E}{q\phi}[f(\mathbf{z})] = abla_\phi \int q_\phi(\mathbf{z}) f(\mathbf{z}) d\mathbf{z}$$
By Leibniz rule (exchanging gradient and integral): $$= \int abla_\phi q_\phi(\mathbf{z}) f(\mathbf{z}) d\mathbf{z}$$
Applying the log-derivative trick: $$= \int q_\phi(\mathbf{z}) abla_\phi \log q_\phi(\mathbf{z}) f(\mathbf{z}) d\mathbf{z}$$
$$= \mathbb{E}{q\phi(\mathbf{z})}[f(\mathbf{z}) abla_\phi \log q_\phi(\mathbf{z})]$$
This is now an expectation we can estimate by sampling!
∇_φ log q_φ(z) is called the 'score function' in statistics. It measures the sensitivity of log-probability to parameter changes. The estimator weights function values f(z) by how parameter changes affect their probability.
12345678910111213141516171819202122232425262728293031323334353637383940
import torchfrom torch.distributions import Categorical, Bernoulli def score_function_estimator(q_distribution, f, n_samples=10): """ Basic score function gradient estimator. ∇_φ E_q[f(z)] = E_q[f(z) ∇_φ log q_φ(z)] """ gradient_estimates = [] for _ in range(n_samples): # Sample from q (non-differentiable for discrete) z = q_distribution.sample() # Compute log probability (differentiable w.r.t. parameters) log_prob = q_distribution.log_prob(z) # Compute function value (doesn't need gradients) with torch.no_grad(): reward = f(z) # Score function estimate: f(z) * ∇ log q(z) # Gradient of (reward * log_prob) w.r.t. params gives reward * ∇ log q gradient_estimates.append(reward * log_prob) # Average over samples return torch.stack(gradient_estimates).mean() # Example: optimizing discrete distributionlogits = torch.randn(5, requires_grad=True)q = Categorical(logits=logits) # Objective: make high-index samples more likelyf = lambda z: z.float() # Reward = sample value estimate = score_function_estimator(q, f, n_samples=100)estimate.backward()print(f"Gradient: {logits.grad}")The score function estimator is unbiased but can have extremely high variance, making training slow or unstable.
Why High Variance?
The gradient estimate is: $$\hat{g} = f(\mathbf{z}) abla_\phi \log q_\phi(\mathbf{z})$$
The variance depends on:
| Source | Why It's Bad | Mitigation |
|---|---|---|
| High reward variance | Large f(z) values dominate signal | Baseline subtraction |
| Score function variance | Different z give different ∇ directions | More samples, control variates |
| Rare but important samples | High-reward samples rarely seen | Importance sampling |
| High-dimensional z | Score is high-dimensional vector | Rao-Blackwellization |
Raw score function gradients can require 100-10,000x more samples than reparameterization to achieve similar gradient accuracy. This translates directly to training time and computational cost.
The most important variance reduction technique is baseline subtraction: subtracting a constant $b$ from the reward.
$$\hat{g} = (f(\mathbf{z}) - b) abla_\phi \log q_\phi(\mathbf{z})$$
Why It's Unbiased:
$$\mathbb{E}q[b abla\phi \log q_\phi(\mathbf{z})] = b abla_\phi \int q_\phi(\mathbf{z}) d\mathbf{z} = b abla_\phi 1 = 0$$
Subtracting $b$ doesn't change the expected gradient—it only reduces variance.
Optimal Baseline:
The variance-minimizing baseline is: $$b^* = \frac{\mathbb{E}[f(\mathbf{z})^2 | abla \log q|^2]}{\mathbb{E}[| abla \log q|^2]}$$
In practice, simpler approximations work well:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
class ScoreFunctionWithBaseline: """Score function estimator with running mean baseline.""" def __init__(self, decay=0.99): self.baseline = 0.0 self.decay = decay def estimate(self, q, f, n_samples=10): log_probs = [] rewards = [] for _ in range(n_samples): z = q.sample() log_probs.append(q.log_prob(z)) with torch.no_grad(): rewards.append(f(z)) log_probs = torch.stack(log_probs) rewards = torch.stack(rewards) # Center rewards with baseline centered_rewards = rewards - self.baseline # Score function estimate estimate = (centered_rewards * log_probs).mean() # Update baseline (exponential moving average) self.baseline = self.decay * self.baseline + (1 - self.decay) * rewards.mean().item() return estimate class LearnedBaseline(nn.Module): """Neural network baseline for input-dependent centering.""" def __init__(self, input_dim, hidden_dim=64): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) ) def forward(self, x): return self.net(x).squeeze(-1) def compute_loss(self, x, actual_rewards): predicted = self.forward(x) return F.mse_loss(predicted, actual_rewards)Control variates generalize baselines by subtracting functions with known expectations.
$$\hat{g} = f(\mathbf{z}) abla \log q - c \cdot (h(\mathbf{z}) - \mathbb{E}_q[h(\mathbf{z})]) abla \log q$$
If $h$ is correlated with $f$, this can dramatically reduce variance. The scaling $c$ is chosen to minimize variance.
Rao-Blackwellization:
For structured latents, we can analytically marginalize some components: $$\hat{g} = \mathbb{E}_{q(\mathbf{z}1)}[\mathbb{E}{q(\mathbf{z}_2|\mathbf{z}_1)}[f] abla \log q(\mathbf{z}_1)]$$
The inner expectation reduces the variance of the outer estimate.
Local Baselines:
For structured models, using different baselines for different components can be more effective than a global baseline.
123456789101112131415161718192021222324252627282930313233343536
def control_variate_estimator(q, f, h, h_expectation, n_samples=10): """ Score function with control variate. Uses h(z) - E[h(z)] as control variate. Optimal c minimizes variance. """ f_samples = [] h_samples = [] log_probs = [] for _ in range(n_samples): z = q.sample() log_probs.append(q.log_prob(z)) with torch.no_grad(): f_samples.append(f(z)) h_samples.append(h(z)) f_samples = torch.stack(f_samples) h_samples = torch.stack(h_samples) log_probs = torch.stack(log_probs) # Centered control variate h_centered = h_samples - h_expectation # Compute optimal scaling c # c* = Cov(f, h) / Var(h) cov_fh = ((f_samples - f_samples.mean()) * h_centered).mean() var_h = h_centered.var() c_optimal = cov_fh / (var_h + 1e-8) # Control variate adjusted estimate adjusted_f = f_samples - c_optimal * h_centered estimate = (adjusted_f * log_probs).mean() return estimateREBAR (REinforce with BAckpropagable Relaxation) and RELAX combine the best of score function and reparameterization for discrete latents.
Core Idea:
Use a continuous relaxation $\tilde{\mathbf{z}}$ of discrete $\mathbf{z}$:
REBAR Estimator:
$$\hat{g} = (f(\mathbf{z}) - \eta f(\tilde{\mathbf{z}})) abla \log p(\mathbf{z}|\tilde{\mathbf{z}}) + \eta abla_\phi f(\tilde{\mathbf{z}})$$
where $\tilde{\mathbf{z}}$ is the Gumbel-Softmax relaxation and $\eta$ is a learned temperature.
RELAX generalizes REBAR by learning the control variate function directly, achieving even lower variance.
1234567891011121314151617181920212223242526272829303132
def rebar_gradient(logits, f, temperature=1.0, eta=0.5): """ REBAR gradient estimator for discrete distributions. Combines score function with Gumbel-Softmax control variate. """ # Sample Gumbel noise u = torch.rand_like(logits) g = -torch.log(-torch.log(u + 1e-8) + 1e-8) # Continuous relaxation (Gumbel-Softmax) z_tilde = torch.softmax((logits + g) / temperature, dim=-1) # Discrete sample (Gumbel-Max) z_discrete = F.one_hot(torch.argmax(logits + g, dim=-1), num_classes=logits.size(-1)).float() # Conditional relaxation given discrete (for control variate) v = torch.rand_like(logits) # ... (conditional sampling math omitted for brevity) # Evaluate function at discrete and relaxed points f_discrete = f(z_discrete) f_tilde = f(z_tilde) # Score function term log_prob = torch.log_softmax(logits, dim=-1) score_term = (f_discrete - eta * f_tilde.detach()) * (z_discrete * log_prob).sum(dim=-1) # Reparameterization term reparam_term = eta * f_tilde return score_term + reparam_term| Estimator | Bias | Variance | Complexity |
|---|---|---|---|
| REINFORCE | None | Very High | Simple |
| REINFORCE + Baseline | None | High | Simple |
| Gumbel-Softmax (soft) | Yes | Low | Moderate |
| Straight-Through | Yes | Moderate | Simple |
| REBAR | None | Moderate | Complex |
| RELAX | None | Low | Complex |
Start with reparameterization when possible (continuous latents). If forced to use score function: (1) Always use baselines, (2) Use many samples (10-100+), (3) Consider REBAR/RELAX for discrete, (4) Monitor gradient variance during training.
Congratulations! You've mastered ELBO optimization—from its fundamental decomposition through the practical gradient estimation techniques that make variational inference work. You now have the tools to implement, train, and debug variational models effectively.