Loading content...
Consider a seemingly simple task: classify whether a sentence is positive or negative. Sentences vary in length—some have 5 words, others have 50. How do you build a neural network that can handle both?
The naive approach fails immediately:
With a standard feedforward network, you'd need to define a fixed input size. A network designed for 50-word sentences would:
Parameter sharing in time solves all these problems. By using the same weights at every timestep, RNNs can:
This architectural decision is as fundamental to RNNs as convolutional weight sharing is to CNNs. Understanding it deeply is essential for understanding why recurrent architectures work.
By the end of this page, you will understand: (1) why parameter sharing is necessary for sequence modeling, (2) the mathematical formalization of temporal weight sharing, (3) how it compares to spatial weight sharing in CNNs, (4) the statistical and computational benefits, and (5) the limitations and when parameter sharing assumptions break down.
To appreciate parameter sharing, we must first understand what happens without it. Let's formalize the challenge.
Scenario: Sequence classification without weight sharing
Suppose we want to classify sequences of length $T$ where each input $x_t$ has dimension $n$. The naive approach uses different weights for each timestep:
$$h_t = f(W_{hh}^{(t)} h_{t-1} + W_{xh}^{(t)} x_t + b_h^{(t)})$$
Parameter count:
Total parameters for T timesteps: $$T \times (d^2 + dn + d) = T \times d(d + n + 1)$$
For a typical model with $d = 256$, $n = 128$, and supporting sequences up to $T = 100$: $$100 \times 256 \times (256 + 128 + 1) = 100 \times 256 \times 385 = 9,856,000 \text{ parameters}$$
That's nearly 10 million parameters just for the recurrent connections—before adding embedding layers, output layers, or any other components.
Without parameter sharing, parameter count scales linearly with maximum sequence length. Want to support sequences of length 1000 instead of 100? That's 10x more parameters. This is statistically disastrous—you'd need proportionally more training data, suffer severe overfitting, and face computational explosion.
Beyond the parameter count—the learning problem:
Even if we could afford the parameters, the learning problem without sharing is fundamentally harder:
No positional transfer: The network learns that 'the' at position 1 predicts a noun coming, but doesn't transfer this knowledge to 'the' at position 10. Every position learns independently.
Sparse gradients: If 'the' appears at position 5 in training but position 7 at test time, the weights at position 7 never saw gradients from 'the' examples.
Length generalization failure: A network trained on sequences up to length 50 has no weights for position 51. It cannot process longer sequences at all.
Massive data requirements: To learn that 'good' indicates positive sentiment, you need examples where 'good' appears at every possible position.
Parameter sharing resolves all of these by asserting: the same transformation should apply at every timestep.
| Approach | Parameters | T=10 | T=100 | T=1000 |
|---|---|---|---|---|
| Unshared | T × d(d+n+1) | ~1M | ~10M | ~100M |
| Shared | d(d+n+1) | ~100K | ~100K | ~100K |
| Ratio | T : 1 | 10x savings | 100x savings | 1000x savings |
With parameter sharing, we use identical weights across all timesteps:
$$h_t = f(W_{hh} h_{t-1} + W_{xh} x_t + b_h) \quad \forall t \in {1, 2, \ldots, T}$$
Note the crucial difference: $W_{hh}$, $W_{xh}$, and $b_h$ have no timestep superscript. They're the same matrices at every step.
What this means mathematically:
The RNN applies the same function at every timestep: $$g_\theta(h, x) = f(W_{hh} h + W_{xh} x + b_h)$$
where $\theta = {W_{hh}, W_{xh}, b_h}$ are the shared parameters.
The sequence processing becomes: $$h_1 = g_\theta(h_0, x_1)$$ $$h_2 = g_\theta(h_1, x_2) = g_\theta(g_\theta(h_0, x_1), x_2)$$ $$\vdots$$ $$h_T = g_\theta(h_{T-1}, x_T) = g_\theta^T(h_0, x_{1:T})$$
where $g_\theta^T$ denotes the $T$-fold composition of $g_\theta$.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
import torchimport torch.nn as nn class ExplicitWeightSharingRNN(nn.Module): """ RNN implementation that makes weight sharing explicit. This implementation manually applies the same weights at each step, demonstrating that the same parameters are used throughout. """ def __init__(self, input_dim: int, hidden_dim: int): super().__init__() # These are the ONLY recurrent parameters # They will be used at EVERY timestep self.W_xh = nn.Parameter(torch.randn(hidden_dim, input_dim) * 0.01) self.W_hh = nn.Parameter(torch.randn(hidden_dim, hidden_dim) * 0.01) self.b_h = nn.Parameter(torch.zeros(hidden_dim)) self.hidden_dim = hidden_dim def step(self, x_t: torch.Tensor, h_prev: torch.Tensor) -> torch.Tensor: """ Single timestep update using SHARED parameters. This function is called T times for a sequence of length T, each time using the exact same W_xh, W_hh, and b_h. """ # Notice: we use self.W_xh, self.W_hh, self.b_h # These references point to the same tensors every time pre_activation = ( torch.mm(x_t, self.W_xh.t()) + # (batch, hidden) torch.mm(h_prev, self.W_hh.t()) + # (batch, hidden) self.b_h # broadcasts to (batch, hidden) ) return torch.tanh(pre_activation) def forward(self, x: torch.Tensor) -> tuple: """ Process a sequence, demonstrating weight sharing. Args: x: Input tensor (batch_size, seq_len, input_dim) Returns: hidden_states: All hidden states (batch_size, seq_len, hidden_dim) final_state: Last hidden state (batch_size, hidden_dim) """ batch_size, seq_len, _ = x.shape # Initialize hidden state h_t = torch.zeros(batch_size, self.hidden_dim, device=x.device) hidden_states = [] # Process each timestep with the SAME weights for t in range(seq_len): h_t = self.step(x[:, t, :], h_t) # Same step function, same weights hidden_states.append(h_t) return torch.stack(hidden_states, dim=1), h_t def count_parameters(self): """Count total trainable parameters.""" return sum(p.numel() for p in self.parameters() if p.requires_grad) class UnsharedWeightsRNN(nn.Module): """ Hypothetical RNN WITHOUT weight sharing (for comparison). Each timestep has its own parameters. This demonstrates why weight sharing is essential. """ def __init__(self, input_dim: int, hidden_dim: int, max_seq_len: int): super().__init__() self.hidden_dim = hidden_dim self.max_seq_len = max_seq_len # SEPARATE parameters for each timestep! self.W_xh_list = nn.ParameterList([ nn.Parameter(torch.randn(hidden_dim, input_dim) * 0.01) for _ in range(max_seq_len) ]) self.W_hh_list = nn.ParameterList([ nn.Parameter(torch.randn(hidden_dim, hidden_dim) * 0.01) for _ in range(max_seq_len) ]) self.b_h_list = nn.ParameterList([ nn.Parameter(torch.zeros(hidden_dim)) for _ in range(max_seq_len) ]) def forward(self, x: torch.Tensor) -> tuple: """Process sequence with position-specific weights.""" batch_size, seq_len, _ = x.shape if seq_len > self.max_seq_len: raise ValueError( f"Sequence length {seq_len} exceeds max {self.max_seq_len}. " "Cannot process—no weights exist for these positions!" ) h_t = torch.zeros(batch_size, self.hidden_dim, device=x.device) hidden_states = [] for t in range(seq_len): # Use timestep-SPECIFIC weights pre_activation = ( torch.mm(x[:, t, :], self.W_xh_list[t].t()) + torch.mm(h_t, self.W_hh_list[t].t()) + self.b_h_list[t] ) h_t = torch.tanh(pre_activation) hidden_states.append(h_t) return torch.stack(hidden_states, dim=1), h_t def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) # Compare parameter countsdef compare_parameter_counts(): input_dim = 128 hidden_dim = 256 max_seq_len = 100 shared_rnn = ExplicitWeightSharingRNN(input_dim, hidden_dim) unshared_rnn = UnsharedWeightsRNN(input_dim, hidden_dim, max_seq_len) shared_params = shared_rnn.count_parameters() unshared_params = unshared_rnn.count_parameters() print(f"Input dim: {input_dim}, Hidden dim: {hidden_dim}") print(f"Max sequence length: {max_seq_len}") print(f"\nShared weights RNN: {shared_params:,} parameters") print(f"Unshared weights RNN: {unshared_params:,} parameters") print(f"Ratio: {unshared_params / shared_params:.1f}x more parameters without sharing") # Demonstrate length limitation print("\nLength generalization test:") try: x_long = torch.randn(1, 150, input_dim) # Longer than max_seq_len _ = unshared_rnn(x_long) except ValueError as e: print(f"Unshared RNN: {e}") # Shared RNN handles any length _, final = shared_rnn(torch.randn(1, 150, input_dim)) print(f"Shared RNN: Successfully processed length 150!") _, final = shared_rnn(torch.randn(1, 1000, input_dim)) print(f"Shared RNN: Successfully processed length 1000!") if __name__ == "__main__": compare_parameter_counts()Parameter sharing encodes a powerful prior: 'The transformation I should apply to sequential data is the same regardless of when in the sequence I am.' This prior is appropriate for most sequential phenomena—language, music, sensor data—where patterns can occur at any position.
Parameter sharing provides statistical advantages beyond just reducing model size. These benefits are fundamental to learning from sequential data.
1. Data efficiency through pooling
With shared weights, every timestep in every training sequence provides gradients to the same parameters. If you have:
Without sharing, each position would only see ~10,000 examples (one per sequence). With sharing, the shared weights effectively see all 500,000 examples. This 50x increase in effective training data per parameter leads to dramatically better generalization.
2. Position-invariant pattern learning
When the word 'not' appears, it typically inverts the sentiment of what follows—regardless of whether it's word 3 or word 30. With shared weights, the network only needs to learn this once. The same $W_{xh}$ that learns 'not' means negation at position 5 automatically applies that knowledge at position 50.
3. Implicit regularization through constraint
Parameter sharing can be viewed as a form of structured regularization. Instead of freely adjusting $T \times d^2$ parameters, we're constraining them to all be copies of the same $d^2$ values. This constraint prevents the model from memorizing position-specific artifacts and forces it to find transformations that genuinely apply across all positions.
Mathematical perspective:
Let $\theta_1, \ldots, \theta_T$ be the parameters we'd have without sharing. Parameter sharing imposes the constraint: $$\theta_1 = \theta_2 = \cdots = \theta_T = \theta$$
This is equivalent to adding a regularization term that penalizes differences between timestep-specific parameters—but with an infinite penalty, forcing exact equality. The result is a model that cannot overfit to position-specific noise.
Parameter sharing trades increased bias (the assumption that the same transformation works everywhere) for dramatically reduced variance (far fewer parameters to estimate). For most sequential tasks, this tradeoff is overwhelmingly favorable—the 'same transformation' assumption is approximately true, and the variance reduction enables learning from realistic data sizes.
Parameter sharing in RNNs (temporal) and CNNs (spatial) share the same underlying principle but apply it in different ways. Understanding the comparison illuminates both architectures.
Convolutional Neural Networks (spatial sharing):
In CNNs, filter weights are shared across spatial locations. A 3×3 filter for edge detection uses the same 9 weights whether it's applied at the top-left, center, or bottom-right of an image.
$$y_{i,j} = \sum_{m}\sum_{n} W_{m,n} \cdot x_{i+m, j+n}$$
The insight: "Edges look like edges regardless of where in the image they appear."
Recurrent Neural Networks (temporal sharing):
In RNNs, weights are shared across temporal positions. The transformation from $h_{t-1}$ to $h_t$ uses the same weights whether $t=1$ or $t=100$.
$$h_t = f(W_{hh} h_{t-1} + W_{xh} x_t + b_h)$$
The insight: "Sequential patterns are meaningful regardless of when in the sequence they occur."
| Aspect | RNN (Temporal) | CNN (Spatial) |
|---|---|---|
| Shared over | Timesteps (1D) | Spatial locations (2D typically) |
| Assumption | Dynamics are time-invariant | Features are translation-invariant |
| Receptive field | All past (accumulates) | Local (fixed kernel size) |
| Memory mechanism | Hidden state carries information | Sequential layers with pooling |
| Parallelization | Sequential (inherently ordered) | Fully parallel (independent locations) |
| Length handling | Any length at inference | Any size at inference |
| Computational pattern | O(T) sequential steps | O(1) parallel over space |
Key differences:
1. Receptive field accumulation: CNNs have local receptive fields that grow by stacking layers. RNNs have a globally connected receptive field by design—at timestep $t$, the hidden state has been influenced by $x_1$ through $x_t$ regardless of network depth.
2. Dependency structure: CNNs assume spatial locality—nearby pixels are more related than distant ones. RNNs assume temporal ordering—the past influences the present, but the relationship may span arbitrary distances.
3. Parallelization: CNN filter applications are independent and can be computed in parallel. RNN timesteps depend on previous steps and must be computed sequentially, limiting parallelism.
4. Stationarity assumption: Both assume their respective invariances, but the nature differs:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
import torchimport torch.nn as nn def compare_sharing_mechanisms(): """ Demonstrate the structural similarity between CNN and RNN parameter sharing. """ # CNN: Same filter applied at all spatial positions conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1) # Number of parameters: in_channels * out_channels * kernel_h * kernel_w + bias conv_params = 3 * 16 * 3 * 3 + 16 print(f"Conv2d params: {conv_params} (same at all H×W positions)") # Apply to image of any size img_small = torch.randn(1, 3, 32, 32) img_large = torch.randn(1, 3, 512, 512) out_small = conv(img_small) # Works! out_large = conv(img_large) # Same weights, different size print(f"Small image: {img_small.shape} -> {out_small.shape}") print(f"Large image: {img_large.shape} -> {out_large.shape}") print("\n" + "="*50 + "\n") # RNN: Same weights applied at all temporal positions rnn = nn.RNN(input_size=128, hidden_size=256, batch_first=True) # Number of parameters (for 1 layer): # W_ih: hidden_size * input_size # W_hh: hidden_size * hidden_size # b_ih: hidden_size # b_hh: hidden_size rnn_params = 256 * 128 + 256 * 256 + 256 + 256 print(f"RNN params: {rnn_params} (same at all T positions)") # Apply to sequence of any length seq_short = torch.randn(1, 10, 128) seq_long = torch.randn(1, 500, 128) out_short, _ = rnn(seq_short) # Works! out_long, _ = rnn(seq_long) # Same weights, different length print(f"Short sequence: {seq_short.shape} -> {out_short.shape}") print(f"Long sequence: {seq_long.shape} -> {out_long.shape}") # Both architectures handle variable sizes with fixed parameters print("\n" + "="*50) print("Both CNN and RNN can process variable-sized inputs") print("using a fixed number of shared parameters!") def demonstrate_weight_identity(): """ Explicitly show that RNN uses the same weights at each step. """ rnn = nn.RNNCell(input_size=4, hidden_size=8) # Access the weight tensors W_ih = rnn.weight_ih # Input-to-hidden W_hh = rnn.weight_hh # Hidden-to-hidden print("RNN weights:") print(f"W_ih id: {id(W_ih)}, shape: {W_ih.shape}") print(f"W_hh id: {id(W_hh)}, shape: {W_hh.shape}") # Process multiple timesteps x = torch.randn(3, 10, 4) # batch=3, len=10, features=4 h = torch.zeros(3, 8) print("\nProcessing 10 timesteps:") for t in range(10): h = rnn(x[:, t, :], h) # Every call uses the SAME W_ih and W_hh print(f" Step {t+1}: W_ih id = {id(rnn.weight_ih)} (same object)") if __name__ == "__main__": compare_sharing_mechanisms() print("\n" + "="*50 + "\n") demonstrate_weight_identity()While parameter sharing is powerful, it encodes a strong prior: temporal stationarity. This assumption doesn't always hold, and understanding when it fails guides architectural choices.
Cases where temporal stationarity is violated:
1. Position-dependent semantics
In some tasks, absolute position genuinely matters:
2. Multi-regime sequences
Some sequences have distinct phases:
3. Periodic or cyclic patterns
If your task has inherent position-dependence that cannot be inferred from the data itself, pure parameter sharing may underperform. However, the solution is rarely to abandon sharing entirely—it's to augment the architecture with position information.
Solutions that preserve sharing while adding position awareness:
1. Positional encodings
Add position information directly to the input: $$\tilde{x}_t = x_t + \text{PE}(t)$$
The RNN still uses shared weights, but now $\tilde{x}_t$ contains both content and position information. This is the approach used in Transformers.
2. Concatenated position features
Include normalized position as an additional feature: $$\tilde{x}_t = [x_t; t/T; \text{one-hot}(\text{segment})]$$
3. Segment-specific processing
For multi-regime sequences, use different RNNs or add segment embeddings:
4. Explicit clock signals
For periodic data, include phase information: $$\tilde{x}_t = [x_t; \sin(2\pi t/P); \cos(2\pi t/P)]$$
where $P$ is the expected period.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
import torchimport torch.nn as nnimport math class PositionAwareRNN(nn.Module): """ RNN that maintains weight sharing but adds position information. This preserves the benefits of parameter sharing while allowing the model to learn position-dependent behavior. """ def __init__( self, input_dim: int, hidden_dim: int, position_encoding: str = 'sinusoidal', max_len: int = 5000 ): super().__init__() self.position_encoding = position_encoding if position_encoding == 'sinusoidal': # Sinusoidal position encodings (like Transformer) pe = torch.zeros(max_len, input_dim) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, input_dim, 2).float() * (-math.log(10000.0) / input_dim) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) rnn_input_dim = input_dim elif position_encoding == 'concatenated': # Simple concatenation of normalized position rnn_input_dim = input_dim + 1 elif position_encoding == 'learned': # Learned position embeddings self.position_embedding = nn.Embedding(max_len, input_dim) rnn_input_dim = input_dim else: # 'none' rnn_input_dim = input_dim # Core RNN with SHARED weights self.rnn = nn.RNN(rnn_input_dim, hidden_dim, batch_first=True) self.hidden_dim = hidden_dim def add_position_info(self, x: torch.Tensor) -> torch.Tensor: """Add position information to input based on encoding type.""" batch_size, seq_len, _ = x.shape if self.position_encoding == 'sinusoidal': # Add sinusoidal encodings return x + self.pe[:seq_len].unsqueeze(0) elif self.position_encoding == 'concatenated': # Concatenate normalized position positions = torch.arange(seq_len, device=x.device).float() / seq_len positions = positions.view(1, seq_len, 1).expand(batch_size, -1, -1) return torch.cat([x, positions], dim=-1) elif self.position_encoding == 'learned': # Add learned position embeddings positions = torch.arange(seq_len, device=x.device) pos_embed = self.position_embedding(positions) return x + pos_embed.unsqueeze(0) else: return x def forward(self, x: torch.Tensor) -> tuple: """Process sequence with position-augmented input.""" x_pos = self.add_position_info(x) output, h_n = self.rnn(x_pos) return output, h_n class CyclicRNN(nn.Module): """ RNN designed for data with known periodicity. Adds phase information explicitly, allowing the model to learn cycle-dependent behavior while maintaining weight sharing. """ def __init__( self, input_dim: int, hidden_dim: int, periods: list = [24, 7, 365] # e.g., hourly, daily, yearly ): super().__init__() # Each period adds sin/cos features augmented_dim = input_dim + 2 * len(periods) self.rnn = nn.RNN(augmented_dim, hidden_dim, batch_first=True) self.periods = periods def add_cyclic_features(self, x: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: """ Add cyclic (sin/cos) features for each period. Args: x: Input features (batch, seq_len, features) timesteps: Absolute timestamps (batch, seq_len) """ cyclic_features = [] for period in self.periods: phase = 2 * math.pi * timesteps / period cyclic_features.append(torch.sin(phase).unsqueeze(-1)) cyclic_features.append(torch.cos(phase).unsqueeze(-1)) cyclic = torch.cat(cyclic_features, dim=-1) return torch.cat([x, cyclic], dim=-1) def forward(self, x: torch.Tensor, timesteps: torch.Tensor) -> tuple: """Process sequence with cyclic augmentation.""" x_cyclic = self.add_cyclic_features(x, timesteps) return self.rnn(x_cyclic) # Example usagedef demonstrate_position_encoding_impact(): """Show how position encoding affects RNN behavior.""" torch.manual_seed(42) input_dim, hidden_dim, seq_len = 16, 32, 50 # Create models rnn_none = PositionAwareRNN(input_dim, hidden_dim, 'none') rnn_sin = PositionAwareRNN(input_dim, hidden_dim, 'sinusoidal') rnn_cat = PositionAwareRNN(input_dim, hidden_dim, 'concatenated') # Same input for all models x = torch.randn(1, seq_len, input_dim) # Get outputs out_none, _ = rnn_none(x) out_sin, _ = rnn_sin(x) out_cat, _ = rnn_cat(x) print("Position Encoding Impact Analysis") print("="*50) # Check if models can distinguish positions # For a model that knows position, identical inputs at different # positions should produce different outputs x_same = torch.ones(1, seq_len, input_dim) * 0.5 # Same input at every position out_none_same, _ = rnn_none(x_same) out_sin_same, _ = rnn_sin(x_same) # Measure output variance across positions var_none = out_none_same.var(dim=1).mean().item() var_sin = out_sin_same.var(dim=1).mean().item() print(f"Output variance with identical inputs:") print(f" Without position encoding: {var_none:.6f}") print(f" With sinusoidal encoding: {var_sin:.6f}") print() print("Higher variance indicates the model distinguishes positions") if __name__ == "__main__": demonstrate_position_encoding_impact()Parameter sharing has profound implications for how gradients flow during training. Understanding this is crucial for understanding both the power and the challenges of RNN optimization.
Gradient aggregation across timesteps:
Because the same weights are used at every timestep, the total gradient is the sum of gradients from all timesteps:
$$\frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^{T} \frac{\partial L}{\partial W_{hh}^{(t)}}$$
where the right side computes as if weights were separate, then aggregates.
Expanded via chain rule: $$\frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^{T} \sum_{k=1}^{t} \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial h_k} \frac{\partial h_k}{\partial W_{hh}}$$
This double sum reveals the complexity: every timestep contributes gradients, and earlier states affect all later gradients.
Shared weights mean gradients aggregate from all timesteps, providing rich training signal. But shares weights also mean gradient pathologies (vanishing/exploding) accumulate multiplicatively across timesteps. This is why BPTT (which we'll cover next) presents unique challenges.
Implications for learning:
1. High-frequency patterns dominate early training
Because recent timesteps have stronger gradients (less decay through the chain rule), the network first learns patterns that manifest near the end of sequences. Long-range dependencies are harder to learn because their gradients are weaker.
2. Gradient magnitude varies with sequence length
Longer sequences accumulate more gradient contributions. Without careful normalization, this can cause instability when training on variable-length sequences.
3. Conflicting gradients
Different timesteps may push weights in different directions. Position 5 might want $W_{hh}$ to emphasize feature A, while position 50 wants emphasis on feature B. The final weight is a compromise.
4. Shared weights provide averaging effect
The aggregation of gradients from many timesteps acts as an averaging operation, reducing noise and potentially providing regularization. This is similar to how batch averaging reduces gradient variance.
Practical strategies:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
import torchimport torch.nn as nn def analyze_gradient_flow(seq_lengths: list = [10, 50, 100, 200]): """ Analyze how gradients scale with sequence length in shared-weight RNNs. """ input_dim, hidden_dim = 8, 16 # Create RNN rnn = nn.RNN(input_dim, hidden_dim, batch_first=True) # Loss: simple sum of outputs def compute_loss(outputs): return outputs.sum() print("Gradient Analysis: Shared Weights Across Sequence Lengths") print("="*60) results = [] for seq_len in seq_lengths: # Forward pass x = torch.randn(1, seq_len, input_dim, requires_grad=True) h0 = torch.zeros(1, 1, hidden_dim) output, _ = rnn(x, h0) loss = compute_loss(output) # Backward pass rnn.zero_grad() loss.backward() # Analyze gradients grad_W_hh = rnn.weight_hh_l0.grad grad_W_ih = rnn.weight_ih_l0.grad grad_W_hh_norm = grad_W_hh.norm().item() grad_W_ih_norm = grad_W_ih.norm().item() results.append({ 'seq_len': seq_len, 'grad_W_hh': grad_W_hh_norm, 'grad_W_ih': grad_W_ih_norm, }) print(f"Seq length {seq_len:3d}: " f"||∇W_hh|| = {grad_W_hh_norm:8.4f}, " f"||∇W_ih|| = {grad_W_ih_norm:8.4f}") # Analyze scaling print("\nScaling Analysis:") for i in range(1, len(results)): ratio = results[i]['grad_W_hh'] / results[i-1]['grad_W_hh'] len_ratio = results[i]['seq_len'] / results[i-1]['seq_len'] print(f" {results[i-1]['seq_len']} -> {results[i]['seq_len']}: " f"length ×{len_ratio:.1f}, gradient ×{ratio:.2f}") def visualize_timestep_contributions(): """ Show which timesteps contribute most to gradients. """ import matplotlib.pyplot as plt input_dim, hidden_dim, seq_len = 4, 8, 20 rnn = nn.RNN(input_dim, hidden_dim, batch_first=True) # We'll track which timesteps contribute to W_hh gradients contributions = [] for target_t in range(seq_len): # Create input where only position target_t is non-zero x = torch.zeros(1, seq_len, input_dim) x[0, target_t, :] = torch.randn(input_dim) h0 = torch.zeros(1, 1, hidden_dim) output, _ = rnn(x, h0) # Loss from final output only loss = output[0, -1, :].sum() rnn.zero_grad() loss.backward() grad_norm = rnn.weight_hh_l0.grad.norm().item() contributions.append(grad_norm) # Plot plt.figure(figsize=(10, 5)) plt.bar(range(seq_len), contributions, alpha=0.7) plt.xlabel('Timestep') plt.ylabel('||∇W_hh|| contribution') plt.title('Gradient Contribution by Timestep\n(input only at that timestep, loss from final output)') plt.xticks(range(0, seq_len, 2)) plt.grid(axis='y', alpha=0.3) # Annotate pattern plt.annotate('Recent inputs contribute more\n(vanishing gradient effect)', xy=(seq_len-2, contributions[-2]), xytext=(seq_len-8, max(contributions)*0.8), arrowprops=dict(arrowstyle='->', color='red'), fontsize=10) plt.tight_layout() plt.savefig('timestep_gradient_contributions.png', dpi=150) plt.show() if __name__ == "__main__": analyze_gradient_flow() print() visualize_timestep_contributions()Parameter sharing is the architectural decision that makes RNNs practically viable for sequence modeling. Let's consolidate the key insights:
What's next:
Now that we understand how parameter sharing enables RNNs to process variable-length sequences efficiently, the next page explores forward computation—the actual mechanics of how RNNs process sequences from input to output. We'll trace through complete forward passes, understand the computational graph, and see how predictions are generated.
You now deeply understand parameter sharing in RNNs. You can explain why it's necessary, how it differs from CNN weight sharing, when its assumptions break and how to mitigate, and how it affects gradient flow. This understanding is essential for working with any recurrent architecture.