Loading content...
Training neural networks requires computing gradients of the loss with respect to parameters. For feedforward networks, this is straightforward backpropagation. For RNNs, where parameters are shared across timesteps and hidden states form chains of dependencies, we need Backpropagation Through Time (BPTT).
BPTT is conceptually simple: unroll the RNN into a feedforward network, then apply standard backpropagation. But this unrolling creates challenges—gradients must flow backward through potentially hundreds of timesteps, multiplying through the same weight matrix repeatedly. This multiplicative gradient flow is the source of both RNNs' power and their notorious training difficulties.
By the end of this page, you will understand: (1) the mathematical derivation of BPTT, (2) how gradients accumulate across timesteps, (3) truncated BPTT for practical training, (4) the computational cost of BPTT, and (5) implementation considerations.
Consider an RNN with loss computed at each timestep: $L = \sum_{t=1}^{T} L_t$ where $L_t = \ell(y_t, \hat{y}_t)$.
The chain rule across time:
For the hidden-to-hidden weights $W_{hh}$, the gradient must account for how $W_{hh}$ affects all future timesteps:
$$\frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^{T} \frac{\partial L_t}{\partial W_{hh}}$$
Each term requires summing contributions through all paths from $W_{hh}$ to $L_t$:
$$\frac{\partial L_t}{\partial W_{hh}} = \sum_{k=1}^{t} \frac{\partial L_t}{\partial h_t} \frac{\partial h_t}{\partial h_k} \frac{\partial^+ h_k}{\partial W_{hh}}$$
where $\frac{\partial^+ h_k}{\partial W_{hh}}$ is the immediate derivative (not through earlier hidden states).
The key recursive relation:
$$\frac{\partial h_t}{\partial h_k} = \prod_{i=k+1}^{t} \frac{\partial h_i}{\partial h_{i-1}} = \prod_{i=k+1}^{t} \text{diag}(f'(z_i)) W_{hh}$$
This product of Jacobians is where vanishing/exploding gradients originate.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
import numpy as np class RNNWithBPTT: """RNN with explicit BPTT implementation for educational purposes.""" def __init__(self, input_dim, hidden_dim, output_dim): scale = 0.01 self.W_xh = np.random.randn(hidden_dim, input_dim) * scale self.W_hh = np.random.randn(hidden_dim, hidden_dim) * scale self.W_hy = np.random.randn(output_dim, hidden_dim) * scale self.b_h = np.zeros((hidden_dim, 1)) self.b_y = np.zeros((output_dim, 1)) self.hidden_dim = hidden_dim def forward(self, X): """Forward pass, storing values for BPTT.""" T = len(X) self.h = {0: np.zeros((self.hidden_dim, 1))} self.z = {} # Pre-activations self.y_hat = {} for t in range(1, T + 1): x_t = X[t-1].reshape(-1, 1) self.z[t] = self.W_xh @ x_t + self.W_hh @ self.h[t-1] + self.b_h self.h[t] = np.tanh(self.z[t]) self.y_hat[t] = self.W_hy @ self.h[t] + self.b_y return self.y_hat def bptt(self, X, Y): """Full BPTT: compute gradients through all timesteps.""" T = len(X) # Initialize gradient accumulators dW_xh = np.zeros_like(self.W_xh) dW_hh = np.zeros_like(self.W_hh) dW_hy = np.zeros_like(self.W_hy) db_h = np.zeros_like(self.b_h) db_y = np.zeros_like(self.b_y) # Backward pass dh_next = np.zeros((self.hidden_dim, 1)) for t in reversed(range(1, T + 1)): x_t = X[t-1].reshape(-1, 1) y_t = Y[t-1].reshape(-1, 1) # Output gradient dy = self.y_hat[t] - y_t # MSE derivative dW_hy += dy @ self.h[t].T db_y += dy # Hidden state gradient (from output + from future) dh = self.W_hy.T @ dy + dh_next # Pre-activation gradient dz = dh * (1 - self.h[t]**2) # tanh derivative # Parameter gradients dW_xh += dz @ x_t.T dW_hh += dz @ self.h[t-1].T db_h += dz # Gradient to previous hidden state dh_next = self.W_hh.T @ dz return {'W_xh': dW_xh, 'W_hh': dW_hh, 'W_hy': dW_hy, 'b_h': db_h, 'b_y': db_y}Because parameters are shared, gradients from every timestep contribute to the total gradient. This accumulation has important implications.
How gradients accumulate:
For $W_{hh}$, timestep $t$ contributes gradients from:
The total gradient is the sum of all these contributions.
| From Step | Affects Steps | Path Length | Gradient Magnitude |
|---|---|---|---|
| t=T | L_T only | 1 | Strong (no decay) |
| t=T-1 | L_{T-1}, L_T | 1-2 | Strong to moderate |
| t=T-k | L_{T-k}, ..., L_T | 1 to k+1 | Decays with k |
| t=1 | L_1, ..., L_T | 1 to T | Weak (maximal decay) |
Gradients from early timesteps pass through T-t matrix multiplications to reach the loss. If eigenvalues of W_hh are < 1, these gradients vanish exponentially. If > 1, they explode. This is why vanilla RNNs struggle with long-range dependencies.
For long sequences, full BPTT is impractical: memory grows linearly with sequence length, and gradients from early timesteps vanish anyway. Truncated BPTT limits gradient flow to the most recent $k$ timesteps.
How it works:
Tradeoffs:
| Truncation Length | Memory | Long-Range Learning | Speed |
|---|---|---|---|
| Full (k=T) | O(T) | Theoretically possible | Slow |
| k=100 | O(100) | Limited to ~100 steps | Moderate |
| k=20 | O(20) | Short-range only | Fast |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
import torchimport torch.nn as nn def truncated_bptt_training(model, sequence, targets, k1, k2, optimizer): """ Truncated BPTT training loop. Args: model: RNN model sequence: Full input sequence (seq_len, batch, features) targets: Target values k1: Forward chunk size k2: Backward truncation length (how far to backprop) """ seq_len = sequence.shape[0] hidden = None total_loss = 0 for start in range(0, seq_len, k1): end = min(start + k1, seq_len) chunk = sequence[start:end] chunk_targets = targets[start:end] # Detach hidden state from previous chunk's graph if hidden is not None: hidden = hidden.detach() # Forward pass on chunk output, hidden = model(chunk, hidden) # Compute loss loss = nn.functional.mse_loss(output, chunk_targets) total_loss += loss.item() # Backward pass (only through k2 steps due to detach) optimizer.zero_grad() loss.backward() # Gradient clipping (essential for RNNs) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() return total_loss class TruncatedBPTTTrainer: """Trainer implementing truncated BPTT with configurable truncation.""" def __init__(self, model, optimizer, k1=50, k2=50, clip_norm=1.0): self.model = model self.optimizer = optimizer self.k1 = k1 # Chunk size for forward self.k2 = k2 # Truncation length for backward self.clip_norm = clip_norm def train_epoch(self, dataloader): self.model.train() total_loss = 0 for batch_x, batch_y in dataloader: seq_len = batch_x.shape[1] hidden = None batch_loss = 0 # Process in chunks for t in range(0, seq_len, self.k1): # Get chunk chunk_x = batch_x[:, t:t+self.k1] chunk_y = batch_y[:, t:t+self.k1] # Detach hidden to truncate gradient flow if hidden is not None: hidden = tuple(h.detach() for h in hidden) if isinstance(hidden, tuple) else hidden.detach() # Forward output, hidden = self.model(chunk_x, hidden) loss = nn.functional.cross_entropy( output.reshape(-1, output.shape[-1]), chunk_y.reshape(-1) ) # Backward with gradient clipping self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm) self.optimizer.step() batch_loss += loss.item() total_loss += batch_loss return total_loss / len(dataloader)BPTT has similar computational complexity to the forward pass, but with additional memory requirements.
Time complexity:
Memory complexity:
Key insight: Memory, not computation, is usually the bottleneck for long sequences.
What's next:
The final page explores Computational Considerations—practical aspects of implementing and deploying RNNs including parallelization strategies, hardware utilization, and optimization techniques.
You now understand BPTT: how gradients flow backward through time, why truncation is practical, and the computational tradeoffs involved in training RNNs.