Loading learning content...
Recurrent neural networks process sequences by iterating the same computation—$h_t = f(h_{t-1}, x_t)$—across time. While this recursive definition is compact and elegant, understanding how RNNs actually compute (and how they're trained) requires a different perspective: unfolding the recurrence into an explicit computational graph.
When we unfold an RNN across time, we reveal its true structure: a very deep feedforward network with shared weights. This perspective illuminates both the power and the challenges of recurrent processing. The depth—one 'layer' per timestep—explains why gradients can vanish or explode. The weight sharing explains why RNNs can generalize across positions.
In this page, we develop the unfolded computation graph in detail. We trace the forward pass, set up backpropagation through time (BPTT), and connect the computational structure to the gradient flow that makes learning possible—or, in some cases, impossible.
By the end of this page, you will: (1) Visualize RNNs as unfolded computational graphs, (2) Trace the forward pass through time, (3) Understand backpropagation through time (BPTT), (4) Analyze gradient flow and identify bottlenecks, and (5) Connect computational depth to learning challenges.
RNNs can be described in two equivalent but conceptually different ways:
The Compact (Recursive) View:
$$h_t = f(h_{t-1}, x_t; \theta)$$ $$y_t = g(h_t; \phi)$$
This view emphasizes the recursive relationship: each hidden state depends on the previous one. The parameters $\theta$ and $\phi$ are shared across all time steps. This is computationally how RNNs are implemented—a loop that repeatedly applies the same function.
The Unfolded (Unrolled) View:
For a sequence of length $T$, we can 'unroll' the recursion into an explicit computational graph:
$$h_1 = f(h_0, x_1; \theta)$$ $$h_2 = f(h_1, x_2; \theta)$$ $$\vdots$$ $$h_T = f(h_{T-1}, x_T; \theta)$$
This view makes explicit that the RNN is equivalent to a $T$-layer deep network where each 'layer' corresponds to a time step, and all layers share the same weights $\theta$.
The unfolded view reveals that training an RNN on a sequence of length T is equivalent to training a T-layer deep network. This explains both the power of RNNs (deep representations can be very expressive) and their challenges (training very deep networks is notoriously difficult due to gradient issues).
Visualizing the Unfolded Graph:
x_1 x_2 x_3 x_T
│ │ │ │
▼ ▼ ▼ ▼
┌───────┐ ┌───────┐ ┌───────┐ ┌───────┐
│ f │──▶│ f │──▶│ f │──▶ ... ──▶│ f │
│ θ │ │ θ │ │ θ │ │ θ │
└───────┘ └───────┘ └───────┘ └───────┘
│ │ │ │
▼ ▼ ▼ ▼
h_1 h_2 h_3 h_T
│ │ │ │
▼ ▼ ▼ ▼
y_1 y_2 y_3 y_T
Each box represents a copy of the same function $f$ with shared parameters $\theta$. Horizontal arrows show hidden state flow; vertical arrows show input/output connections. The entire structure forms a DAG (directed acyclic graph) suitable for automatic differentiation.
The forward pass computes the sequence of hidden states and outputs from the input sequence. For a simple (vanilla) RNN, this is:
Step 1: Initialize $$h_0 = \mathbf{0} \quad \text{(or learned initial state)}$$
Step 2: Iterate through time $$\text{For } t = 1, 2, \ldots, T:$$ $$\quad a_t = W_{hh} h_{t-1} + W_{xh} x_t + b_h$$ $$\quad h_t = \sigma(a_t) \quad \text{(typically } \sigma = \tanh \text{)}$$ $$\quad o_t = W_{hy} h_t + b_y$$ $$\quad y_t = \text{softmax}(o_t) \quad \text{(for classification)}$$
where:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
import torchimport torch.nn as nnimport torch.nn.functional as F class VanillaRNN(nn.Module): """ Vanilla RNN implementation showing explicit forward pass. """ def __init__(self, input_size, hidden_size, output_size): super().__init__() self.hidden_size = hidden_size # Shared parameters across all timesteps self.W_xh = nn.Linear(input_size, hidden_size) # Input to hidden self.W_hh = nn.Linear(hidden_size, hidden_size) # Hidden to hidden self.W_hy = nn.Linear(hidden_size, output_size) # Hidden to output def forward(self, x_sequence, h_0=None): """ Forward pass through entire sequence. Args: x_sequence: [batch_size, seq_length, input_size] h_0: Optional initial hidden state [batch_size, hidden_size] Returns: outputs: [batch_size, seq_length, output_size] hidden_states: [batch_size, seq_length, hidden_size] h_final: [batch_size, hidden_size] """ batch_size, seq_length, _ = x_sequence.shape # Initialize hidden state if h_0 is None: h = torch.zeros(batch_size, self.hidden_size, device=x_sequence.device) else: h = h_0 # Storage for outputs outputs = [] hidden_states = [] # Iterate through time (the unfolded loop) for t in range(seq_length): # Get input at time t x_t = x_sequence[:, t, :] # Compute pre-activation # a_t = W_xh @ x_t + W_hh @ h_{t-1} + b a_t = self.W_xh(x_t) + self.W_hh(h) # Apply activation (tanh is standard for vanilla RNN) h = torch.tanh(a_t) # Compute output o_t = self.W_hy(h) hidden_states.append(h) outputs.append(o_t) # Stack outputs: [batch, seq_len, dim] outputs = torch.stack(outputs, dim=1) hidden_states = torch.stack(hidden_states, dim=1) return outputs, hidden_states, h # h is final hidden state # Example usagebatch_size, seq_length, input_size = 32, 100, 50hidden_size, output_size = 256, 10000 # e.g., vocabulary size model = VanillaRNN(input_size, hidden_size, output_size)x = torch.randn(batch_size, seq_length, input_size)outputs, hidden_states, h_final = model(x) print(f"Input shape: {x.shape}")print(f"Outputs shape: {outputs.shape}")print(f"Hidden states shape: {hidden_states.shape}")print(f"Final hidden shape: {h_final.shape}")Key Observations:
Sequential computation: Each $h_t$ depends on $h_{t-1}$, so hidden states must be computed in order.
Fixed computation per step: The same operations (matrix multiplications, activation) are performed at each timestep.
Memory accumulation: Hidden states form a 'memory' that accumulates information as processing progresses.
Output flexibility: We can produce outputs at every timestep, only at the end, or at selected positions depending on the task.
Training RNNs requires computing gradients of the loss with respect to all parameters. Since parameters are shared across time, we must accumulate contributions from all timesteps. Backpropagation Through Time (BPTT) is the algorithm for this—it's simply standard backpropagation applied to the unfolded computational graph.
The Total Loss:
For a sequence with losses at each timestep:
$$\mathcal{L} = \sum_{t=1}^{T} \mathcal{L}_t(y_t, \hat{y}_t)$$
where $\mathcal{L}_t$ is the loss at time $t$ (e.g., cross-entropy for language modeling).
Gradient with Respect to Shared Parameters:
Because $W_{hh}$ (for example) is used at every timestep, its gradient is the sum of contributions from all timesteps:
$$\frac{\partial \mathcal{L}}{\partial W_{hh}} = \sum_{t=1}^{T} \frac{\partial \mathcal{L}}{\partial W_{hh}^{(t)}}$$
where $W_{hh}^{(t)}$ denotes the 'copy' of $W_{hh}$ at time $t$ in the unfolded graph.
Despite its special name, BPTT is just backpropagation applied to the unfolded graph. The 'through time' emphasizes that gradients flow backward through the sequence (from t=T to t=1), but the mechanics are identical to backprop in any deep network. Modern auto-diff frameworks (PyTorch, TensorFlow) handle this automatically.
The Backward Pass:
Define $\delta_t = \frac{\partial \mathcal{L}}{\partial h_t}$ as the gradient of the total loss with respect to the hidden state at time $t$. This gradient has two sources:
$$\delta_t = \frac{\partial \mathcal{L}t}{\partial h_t} + \delta{t+1} \cdot \frac{\partial h_{t+1}}{\partial h_t}$$
with boundary condition $\delta_{T+1} = 0$ (no future beyond $T$).
The Jacobian $\frac{\partial h_{t+1}}{\partial h_t}$ involves:
$$\frac{\partial h_{t+1}}{\partial h_t} = \text{diag}(\sigma'(a_{t+1})) \cdot W_{hh}$$
where $\sigma'$ is the derivative of the activation function.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
import torchimport torch.nn as nn def bptt_gradient_analysis(model, x_sequence, targets): """ Analyze gradient flow through BPTT. Shows how gradients accumulate from future to past. """ model.train() # Forward pass outputs, hidden_states, _ = model(x_sequence) # Compute loss at each timestep batch_size, seq_length, vocab_size = outputs.shape total_loss = 0 losses_per_step = [] for t in range(seq_length): loss_t = nn.functional.cross_entropy( outputs[:, t, :], targets[:, t] ) total_loss += loss_t losses_per_step.append(loss_t.item()) # Backward pass - gradients flow from T to 1 total_loss.backward() # Analyze gradient magnitudes at each timestep # The gradient of loss w.r.t. hidden state at time t grad_magnitudes = [] for t in range(seq_length): h_t = hidden_states[:, t, :] if h_t.grad is not None: grad_mag = h_t.grad.norm().item() grad_magnitudes.append(grad_mag) return { 'total_loss': total_loss.item(), 'losses_per_step': losses_per_step, 'grad_magnitudes': grad_magnitudes } def demonstrate_gradient_accumulation(): """ Show how gradients from shared parameters accumulate. """ # Simple RNN input_size, hidden_size = 10, 20 W_hh = torch.randn(hidden_size, hidden_size, requires_grad=True) W_xh = torch.randn(hidden_size, input_size, requires_grad=True) # Input sequence T = 5 x = [torch.randn(1, input_size) for _ in range(T)] h = torch.zeros(1, hidden_size) # Forward pass (storing intermediate values) hidden_states = [] for t in range(T): h = torch.tanh(x[t] @ W_xh.T + h @ W_hh.T) hidden_states.append(h) # Simple loss: sum of final hidden state loss = hidden_states[-1].sum() # Backward loss.backward() print("W_hh gradient shape:", W_hh.grad.shape) print("W_hh gradient norm:", W_hh.grad.norm().item()) print("\nThis gradient is the SUM of contributions from all timesteps") print("Each timestep contributes based on how its h_t affects the loss")Understanding gradient flow through the unfolded RNN is essential for diagnosing training issues. The key quantity is the product of Jacobians that gradients must traverse to flow from time $T$ to time $1$.
The Jacobian Chain:
The gradient of loss at time $T$ with respect to hidden state at time $1$ involves:
$$\frac{\partial \mathcal{L}T}{\partial h_1} = \frac{\partial \mathcal{L}T}{\partial h_T} \cdot \prod{t=1}^{T-1} \frac{\partial h{t+1}}{\partial h_t}$$
Each Jacobian $J_t = \frac{\partial h_{t+1}}{\partial h_t}$ can be written as:
$$J_t = \text{diag}(\sigma'(a_{t+1})) \cdot W_{hh}$$
For tanh activation: $$\sigma'(a) = 1 - \tanh^2(a) \in (0, 1]$$
The gradient product becomes:
$$\prod_{t=1}^{T-1} J_t = \prod_{t=1}^{T-1} \text{diag}(\sigma'(a_{t+1})) \cdot W_{hh}^{T-1}$$
(approximately, since the diagonal matrices differ at each step)
Let $\lambda$ be a typical eigenvalue of $J_t$. The magnitude of the gradient after $k$ steps scales as $|\lambda|^k$. If $|\lambda| < 1$: gradients vanish exponentially (can't learn long-range dependencies). If $|\lambda| > 1$: gradients explode exponentially (training becomes unstable). Only $|\lambda| = 1$ preserves gradient magnitude—a knife-edge condition impossible to maintain exactly.
| Steps Back | |λ| = 0.9 | |λ| = 0.99 | |λ| = 1.01 | |λ| = 1.1 |
|---|---|---|---|---|
| 10 | 0.35 | 0.90 | 1.10 | 2.59 |
| 50 | 0.005 | 0.61 | 1.64 | 117 |
| 100 | 0.00003 | 0.37 | 2.70 | 13,781 |
| 500 | ≈0 | 0.007 | 144 | ≈∞ |
Implications for Learning:
Vanishing gradients (|λ| < 1):
Exploding gradients (|λ| > 1):
The tanh squeeze:
For very long sequences, full BPTT becomes impractical:
Truncated BPTT addresses this by limiting how far back gradients flow:
Truncation Strategies:
| Strategy | $k_1$ | $k_2$ | Description |
|---|---|---|---|
| Full BPTT | $T$ | $T$ | Complete sequence, full gradients |
| $(k, k)$-truncated | $k$ | $k$ | Common choice: process and backprop $k$ steps |
| $(k_1, k_2)$-truncated | $k_1$ | $k_2 < k_1$ | Forward $k_1$, backprop only $k_2$ |
| $(k, 1)$-truncated | $k$ | $1$ | Extreme: only learn from immediate prediction |
Typical values: $k_1 = k_2 \in {32, 64, 128, 256}$
Tradeoffs:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
import torchimport torch.nn as nn def truncated_bptt_training( model, sequence, targets, chunk_size=64, optimizer=None): """ Train RNN using truncated BPTT. Process sequence in chunks, backpropagating within each chunk but detaching hidden state between chunks. Args: model: RNN model sequence: Full input sequence [batch, full_length, input_size] targets: Full target sequence [batch, full_length] chunk_size: Number of timesteps per BPTT chunk optimizer: Optimizer for parameter updates """ batch_size, full_length, _ = sequence.shape total_loss = 0 # Initialize hidden state hidden = torch.zeros(batch_size, model.hidden_size, device=sequence.device) # Process in chunks for start_idx in range(0, full_length, chunk_size): end_idx = min(start_idx + chunk_size, full_length) # Get chunk of sequence chunk_input = sequence[:, start_idx:end_idx, :] chunk_targets = targets[:, start_idx:end_idx] # CRITICAL: Detach hidden state from previous chunk's graph # This prevents gradients from flowing further back hidden = hidden.detach() # Forward pass through chunk outputs, _, hidden = model(chunk_input, h_0=hidden) # Compute loss for this chunk chunk_loss = nn.functional.cross_entropy( outputs.reshape(-1, outputs.size(-1)), chunk_targets.reshape(-1) ) # Backward only through this chunk optimizer.zero_grad() chunk_loss.backward() # Gradient clipping (common with RNNs) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Update parameters optimizer.step() total_loss += chunk_loss.item() * (end_idx - start_idx) return total_loss / full_length class StatefulRNNTrainer: """ Trainer that maintains hidden state across batches. Useful for language modeling on long documents. """ def __init__(self, model, chunk_size=64): self.model = model self.chunk_size = chunk_size self.hidden = None def reset_hidden(self, batch_size, device): """Reset hidden state (e.g., at document boundaries).""" self.hidden = torch.zeros( batch_size, self.model.hidden_size, device=device ) def train_batch(self, batch_input, batch_targets, optimizer): """Train on one batch, maintaining state.""" if self.hidden is None: self.reset_hidden(batch_input.size(0), batch_input.device) # Detach to prevent BPTT across batches self.hidden = self.hidden.detach() # Forward outputs, _, self.hidden = self.model(batch_input, h_0=self.hidden) # Loss and backward loss = nn.functional.cross_entropy( outputs.reshape(-1, outputs.size(-1)), batch_targets.reshape(-1) ) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) optimizer.step() return loss.item()The key to truncated BPTT is hidden.detach(). This operation keeps the hidden state's VALUES but removes it from the computational graph. The hidden state carries information forward (forward pass), but gradients cannot flow backward through it (no backward pass beyond the chunk). This breaks the gradient chain at chunk boundaries.
Understanding the computational costs of RNNs is important for practical deployment and comparison with alternatives.
Time Complexity:
| Operation | Complexity | Notes |
|---|---|---|
| Forward pass (one step) | $O(d_h^2 + d_h d_x)$ | Matrix multiplications |
| Forward pass (full sequence) | $O(T(d_h^2 + d_h d_x))$ | Sequential over $T$ steps |
| Backward pass (full sequence) | $O(T(d_h^2 + d_h d_x))$ | Same as forward |
| Total training iteration | $O(T \cdot d_h^2)$ | Dominated by hidden-hidden |
where $d_x$ is input dimension, $d_h$ is hidden dimension, $T$ is sequence length.
Space Complexity:
| Requirement | Size | Notes |
|---|---|---|
| Parameters | $O(d_h^2 + d_h d_x + d_h d_y)$ | Weight matrices |
| Hidden states (for BPTT) | $O(T \cdot d_h \cdot B)$ | Batch size $B$ |
| Activations (for BPTT) | $O(T \cdot d_h \cdot B)$ | Needed for gradients |
| Inference only | $O(d_h \cdot B)$ | Just current hidden state |
The fundamental bottleneck of RNNs is that hidden state computation is inherently sequential: $h_t$ cannot be computed until $h_{t-1}$ is known. While matrix operations within each step can be parallelized (batching, GPU ops), steps cannot be parallelized with respect to each other. For T=1000, that's 1000 sequential dependencies—limiting GPU utilization compared to Transformers.
Comparison with Transformers:
| Aspect | RNN | Transformer |
|---|---|---|
| Training time complexity | $O(T \cdot d^2)$ | $O(T^2 \cdot d)$ |
| Training parallelization | Low (sequential) | High (fully parallel) |
| Inference (incremental) | $O(d^2)$ per token | $O(T \cdot d)$ per token |
| Memory (inference) | $O(d)$ | $O(T \cdot d)$ (KV cache) |
The key tradeoff:
This explains why Transformers dominate when GPU compute is abundant and sequences are moderate (NLP), while RNNs remain relevant for streaming applications and resource-constrained settings.
The basic RNN structure can be configured differently depending on the task. The relationship between input sequence, hidden states, and output sequence determines the configuration:
Many-to-Many (Sequence-to-Sequence of Same Length)
Input: $(x_1, x_2, \ldots, x_T)$ Output: $(y_1, y_2, \ldots, y_T)$
Each input position produces an output. Used for:
x_1 → h_1 → y_1
↓
x_2 → h_2 → y_2
↓
x_3 → h_3 → y_3
This is the 'synchronous' many-to-many—input and output are aligned.
This page has developed the unfolded computational perspective on RNNs—essential for understanding how they compute and how they're trained.
Module Complete:
This concludes Module 1: Sequence Modeling. We have built a comprehensive foundation:
In the next module, we dive into the RNN Architecture itself—the specific mathematical formulations, implementation details, and the challenges that motivate advanced architectures like LSTM and GRU.
You now have a complete conceptual foundation for understanding recurrent neural networks. You understand the problem they solve (sequence modeling), the framework for training them (autoregressive + BPTT), and the computational structure that underlies both their power and their challenges (the unfolded graph). Next module: the architecture itself.