Loading learning content...
Understanding how data flows through an RNN—from raw input tokens to final predictions—is essential for both implementing and debugging recurrent models. Unlike feedforward networks where data flows in a single pass through layers, RNNs process data through both spatial (layer-wise) and temporal (timestep-wise) dimensions.
The forward pass in an RNN is a careful dance: at each timestep, the network must:
This process repeats for every position in the sequence, with information accumulating and transforming as it flows through time. Understanding this flow is crucial for:
By the end of this page, you will understand: (1) the complete forward pass algorithm step-by-step, (2) how to build and interpret computational graphs for RNNs, (3) different output strategies for various tasks, (4) the relationship between unrolled and rolled representations, and (5) practical implementation considerations including batching and efficiency.
Let's trace through the complete forward computation for a vanilla RNN. Given:
The algorithm:
1. Initialize: h_0 ← zeros or learned initial state
2. For t = 1 to T:
a. Compute linear combination: z_t = W_xh · x_t + W_hh · h_{t-1} + b_h
b. Apply activation: h_t = tanh(z_t)
c. (Optional) Compute output: y_t = W_hy · h_t + b_y
3. Return: Hidden states {h_1, ..., h_T}, Outputs {y_1, ..., y_T}
Let's break down each step in detail.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
import numpy as npfrom typing import List, Tuple, Optional class VanillaRNN: """ Complete vanilla RNN implementation with detailed forward pass. This implementation prioritizes clarity and educational value, exposing every step of the forward computation. """ def __init__( self, input_dim: int, hidden_dim: int, output_dim: Optional[int] = None, activation: str = 'tanh' ): """ Initialize RNN with all necessary parameters. Args: input_dim: Dimension of input vectors (n) hidden_dim: Dimension of hidden state (d) output_dim: Dimension of output (o), None for no output layer activation: Activation function ('tanh' or 'relu') """ self.input_dim = input_dim self.hidden_dim = hidden_dim self.output_dim = output_dim # Initialize weights using Xavier initialization # Input-to-hidden weights scale_xh = np.sqrt(2.0 / (input_dim + hidden_dim)) self.W_xh = np.random.randn(hidden_dim, input_dim) * scale_xh # Hidden-to-hidden weights scale_hh = np.sqrt(2.0 / (2 * hidden_dim)) self.W_hh = np.random.randn(hidden_dim, hidden_dim) * scale_hh # Hidden bias self.b_h = np.zeros((hidden_dim, 1)) # Optional output layer if output_dim is not None: scale_hy = np.sqrt(2.0 / (hidden_dim + output_dim)) self.W_hy = np.random.randn(output_dim, hidden_dim) * scale_hy self.b_y = np.zeros((output_dim, 1)) else: self.W_hy = None self.b_y = None # Activation function if activation == 'tanh': self.activation = np.tanh elif activation == 'relu': self.activation = lambda x: np.maximum(0, x) else: raise ValueError(f"Unknown activation: {activation}") def forward_step( self, x_t: np.ndarray, h_prev: np.ndarray, return_intermediates: bool = False ) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[dict]]: """ Compute single timestep forward pass. Args: x_t: Input at time t, shape (input_dim, 1) or (input_dim,) h_prev: Hidden state from t-1, shape (hidden_dim, 1) return_intermediates: If True, return intermediate computations Returns: h_t: New hidden state, shape (hidden_dim, 1) y_t: Output (if output layer exists), shape (output_dim, 1) intermediates: Dictionary of intermediate values (if requested) """ # Ensure column vector format if x_t.ndim == 1: x_t = x_t.reshape(-1, 1) # Step 2a: Compute linear combination # Separate the components for clarity input_contribution = self.W_xh @ x_t # (hidden_dim, 1) memory_contribution = self.W_hh @ h_prev # (hidden_dim, 1) z_t = input_contribution + memory_contribution + self.b_h # Pre-activation # Step 2b: Apply activation h_t = self.activation(z_t) # (hidden_dim, 1) # Step 2c: Optional output computation y_t = None if self.W_hy is not None: y_t = self.W_hy @ h_t + self.b_y # (output_dim, 1) if return_intermediates: intermediates = { 'input_contribution': input_contribution, 'memory_contribution': memory_contribution, 'z_t': z_t, 'h_t': h_t, 'y_t': y_t, } return h_t, y_t, intermediates return h_t, y_t, None def forward_sequence( self, X: np.ndarray, h_0: Optional[np.ndarray] = None, return_all_intermediates: bool = False ) -> dict: """ Complete forward pass over a sequence. Args: X: Input sequence, shape (seq_len, input_dim) or (seq_len, input_dim, 1) h_0: Initial hidden state, defaults to zeros return_all_intermediates: If True, return all intermediate computations Returns: Dictionary containing: - hidden_states: List of hidden states [h_1, ..., h_T] - outputs: List of outputs [y_1, ..., y_T] if output layer exists - final_hidden: Final hidden state h_T - intermediates: List of intermediate dicts (if requested) """ seq_len = X.shape[0] # Step 1: Initialize hidden state if h_0 is None: h_0 = np.zeros((self.hidden_dim, 1)) # Storage for results hidden_states = [] outputs = [] all_intermediates = [] # Current hidden state h_t = h_0 # Step 2: Process each timestep for t in range(seq_len): x_t = X[t] h_t, y_t, intermediates = self.forward_step( x_t, h_t, return_intermediates=return_all_intermediates ) hidden_states.append(h_t.copy()) if y_t is not None: outputs.append(y_t.copy()) if intermediates is not None: intermediates['t'] = t all_intermediates.append(intermediates) # Step 3: Return results result = { 'hidden_states': hidden_states, 'final_hidden': h_t, } if outputs: result['outputs'] = outputs if all_intermediates: result['intermediates'] = all_intermediates return result def trace_forward_pass(): """ Demonstrate and trace a complete forward pass with detailed output. """ np.random.seed(42) # Create small RNN for visibility rnn = VanillaRNN(input_dim=3, hidden_dim=4, output_dim=2) # Create short sequence seq_len = 5 X = np.random.randn(seq_len, 3) print("=" * 70) print("COMPLETE FORWARD PASS TRACE") print("=" * 70) print(f"RNN Configuration:") print(f" Input dimension: {rnn.input_dim}") print(f" Hidden dimension: {rnn.hidden_dim}") print(f" Output dimension: {rnn.output_dim}") print(f" Sequence length: {seq_len}") # Forward pass with intermediates result = rnn.forward_sequence(X, return_all_intermediates=True) print("" + "-" * 70) print("STEP-BY-STEP COMPUTATION") print("-" * 70) for t, inter in enumerate(result['intermediates']): print(f">>> Timestep t = {t + 1}") print(f" Input x_{t+1}: {X[t].round(3)}") print(f" Previous hidden h_{t}: norm = {np.linalg.norm(result['hidden_states'][t-1] if t > 0 else np.zeros((4,1))):.4f}") print(f" Input contribution (W_xh @ x): norm = {np.linalg.norm(inter['input_contribution']):.4f}") print(f" Memory contribution (W_hh @ h): norm = {np.linalg.norm(inter['memory_contribution']):.4f}") print(f" Pre-activation z_{t+1}: range = [{inter['z_t'].min():.3f}, {inter['z_t'].max():.3f}]") print(f" Hidden state h_{t+1}: range = [{inter['h_t'].min():.3f}, {inter['h_t'].max():.3f}]") print(f" Output y_{t+1}: {inter['y_t'].squeeze().round(3)}") print("" + "-" * 70) print("FINAL RESULTS") print("-" * 70) print(f"Final hidden state h_T: {result['final_hidden'].squeeze().round(3)}") print(f"All hidden state norms: {[np.linalg.norm(h) for h in result['hidden_states']]}") if __name__ == "__main__": trace_forward_pass()Understanding the computational graph is essential for reasoning about gradient flow, memory usage, and automatic differentiation. RNN computational graphs have a unique structure that reflects their temporal nature.
Rolled vs Unrolled representations:
An RNN can be visualized in two equivalent ways:
1. Rolled representation (compact): Shows a single cell with a self-loop representing the recurrent connection. This emphasizes parameter sharing—there's only one set of weights.
2. Unrolled representation (explicit): Shows the network 'unfolded' through time, with one copy of the cell for each timestep. Edges between timesteps show the hidden state flow. This is what actually happens during computation.
The unrolled computational graph:
x_1 x_2 x_3 x_T
| | | |
v v v v
h_0 → [RNN Cell] → [RNN Cell] → [RNN Cell] → ... → [RNN Cell]
| | | |
v v v v
h_1 h_2 h_3 h_T
| | | |
v v v v
y_1 y_2 y_3 y_T
Each [RNN Cell] uses the same parameters—this is the parameter sharing. But in the computational graph, we see $T$ separate computations.
For gradient computation, we must reason about the unrolled graph. The gradient of the loss with respect to W_hh flows through every timestep, accumulating contributions. Understanding this graph structure is essential for understanding backpropagation through time (BPTT).
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
import torchimport torch.nn as nnfrom graphviz import Digraph class RNNForward: """ RNN implementation that explicitly builds the computational graph. Uses PyTorch's autograd to track dependencies, demonstrating how the graph is constructed during forward computation. """ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): self.input_dim = input_dim self.hidden_dim = hidden_dim self.output_dim = output_dim # Parameters with requires_grad for graph tracking self.W_xh = nn.Parameter(torch.randn(hidden_dim, input_dim) * 0.01) self.W_hh = nn.Parameter(torch.randn(hidden_dim, hidden_dim) * 0.01) self.b_h = nn.Parameter(torch.zeros(hidden_dim)) self.W_hy = nn.Parameter(torch.randn(output_dim, hidden_dim) * 0.01) self.b_y = nn.Parameter(torch.zeros(output_dim)) def forward(self, X: torch.Tensor, trace_graph: bool = False): """ Forward pass with optional graph tracing. Args: X: Input sequence (seq_len, input_dim) trace_graph: If True, return graph structure Returns: outputs: Output predictions graph_info: Graph structure (if trace_graph=True) """ seq_len = X.shape[0] h_t = torch.zeros(self.hidden_dim) outputs = [] graph_nodes = [] if trace_graph else None graph_edges = [] if trace_graph else None for t in range(seq_len): x_t = X[t] # Forward step with explicit operations Wx = torch.matmul(self.W_xh, x_t) # Input projection Wh = torch.matmul(self.W_hh, h_t) # Hidden projection z = Wx + Wh + self.b_h # Linear combination h_t = torch.tanh(z) # Activation y_t = torch.matmul(self.W_hy, h_t) + self.b_y # Output outputs.append(y_t) if trace_graph: # Record nodes and edges graph_nodes.extend([ f'x_{t+1}', f'Wx_{t+1}', f'Wh_{t+1}', f'z_{t+1}', f'h_{t+1}', f'y_{t+1}' ]) graph_edges.extend([ (f'x_{t+1}', f'Wx_{t+1}'), (f'Wx_{t+1}', f'z_{t+1}'), (f'Wh_{t+1}', f'z_{t+1}'), (f'z_{t+1}', f'h_{t+1}'), (f'h_{t+1}', f'y_{t+1}'), ]) if t > 0: graph_edges.append((f'h_{t}', f'Wh_{t+1}')) else: graph_nodes.append('h_0') graph_edges.append(('h_0', f'Wh_{t+1}')) if trace_graph: return torch.stack(outputs), {'nodes': graph_nodes, 'edges': graph_edges} return torch.stack(outputs) def visualize_computational_graph(seq_len: int = 3): """ Create a visual representation of the RNN computational graph. """ dot = Digraph(comment='RNN Computational Graph') dot.attr(rankdir='LR') # Left to right dot.attr('node', shape='box') # Create subgraphs for each timestep for t in range(seq_len): with dot.subgraph(name=f'cluster_{t}') as c: c.attr(label=f'Timestep {t+1}') # Input c.node(f'x_{t+1}', f'x_{t+1}', shape='ellipse', style='filled', fillcolor='lightblue') # Hidden state operations c.node(f'Wx_{t+1}', f'W_xh·x_{t+1}') c.node(f'Wh_{t+1}', f'W_hh·h_{t}') c.node(f'z_{t+1}', f'+') c.node(f'tanh_{t+1}', 'tanh') c.node(f'h_{t+1}', f'h_{t+1}', style='filled', fillcolor='lightgreen') c.node(f'y_{t+1}', f'y_{t+1}', shape='ellipse', style='filled', fillcolor='lightyellow') # Internal edges c.edge(f'x_{t+1}', f'Wx_{t+1}') c.edge(f'Wx_{t+1}', f'z_{t+1}') c.edge(f'Wh_{t+1}', f'z_{t+1}') c.edge(f'z_{t+1}', f'tanh_{t+1}') c.edge(f'tanh_{t+1}', f'h_{t+1}') c.edge(f'h_{t+1}', f'y_{t+1}') # Initial hidden state dot.node('h_0', 'h_0 (zeros)', style='filled', fillcolor='lightgray') dot.edge('h_0', 'Wh_1') # Cross-timestep edges (recurrent connections) for t in range(seq_len - 1): dot.edge(f'h_{t+1}', f'Wh_{t+2}', color='red', style='bold') # Parameter annotations dot.node('W_xh', 'W_xh(shared)', shape='box', style='dashed') dot.node('W_hh', 'W_hh(shared)', shape='box', style='dashed', color='red') dot.node('W_hy', 'W_hy(shared)', shape='box', style='dashed') return dot def analyze_graph_structure(): """ Analyze the computational graph structure of an RNN. """ input_dim, hidden_dim, output_dim = 4, 8, 2 seq_len = 5 rnn = RNNForward(input_dim, hidden_dim, output_dim) X = torch.randn(seq_len, input_dim) outputs, graph_info = rnn.forward(X, trace_graph=True) print("Computational Graph Analysis") print("=" * 50) print(f"Sequence length: {seq_len}") print(f"Total nodes: {len(graph_info['nodes'])}") print(f"Total edges: {len(graph_info['edges'])}") print(f"Nodes per timestep: {len(graph_info['nodes']) // seq_len}") print(f"Edges per timestep: ~{len(graph_info['edges']) // seq_len}") # Count edge types recurrent_edges = [(s, t) for s, t in graph_info['edges'] if s.startswith('h_') and t.startswith('Wh_')] print(f"Recurrent edges (h_t -> Wh_t+1): {len(recurrent_edges)}") print("Graph structure visualization saved to 'rnn_graph.pdf'") graph = visualize_computational_graph(3) graph.render('rnn_graph', format='pdf', cleanup=True) if __name__ == "__main__": analyze_graph_structure()Key properties of the RNN computational graph:
RNNs can produce outputs in different ways depending on the task. The choice of output strategy fundamentally affects the network's behavior and the loss computation.
Common output patterns:
| Pattern | Description | Example Tasks | Output Used |
|---|---|---|---|
| Many-to-One | Entire sequence → single output | Sentiment analysis, document classification | h_T (final hidden) |
| Many-to-Many (aligned) | Each input → corresponding output | POS tagging, NER, frame-level video | y_1, y_2, ..., y_T |
| Many-to-Many (unaligned) | Input sequence → output sequence (different lengths) | Machine translation, summarization | Encoder-decoder architecture |
| One-to-Many | Single input → sequence output | Image captioning, music generation | Generate until <END> |
Many-to-One: Sequence Classification
The entire sequence is compressed into the final hidden state, which is used for prediction:
$$\hat{y} = \text{softmax}(W_{hy} h_T + b_y)$$
Consideration: All information must be preserved in $h_T$. For long sequences, early information may be lost.
Many-to-Many (aligned): Sequence Labeling
Every timestep produces an output:
$$\hat{y}t = \text{softmax}(W{hy} h_t + b_y) \quad \forall t \in {1, \ldots, T}$$
Consideration: Decisions at time $t$ see only past context ($x_1, \ldots, x_t$). For tasks needing future context, use bidirectional RNNs.
Many-to-Many (unaligned): Encoder-Decoder
Encoder processes input, final state initializes decoder:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
import torchimport torch.nn as nnimport torch.nn.functional as F class ManyToOneRNN(nn.Module): """ RNN for sequence classification (many-to-one). Uses only the final hidden state for prediction. """ def __init__(self, input_dim, hidden_dim, num_classes): super().__init__() self.rnn = nn.RNN(input_dim, hidden_dim, batch_first=True) self.classifier = nn.Linear(hidden_dim, num_classes) def forward(self, x): """ Args: x: (batch, seq_len, input_dim) Returns: logits: (batch, num_classes) """ # Process entire sequence outputs, h_n = self.rnn(x) # outputs: (batch, seq, hidden) # Use FINAL hidden state only final_hidden = outputs[:, -1, :] # (batch, hidden) # Classify logits = self.classifier(final_hidden) return logits class ManyToManyAlignedRNN(nn.Module): """ RNN for sequence labeling (many-to-many, aligned). Each timestep produces a prediction. """ def __init__(self, input_dim, hidden_dim, num_classes): super().__init__() self.rnn = nn.RNN(input_dim, hidden_dim, batch_first=True) self.classifier = nn.Linear(hidden_dim, num_classes) def forward(self, x): """ Args: x: (batch, seq_len, input_dim) Returns: logits: (batch, seq_len, num_classes) """ # Process sequence outputs, _ = self.rnn(x) # (batch, seq, hidden) # Classify EVERY timestep logits = self.classifier(outputs) # (batch, seq, num_classes) return logits class EncoderDecoderRNN(nn.Module): """ RNN for sequence-to-sequence (many-to-many, unaligned). Encoder compresses input, decoder generates output. """ def __init__(self, input_vocab_size, output_vocab_size, embed_dim, hidden_dim): super().__init__() # Encoder self.encoder_embed = nn.Embedding(input_vocab_size, embed_dim) self.encoder_rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True) # Decoder self.decoder_embed = nn.Embedding(output_vocab_size, embed_dim) self.decoder_rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True) self.output_layer = nn.Linear(hidden_dim, output_vocab_size) def encode(self, src): """Encode source sequence.""" embedded = self.encoder_embed(src) # (batch, src_len, embed) _, h_n = self.encoder_rnn(embedded) # h_n: (1, batch, hidden) return h_n def decode_step(self, token, hidden): """Single decoding step.""" embedded = self.decoder_embed(token).unsqueeze(1) # (batch, 1, embed) output, hidden = self.decoder_rnn(embedded, hidden) logits = self.output_layer(output.squeeze(1)) # (batch, vocab) return logits, hidden def forward(self, src, tgt, teacher_forcing_ratio=0.5): """ Full forward pass with optional teacher forcing. Args: src: Source tokens (batch, src_len) tgt: Target tokens (batch, tgt_len) teacher_forcing_ratio: Probability of using ground truth """ batch_size = src.shape[0] tgt_len = tgt.shape[1] vocab_size = self.output_layer.out_features # Encode source hidden = self.encode(src) # Decode outputs = torch.zeros(batch_size, tgt_len, vocab_size) input_token = tgt[:, 0] # Start token for t in range(1, tgt_len): logits, hidden = self.decode_step(input_token, hidden) outputs[:, t] = logits # Teacher forcing decision teacher_force = torch.rand(1).item() < teacher_forcing_ratio top1 = logits.argmax(1) input_token = tgt[:, t] if teacher_force else top1 return outputs class OneToManyRNN(nn.Module): """ RNN for generation from single input (one-to-many). Used for tasks like image captioning. """ def __init__(self, context_dim, embed_dim, hidden_dim, vocab_size): super().__init__() # Project context to initial hidden state self.context_projection = nn.Linear(context_dim, hidden_dim) # Decoder self.embedding = nn.Embedding(vocab_size, embed_dim) self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True) self.output_layer = nn.Linear(hidden_dim, vocab_size) self.vocab_size = vocab_size def forward(self, context, max_len=50, start_token=1, end_token=2): """ Generate sequence from context vector. Args: context: Context vector (batch, context_dim) max_len: Maximum generation length start_token: Token ID for sequence start end_token: Token ID for sequence end """ batch_size = context.shape[0] # Initialize hidden state from context hidden = self.context_projection(context).unsqueeze(0) # (1, batch, hidden) hidden = torch.tanh(hidden) # Start with start token generated = [torch.full((batch_size,), start_token, dtype=torch.long)] token = generated[0] for _ in range(max_len - 1): embedded = self.embedding(token).unsqueeze(1) # (batch, 1, embed) output, hidden = self.rnn(embedded, hidden) logits = self.output_layer(output.squeeze(1)) # (batch, vocab) # Sample next token probs = F.softmax(logits, dim=-1) token = torch.multinomial(probs, 1).squeeze(1) # (batch,) generated.append(token) # Check for end token (simplified: stop if all ended) if (token == end_token).all(): break return torch.stack(generated, dim=1) # Demonstrationdef demonstrate_output_strategies(): """Show different output strategies in action.""" torch.manual_seed(42) batch_size = 2 seq_len = 10 input_dim = 8 hidden_dim = 16 num_classes = 3 x = torch.randn(batch_size, seq_len, input_dim) print("Output Strategy Demonstrations") print("=" * 60) # Many-to-One m2o = ManyToOneRNN(input_dim, hidden_dim, num_classes) out_m2o = m2o(x) print(f"Many-to-One (sequence classification):") print(f" Input shape: {x.shape}") print(f" Output shape: {out_m2o.shape}") print(f" Interpretation: One class prediction per sequence") # Many-to-Many Aligned m2m = ManyToManyAlignedRNN(input_dim, hidden_dim, num_classes) out_m2m = m2m(x) print(f"Many-to-Many Aligned (sequence labeling):") print(f" Input shape: {x.shape}") print(f" Output shape: {out_m2m.shape}") print(f" Interpretation: One class prediction per timestep") # One-to-Many context = torch.randn(batch_size, 64) o2m = OneToManyRNN(context_dim=64, embed_dim=16, hidden_dim=hidden_dim, vocab_size=100) out_o2m = o2m(context, max_len=15) print(f"One-to-Many (generation from context):") print(f" Context shape: {context.shape}") print(f" Generated shape: {out_o2m.shape}") print(f" Interpretation: Variable length sequence from single input") if __name__ == "__main__": demonstrate_output_strategies()In practice, RNNs process batches of sequences for computational efficiency. This introduces complexity due to variable-length sequences within a batch.
The padding problem:
Sequences in a batch may have different lengths:
To process as a batch, we must:
Padding strategies:
PackedSequence (PyTorch optimization):
PyTorch provides pack_padded_sequence and pad_packed_sequence to avoid computing on padding tokens:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
import torchimport torch.nn as nnfrom torch.nn.utils.rnn import ( pad_sequence, pack_padded_sequence, pad_packed_sequence) class EfficientBatchedRNN(nn.Module): """ RNN with efficient batched processing using PackedSequence. """ def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.rnn = nn.RNN(input_dim, hidden_dim, batch_first=True) self.output = nn.Linear(hidden_dim, output_dim) def forward(self, sequences: list, lengths: torch.Tensor): """ Forward pass with variable-length sequences. Args: sequences: Padded tensor (batch, max_len, input_dim) lengths: Actual lengths of each sequence (batch,) """ # Sort by length (required for pack_padded_sequence) lengths_sorted, sort_idx = lengths.sort(descending=True) sequences_sorted = sequences[sort_idx] # Pack the padded sequence packed = pack_padded_sequence( sequences_sorted, lengths_sorted.cpu(), # lengths must be on CPU batch_first=True ) # Process with RNN (efficiently - no padding computation) packed_output, h_n = self.rnn(packed) # Unpack back to padded tensor output, _ = pad_packed_sequence(packed_output, batch_first=True) # Restore original order _, unsort_idx = sort_idx.sort() output = output[unsort_idx] h_n = h_n[:, unsort_idx] # Apply output layer return self.output(output), h_n def forward_masked(self, sequences: torch.Tensor, mask: torch.Tensor): """ Alternative: forward pass with mask (simpler but less efficient). Args: sequences: Padded tensor (batch, max_len, input_dim) mask: Boolean mask, True for valid positions (batch, max_len) """ # Standard forward pass output, h_n = self.rnn(sequences) # Apply output layer output = self.output(output) # Zero out padded positions (for loss computation) output = output * mask.unsqueeze(-1).float() return output, h_n def demonstrate_batched_processing(): """ Demonstrate batched forward computation with variable lengths. """ torch.manual_seed(42) input_dim = 8 hidden_dim = 16 output_dim = 4 # Create sequences of different lengths seq1 = torch.randn(5, input_dim) # Length 5 seq2 = torch.randn(3, input_dim) # Length 3 seq3 = torch.randn(7, input_dim) # Length 7 seq4 = torch.randn(2, input_dim) # Length 2 # Pad sequences to create batch sequences = [seq1, seq2, seq3, seq4] padded_batch = pad_sequence(sequences, batch_first=True, padding_value=0.0) lengths = torch.tensor([5, 3, 7, 2]) print("Batched Processing Demonstration") print("=" * 60) print(f"Original sequence lengths: {lengths.tolist()}") print(f"Padded batch shape: {padded_batch.shape}") print(f"Max length (after padding): {padded_batch.shape[1]}") # Process with efficient batched RNN model = EfficientBatchedRNN(input_dim, hidden_dim, output_dim) output, final_hidden = model(padded_batch, lengths) print(f"Output shape: {output.shape}") print(f"Final hidden shape: {final_hidden.shape}") # Verify that padding doesn't affect hidden state # by comparing single-sequence processing vs batched print("" + "-" * 40) print("Verification: Single vs. Batched Processing") single_model = nn.RNN(input_dim, hidden_dim, batch_first=True) # Copy weights to ensure same model with torch.no_grad(): single_model.weight_ih_l0.copy_(model.rnn.weight_ih_l0) single_model.weight_hh_l0.copy_(model.rnn.weight_hh_l0) single_model.bias_ih_l0.copy_(model.rnn.bias_ih_l0) single_model.bias_hh_l0.copy_(model.rnn.bias_hh_l0) # Process sequence 0 individually single_out, single_h = single_model(seq1.unsqueeze(0)) # The first output in batch (correctly unpacked) should match # We need to check up to the original length packed = pack_padded_sequence( padded_batch, lengths.sort(descending=True)[0].cpu(), batch_first=True ) batch_out, batch_h = single_model(packed) batch_out_unpacked, _ = pad_packed_sequence(batch_out, batch_first=True) # Find seq1 in sorted order (length 5) sorted_lens, sort_idx = lengths.sort(descending=True) seq1_idx_in_sorted = (sort_idx == 0).nonzero().item() print(f"Single sequence output shape: {single_out.shape}") print(f"Extracted from batch shape: {batch_out_unpacked[seq1_idx_in_sorted, :5].shape}") # Compare hidden states diff = (single_h - batch_h[:, seq1_idx_in_sorted:seq1_idx_in_sorted+1]).abs().max() print(f"Hidden state difference: {diff.item():.2e} (should be ~0)") def explain_packing(): """ Visual explanation of how PackedSequence works. """ print("" + "=" * 60) print("HOW PACKEDSEQUENCE WORKS") print("=" * 60) print(""" Original padded batch (batch=4, max_len=7): Seq 0: [a a a a a 0 0] (length 5) Seq 1: [b b b 0 0 0 0] (length 3) Seq 2: [c c c c c c c] (length 7) Seq 3: [d d 0 0 0 0 0] (length 2) After sorting by length (descending): Seq 2: [c c c c c c c] (length 7) Seq 0: [a a a a a 0 0] (length 5) Seq 1: [b b b 0 0 0 0] (length 3) Seq 3: [d d 0 0 0 0 0] (length 2) PackedSequence data (concatenated valid tokens): time 0: [c, a, b, d] (all 4 sequences have token at t=0) time 1: [c, a, b, d] (all 4) time 2: [c, a, b] (3 sequences: seq3 ended) time 3: [c, a] (2 sequences: seq1 ended) time 4: [c, a] (2 sequences) time 5: [c] (1 sequence: seq0 ended) time 6: [c] (1 sequence) batch_sizes = [4, 4, 3, 2, 2, 1, 1] This allows RNN to process variable sizes at each step! """) if __name__ == "__main__": demonstrate_batched_processing() explain_packing()Use PackedSequence when: (1) you have highly variable sequence lengths, (2) you need the final hidden state for each sequence (at its true end, not at max_len), or (3) efficiency is critical. For simple cases with mostly similar lengths, masked processing may be simpler to implement.
Understanding the computational costs of RNN forward passes is crucial for efficient implementation and scaling.
Per-timestep computation:
At each timestep, the main operations are:
Total per timestep: $O(d^2 + dn + od)$
For sequence of length T: $O(T \cdot (d^2 + dn + od))$
The hidden-to-hidden term $O(d^2)$ typically dominates since $d$ is often larger than $n$ or $o$.
| Resource | Complexity | Notes |
|---|---|---|
| Time (forward pass) | O(T × d²) | Sequential, cannot parallelize across T |
| Time (per batch) | O(B × T × d²) | Can parallelize across B |
| Memory (activations) | O(T × d) | Must store all h_t for backprop |
| Memory (parameters) | O(d² + dn + od) | Independent of sequence length |
| Memory (batch) | O(B × T × d) | Activations scale with batch and length |
Memory considerations during training:
For backpropagation through time (BPTT), we must store all intermediate activations:
For very long sequences, this memory requirement becomes prohibitive. Solutions include:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
import torchimport torch.nn as nnimport timeimport gc def profile_rnn_forward(hidden_dims=[128, 256, 512, 1024], seq_lengths=[50, 100, 200, 500], batch_size=32, input_dim=100): """ Profile RNN forward pass computational cost. """ print("RNN Forward Pass Profiling") print("=" * 70) print(f"Batch size: {batch_size}, Input dim: {input_dim}") print("-" * 70) print(f"{'Hidden Dim':>12} {'Seq Len':>10} {'Time (ms)':>12} {'Memory (MB)':>14}") print("-" * 70) results = [] for hidden_dim in hidden_dims: for seq_len in seq_lengths: # Clear memory gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Create model and input rnn = nn.RNN(input_dim, hidden_dim, batch_first=True) x = torch.randn(batch_size, seq_len, input_dim) if torch.cuda.is_available(): rnn = rnn.cuda() x = x.cuda() torch.cuda.synchronize() # Warmup for _ in range(3): _ = rnn(x) # Time forward pass if torch.cuda.is_available(): torch.cuda.synchronize() start = time.perf_counter() for _ in range(10): output, h_n = rnn(x) if torch.cuda.is_available(): torch.cuda.synchronize() elapsed = (time.perf_counter() - start) / 10 * 1000 # ms # Estimate memory (activations) # Hidden states: batch × seq × hidden × 4 bytes (float32) activation_memory = batch_size * seq_len * hidden_dim * 4 / 1e6 # MB # Parameter memory param_memory = sum(p.numel() * 4 for p in rnn.parameters()) / 1e6 total_memory = activation_memory + param_memory print(f"{hidden_dim:>12} {seq_len:>10} {elapsed:>12.2f} {total_memory:>14.2f}") results.append({ 'hidden_dim': hidden_dim, 'seq_len': seq_len, 'time_ms': elapsed, 'memory_mb': total_memory }) del rnn, x, output, h_n return results def analyze_sequential_bottleneck(): """ Demonstrate why RNNs cannot parallelize across timesteps. """ print("" + "=" * 70) print("SEQUENTIAL BOTTLENECK ANALYSIS") print("=" * 70) hidden_dim = 256 input_dim = 100 batch_size = 32 rnn_cell = nn.RNNCell(input_dim, hidden_dim) # Simulate processing: each timestep depends on the previous print(""" RNNs have an inherent sequential dependency: h_1 = f(x_1, h_0) ↓ h_2 = f(x_2, h_1) ← Must wait for h_1 ↓ h_3 = f(x_3, h_2) ← Must wait for h_2 ↓ ... This means: - We CANNOT compute h_3 until h_2 is done - We CANNOT parallelize across timesteps - Processing time scales linearly with sequence length Compare to feedforward networks: - All inputs can be processed in parallel - GPUs can process entire batch in one operation This is why Transformers (with self-attention) are faster for long sequences—they allow parallel processing across positions. """) # Demonstrate timing seq_lengths = [10, 50, 100, 200] print("Timing demonstration:") for seq_len in seq_lengths: x = torch.randn(batch_size, seq_len, input_dim) h = torch.zeros(batch_size, hidden_dim) start = time.perf_counter() for t in range(seq_len): h = rnn_cell(x[:, t, :], h) # Sequential! elapsed = (time.perf_counter() - start) * 1000 print(f" Seq len {seq_len:3d}: {elapsed:.2f} ms (linear with length)") def memory_during_training(): """ Analyze memory requirements during training (storing activations). """ print("" + "=" * 70) print("MEMORY REQUIREMENTS FOR TRAINING") print("=" * 70) print(""" During training, we must store activations for backpropagation: Forward pass stores: - All hidden states h_1, ..., h_T : O(B × T × d) × 4 bytes - All pre-activations z_1, ..., z_T : O(B × T × d) × 4 bytes - Intermediate computation results : Additional overhead Example calculation: - Batch size B = 32 - Sequence length T = 500 - Hidden dimension d = 512 Hidden states memory: 32 × 500 × 512 × 4 bytes = 32.77 MB With pre-activations: ~65 MB With gradients during backward: ~130 MB For long sequences (T = 5000): 32 × 5000 × 512 × 4 bytes = 327 MB (just hidden states!) This is why gradient checkpointing and truncated BPTT are important. """) # Concrete example for seq_len in [100, 500, 1000, 5000]: batch_size = 32 hidden_dim = 512 # Memory for hidden states (float32) hidden_mem = batch_size * seq_len * hidden_dim * 4 / 1e6 # MB # Approximate total training memory (with gradients, ~3x) training_mem = hidden_mem * 3 print(f" T={seq_len:4d}: ~{hidden_mem:.1f} MB activations, " f"~{training_mem:.1f} MB training") if __name__ == "__main__": # Run profiling on CPU (works everywhere) profile_rnn_forward( hidden_dims=[128, 256], seq_lengths=[50, 100], batch_size=16 ) analyze_sequential_bottleneck() memory_during_training()Unlike CNNs or Transformers, RNNs cannot parallelize across timesteps because each h_t depends on h_{t-1}. This inherent sequential dependency means RNN processing time scales linearly with sequence length, regardless of available compute. This is a fundamental limitation that motivated the development of parallel alternatives like Transformers.
We've traced the complete forward computation through an RNN. Let's consolidate the key insights:
What's next:
Now that we understand how data flows forward through an RNN, the next page explores the reverse direction: Backpropagation Through Time (BPTT). We'll see how gradients flow backward through the unrolled computational graph, understand the chain rule across timesteps, and discover why this process is both powerful and problematic.
You now have a complete understanding of RNN forward computation. You can trace data through the network, construct computational graphs, choose appropriate output strategies, handle batched variable-length inputs, and reason about computational costs. This foundation is essential for understanding training dynamics.