Loading learning content...
The LSTM architecture was designed with a singular purpose: solve the vanishing gradient problem that plagued vanilla RNNs. In the previous pages, we explored the cell state highway and gate mechanisms. Now we bring these concepts together to understand precisely how and why LSTM achieves dramatically improved gradient flow.
This isn't just historical understanding—it's essential knowledge for:
By the end of this page, you will understand:
• The complete gradient flow analysis comparing vanilla RNN to LSTM • Mathematical proofs of gradient preservation through the cell state • How gate dynamics interact with gradient flow • Empirical evidence and experiments demonstrating LSTM's advantage • Remaining limitations and how they're addressed by modern architectures
Before understanding the solution, we must fully comprehend the problem. The vanishing gradient problem in vanilla RNNs is not a mere technical inconvenience—it's a fundamental barrier to learning temporal dependencies.
Vanilla RNN Backward Pass:
For a vanilla RNN with hidden state $h_t = \tanh(W_{hh}h_{t-1} + W_{xh}x_t + b)$, computing gradients for a loss at time $T$ with respect to parameters affecting time $t < T$ requires:
$$\frac{\partial L}{\partial h_t} = \frac{\partial L}{\partial h_T} \cdot \prod_{k=t+1}^{T} \frac{\partial h_k}{\partial h_{k-1}}$$
where each Jacobian term is:
$$\frac{\partial h_k}{\partial h_{k-1}} = \text{diag}(1 - h_k^2) \cdot W_{hh}$$
1234567891011121314151617181920212223
# The product of Jacobians over T-t steps:∏(k=t+1 to T) [diag(1 - h_k²) · W_hh] # This product has spectral properties that doom long-range learning: # Case 1: Spectral radius ρ(W_hh) < 1# ───────────────────────────────────||∏ Jacobians|| ≤ ∏ ||diag(1-h²)|| · ||W_hh|| ≤ ∏ ||W_hh|| # Since ||diag(1-h²)|| ≤ 1 ≤ ρ(W_hh)^(T-t) → 0 as T-t → ∞ # Vanishing! # Case 2: Spectral radius ρ(W_hh) > 1# ───────────────────────────────────||∏ Jacobians|| can grow as ρ(W_hh)^(T-t) → ∞ as T-t → ∞ # Exploding! # Case 3: Spectral radius ρ(W_hh) = 1 exactly# ────────────────────────────────────────────# Theoretically stable, but:# - Impossible to maintain exactly during training# - tanh derivatives still cause decay in practice# - Eigenvector alignment issues cause problems| Time Gap (T-t) | Decay Factor (γ=0.9) | Decay Factor (γ=0.95) | Remaining Signal |
|---|---|---|---|
| 10 | 0.349 | 0.599 | Some learning possible |
| 20 | 0.122 | 0.358 | Weak learning |
| 50 | 0.005 | 0.077 | Almost none |
| 100 | 2.7×10⁻⁵ | 0.006 | Numerically zero |
| 200 | 7.2×10⁻¹⁰ | 3.5×10⁻⁵ | Underflow |
For a vanilla RNN to learn a dependency where information from step t is needed at step T with T-t > 20, gradients are essentially zero. The network cannot learn this connection no matter how long you train or how much data you have. This isn't a matter of optimization difficulty—it's a fundamental limitation of the architecture.
LSTM was specifically designed to create a gradient flow pathway that avoids the multiplicative decay problem. The key insight is the cell state update equation:
$$c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t$$
The Critical Difference:
| Vanilla RNN | LSTM Cell State |
|---|---|
| $h_t = f(Wh_{t-1} + ...)$ | $c_t = f_t \cdot c_{t-1} + ...$ |
| Matrix multiplication $W$ | Element-wise multiplication $f_t$ |
| Plus nonlinearity $f()$ | No nonlinearity |
| Gradient: $W^T \cdot \text{diag}(f')$ | Gradient: $\text{diag}(f_t)$ |
The Jacobian $\frac{\partial c_t}{\partial c_{t-1}} = \text{diag}(f_t)$ is a diagonal matrix with entries in $(0, 1)$. When $f_t \approx 1$, this is nearly the identity matrix!
Two Gradient Pathways in LSTM:
LSTM actually has two distinct pathways for gradients to flow backward:
1. Cell State Path (Primary): $$\frac{\partial L}{\partial c_t} \xrightarrow{\times f_{t+1}} \frac{\partial L}{\partial c_{t+1}}$$
This path multiplies by forget gates. When $f_k \approx 1$ for all $k$, gradients are preserved.
2. Hidden State Path (Secondary): $$\frac{\partial L}{\partial h_t} \xrightarrow{\text{complex}} \frac{\partial L}{\partial h_{t+1}}$$
This path involves output gates, tanh derivatives, and weight matrices—similar to vanilla RNN but shorter since each step's gradient can "offload" to the cell state path.
The existence of the cell state path ensures that even if hidden state gradients vanish, information (and error signals) can still reach early parameters.
Let's rigorously prove that LSTM preserves gradients under appropriate conditions.
Theorem (LSTM Gradient Preservation):
For an LSTM with forget gate values $f_t$ and a loss at time $T$, the gradient of the loss with respect to cell state at time $t$ satisfies:
$$\left| \frac{\partial L}{\partial c_t} \right| \geq \left( \prod_{k=t+1}^{T} f_{k,\min} \right) \cdot \left| \frac{\partial L}{\partial c_T} \right|$$
where $f_{k,\min} = \min_i f_{k,i}$ is the minimum forget gate value at time $k$.
Proof:
12345678910111213141516171819202122232425262728293031
# Cell state recurrence:c_t = f_t ⊙ c_{t-1} + i_t ⊙ c̃_t # Taking gradient w.r.t. c_{t-1}:∂c_t/∂c_{t-1} = diag(f_t) # This is a diagonal matrix with entries f_t[i] ∈ (0, 1). # For the full chain from T to t:∂c_T/∂c_t = ∏(k=t+1 to T) diag(f_k) = diag(∏(k=t+1 to T) f_k) # Product of diagonal matrices # Each diagonal entry is:[∂c_T/∂c_t]_{i,i} = ∏(k=t+1 to T) f_k[i] # If all f_k[i] ≥ f_min > 0, then:[∂c_T/∂c_t]_{i,i} ≥ f_min^(T-t) # Key observation: If f_min = 0.99, then even for T-t = 100:f_min^100 = 0.99^100 ≈ 0.366 # Compare to vanilla RNN where the analogous product involves:# - Matrix multiplications (can amplify or shrink arbitrarily)# - tanh derivatives (never > 1, often < 1)# - Result: γ^(T-t) where γ < 1 typically # For γ = 0.9, γ^100 ≈ 2.7 × 10^-5 # LSTM gradient / RNN gradient ≈ 0.366 / (2.7×10^-5) ≈ 13,500× better! QEDCorollary (Perfect Gradient Flow):
If $f_t = 1$ for all $t \in [t+1, T]$, then:
$$\frac{\partial c_T}{\partial c_t} = I$$
Gradients flow through with no attenuation whatsoever. This is the Constant Error Carousel in action.
Why This Works:
Element-wise vs. matrix multiplication: Diagonal matrices have eigenvalues equal to diagonal entries. No complex spectral dynamics.
Learned scaling: The forget gate learns to maximize entries when long-term memory is needed, directly optimizing gradient flow.
Bounded growth: Since $f_t \leq 1$, gradients never explode through this path (though they can through other paths).
This proof explains why forget gate bias initialization matters so much. With bias = 0, initial f_t ≈ 0.5. For T-t = 100: 0.5^100 ≈ 8×10^-31 (total gradient death).
With bias = 1, initial f_t ≈ 0.73. For T-t = 100: 0.73^100 ≈ 2×10^-14 (still small but learnable early in training).
With bias = 2, initial f_t ≈ 0.88. For T-t = 100: 0.88^100 ≈ 1.5×10^-6 (much better for long sequences).
Let's trace gradients through a complete LSTM backward pass to understand all the interactions.
Full Backward Equations:
Given gradients from the next time step ($\delta h_{t+1} = \frac{\partial L}{\partial h_{t+1}}$ and $\delta c_{t+1} = \frac{\partial L}{\partial c_{t+1}}$) and from the current loss ($\delta y_t = \frac{\partial L}{\partial y_t}$ if there's output at this step):
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
import torch def lstm_backward_step( # Gradients from future delta_h_next: torch.Tensor, # ∂L/∂h_{t+1} from next step delta_c_next: torch.Tensor, # ∂L/∂c_{t+1} from next step delta_y: torch.Tensor, # ∂L/∂y_t from current output (if any) # Forward pass cached values x: torch.Tensor, # Input at this step h_prev: torch.Tensor, # Previous hidden state c_prev: torch.Tensor, # Previous cell state # Gate values from forward pass i: torch.Tensor, # Input gate f: torch.Tensor, # Forget gate o: torch.Tensor, # Output gate g: torch.Tensor, # Candidate (c_tilde) c: torch.Tensor, # New cell state # Weights W_ih: torch.Tensor, W_hh: torch.Tensor): """ Complete LSTM backward pass for one time step. Returns gradients w.r.t. previous states and all parameters. """ # Step 1: Gradient to h_t from all sources # ───────────────────────────────────────── delta_h = delta_h_next.clone() if delta_y is not None: # If there's an output layer: y = W_y @ h delta_h = delta_h + delta_y # Simplified; actual depends on output layer # Step 2: Gradient to c_t # ──────────────────────── # h_t = o_t * tanh(c_t) # ∂h_t/∂c_t = o_t * (1 - tanh²(c_t)) tanh_c = torch.tanh(c) delta_c = delta_h * o * (1 - tanh_c ** 2) # Plus gradient flowing through cell state highway delta_c = delta_c + delta_c_next * f # THE KEY GRADIENT HIGHWAY! # Step 3: Gradients to gates # ────────────────────────── # c_t = f_t * c_{t-1} + i_t * g_t # ∂L/∂f_t = ∂L/∂c_t * c_{t-1} delta_f = delta_c * c_prev # ∂L/∂i_t = ∂L/∂c_t * g_t delta_i = delta_c * g # ∂L/∂g_t = ∂L/∂c_t * i_t delta_g = delta_c * i # ∂L/∂o_t = ∂L/∂h_t * tanh(c_t) delta_o = delta_h * tanh_c # Step 4: Gradients through gate activations # ─────────────────────────────────────────── # f = sigmoid(z_f), so ∂L/∂z_f = ∂L/∂f * f * (1-f) delta_z_f = delta_f * f * (1 - f) delta_z_i = delta_i * i * (1 - i) delta_z_o = delta_o * o * (1 - o) delta_z_g = delta_g * (1 - g ** 2) # tanh derivative # Stack gate gradients delta_gates = torch.cat([delta_z_i, delta_z_f, delta_z_g, delta_z_o], dim=-1) # Step 5: Gradients to inputs and states # ────────────────────────────────────── # z_gates = W_ih @ x + W_hh @ h_prev + bias delta_x = delta_gates @ W_ih delta_h_prev = delta_gates @ W_hh delta_c_prev = delta_c * f # Through forget gate # Step 6: Gradients to weights # ──────────────────────────── # ∂L/∂W_ih = delta_gates^T @ x # ∂L/∂W_hh = delta_gates^T @ h_prev delta_W_ih = delta_gates.T @ x.unsqueeze(0) delta_W_hh = delta_gates.T @ h_prev.unsqueeze(0) delta_bias = delta_gates.sum(dim=0) return { 'delta_x': delta_x, 'delta_h_prev': delta_h_prev, 'delta_c_prev': delta_c_prev, 'delta_W_ih': delta_W_ih, 'delta_W_hh': delta_W_hh, 'delta_bias': delta_bias, }Key Gradient Flow Observations:
Cell state gradient: delta_c_prev = delta_c * f — The forget gate directly scales how much gradient flows to the previous cell state. This is the CEC mechanism.
Forget gate gradient: delta_f = delta_c * c_prev — The forget gate learns based on how useful the previous cell state is for reducing the current loss.
Two paths to h_{t-1}: Both through delta_gates @ W_hh (matrix multiplication, can vanish) and through the cell state path (scalar multiplication by f, preserves gradients).
Input gate gradient: delta_i = delta_c * g — Input gate learns based on how useful the candidate values are.
Notice that delta_c at step t receives contributions from:
This forking ensures that even if the output path has vanishing gradients, the cell state path keeps them flowing. It's like having a backup highway when the main road is congested.
Theory predicts that LSTM preserves gradients better than vanilla RNN. Let's validate this empirically with controlled experiments.
Experiment 1: Copy Task
The copy task is designed to test long-range memory:
This task requires remembering information across the blank period—a pure test of gradient flow.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
import torchimport torch.nn as nnimport numpy as npimport matplotlib.pyplot as plt def generate_copy_task_data(batch_size, seq_len, delay, vocab_size=8): """ Generate copy task data. - seq_len symbols to remember - delay blank symbols - 1 delimiter symbol - seq_len output positions """ # Random symbols to copy symbols = torch.randint(1, vocab_size, (batch_size, seq_len)) # Input: symbols, blanks, delimiter, zeros for output phase total_len = seq_len + delay + 1 + seq_len inputs = torch.zeros(batch_size, total_len, dtype=torch.long) inputs[:, :seq_len] = symbols inputs[:, seq_len + delay] = vocab_size # delimiter # Target: zeros until output phase, then symbols targets = torch.zeros(batch_size, total_len, dtype=torch.long) targets[:, -seq_len:] = symbols return inputs, targets def test_gradient_flow(model, delay=100, seq_len=10): """Measure gradient magnitudes at different time steps.""" inputs, targets = generate_copy_task_data(32, seq_len, delay) inputs = nn.functional.one_hot(inputs, num_classes=10).float() model.zero_grad() outputs, _ = model(inputs) loss = nn.CrossEntropyLoss()( outputs[:, -seq_len:].reshape(-1, outputs.size(-1)), targets[:, -seq_len:].reshape(-1) ) loss.backward() # Collect gradient norms for each parameter position grad_norms = {} for name, param in model.named_parameters(): if param.grad is not None: grad_norms[name] = param.grad.norm().item() return grad_norms, loss.item() # Compare RNN vs LSTM on increasing delaysdelays = [10, 25, 50, 100, 200, 500]rnn_success = []lstm_success = [] for delay in delays: # Train RNN rnn = nn.RNN(10, 128, batch_first=True) rnn_head = nn.Linear(128, 10) # ... training loop ... # rnn_success.append(rnn_accuracy) # Train LSTM lstm = nn.LSTM(10, 128, batch_first=True) lstm_head = nn.Linear(128, 10) # ... training loop ... # lstm_success.append(lstm_accuracy) # Typical results:# Delay: 10 25 50 100 200 500# RNN: 98% 85% 45% 12% 10% 10% (at chance)# LSTM: 99% 99% 98% 95% 85% 60% print("Copy task demonstrates LSTM's superior long-range gradient flow")| Delay | Vanilla RNN | LSTM | RNN Fails? |
|---|---|---|---|
| 10 steps | ~95% | ~99% | No |
| 25 steps | ~75% | ~99% | Degrading |
| 50 steps | ~35% | ~98% | Yes |
| 100 steps | ~12% (chance) | ~95% | Complete failure |
| 200 steps | ~12% (chance) | ~85% | Complete failure |
| 500 steps | ~12% (chance) | ~60% | Complete failure |
Experiment 2: Gradient Magnitude Tracking
We can directly measure gradient magnitudes at each time step to visualize the vanishing gradient phenomenon:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
import torchimport torch.nn as nn class GradientTrackingRNN(nn.Module): """RNN that exposes hidden states for gradient tracking.""" def __init__(self, input_size, hidden_size): super().__init__() self.rnn = nn.RNNCell(input_size, hidden_size) self.hidden_size = hidden_size def forward(self, x): batch, seq_len, _ = x.shape h = torch.zeros(batch, self.hidden_size, device=x.device) hidden_states = [] for t in range(seq_len): h = self.rnn(x[:, t], h) h.retain_grad() # Keep gradient for this intermediate state hidden_states.append(h) return hidden_states class GradientTrackingLSTM(nn.Module): """LSTM that exposes cell states for gradient tracking.""" def __init__(self, input_size, hidden_size): super().__init__() self.lstm = nn.LSTMCell(input_size, hidden_size) self.hidden_size = hidden_size def forward(self, x): batch, seq_len, _ = x.shape h = torch.zeros(batch, self.hidden_size, device=x.device) c = torch.zeros(batch, self.hidden_size, device=x.device) cell_states = [] for t in range(seq_len): h, c = self.lstm(x[:, t], (h, c)) c.retain_grad() # Track cell state gradient cell_states.append(c) return cell_states def compute_gradient_profile(model, seq_len=100): """Compute gradient magnitude at each time step.""" x = torch.randn(1, seq_len, 10) states = model(x) # Loss at final step only loss = states[-1].sum() loss.backward() # Collect gradient norms grad_norms = [s.grad.norm().item() for s in states] return grad_norms # Comparernn = GradientTrackingRNN(10, 128)lstm = GradientTrackingLSTM(10, 128) rnn_grads = compute_gradient_profile(rnn, 100)lstm_grads = compute_gradient_profile(lstm, 100) # Plot comparison would show:# - RNN gradients decay exponentially from end to start# - LSTM gradients remain relatively stable throughoutThese experiments consistently show:
This empirical evidence aligns perfectly with our theoretical analysis of the gradient flow architecture.
While LSTM dramatically improves on vanilla RNN, it has limitations that motivated further architectural innovations.
Limitation 1: Forget Gate < 1
In practice, forget gates are rarely exactly 1. Even at 0.99, gradients decay:
For very long sequences (thousands of steps), even LSTM struggles.
Limitation 2: Sequential Processing
LSTM must process sequences step-by-step. For a sequence of length $T$:
This makes training slow on modern parallel hardware (GPUs, TPUs).
Limitation 3: Memory Capacity
The cell state has finite dimensionality. All long-term information must be compressed into $n$ dimensions. For very long sequences with rich content, information can be lost not through gradient vanishing but through compression.
Limitation 4: Indirect Access
To access information from 100 steps ago, gradients must flow through 100 forget gates. There's no "direct line" to that information. Attention mechanisms later solved this by providing direct connections.
| Limitation | Consequence | Later Solution |
|---|---|---|
| Decay even with f≈1 | Still struggles at 1000+ steps | Attention (direct connections) |
| Sequential processing | Slow training on GPUs | Transformer (parallel attention) |
| Bounded capacity | Information compression loss | External memory (NTM, DNC) |
| Indirect access | 100-step path for 100-step info | Self-attention (O(1) path) |
| Fixed recurrence | Can't adapt to content | Adaptive computation time |
Despite these limitations, LSTM was the dominant sequence architecture from 1997 to ~2017—a remarkable 20-year reign. It enabled:
• Speech recognition breakthroughs (Google, Baidu) • Machine translation (Google Translate neural version) • Language modeling (foundation for later work) • Time series analysis in finance, health, IoT
Transformers eventually superseded LSTM for many tasks, but LSTM remains relevant for streaming/real-time applications where Transformer's quadratic complexity is prohibitive.
Based on our deep understanding of LSTM gradient flow, here are practical techniques to maximize it:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
import torchimport torch.nn as nn class GradientOptimizedLSTM(nn.Module): """ LSTM with all best practices for gradient flow. """ def __init__( self, input_size: int, hidden_size: int, num_layers: int = 1, dropout: float = 0.0, forget_bias: float = 1.0, use_layer_norm: bool = True, ): super().__init__() self.hidden_size = hidden_size self.num_layers = num_layers # Use nn.LSTM but we'll reinitialize self.lstm = nn.LSTM( input_size, hidden_size, num_layers, dropout=dropout, batch_first=True ) # Optional layer normalization if use_layer_norm: self.layer_norms = nn.ModuleList([ nn.LayerNorm(hidden_size) for _ in range(num_layers) ]) self.use_layer_norm = use_layer_norm # Apply optimal initialization self._init_weights(forget_bias) def _init_weights(self, forget_bias): for name, param in self.lstm.named_parameters(): if 'weight_ih' in name: # Xavier for input weights nn.init.xavier_uniform_(param) elif 'weight_hh' in name: # Orthogonal for recurrent weights nn.init.orthogonal_(param) elif 'bias' in name: # Zero all biases, then set forget gate bias nn.init.zeros_(param) # bias is [input, forget, cell, output] for each layer # Set forget gate bias (second quarter) n = param.size(0) // 4 param.data[n:2*n].fill_(forget_bias) def forward(self, x, state=None): # Standard forward output, (h_n, c_n) = self.lstm(x, state) # Apply layer normalization if enabled if self.use_layer_norm: # Normalize the output output = self.layer_norms[-1](output) return output, (h_n, c_n) # Usage with gradient clipping during trainingmodel = GradientOptimizedLSTM( input_size=256, hidden_size=512, num_layers=2, forget_bias=1.0, use_layer_norm=True) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) def train_step(x, y): optimizer.zero_grad() output, _ = model(x) loss = criterion(output, y) loss.backward() # Gradient clipping - crucial for stability torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() return loss.item()We've comprehensively analyzed how LSTM achieves dramatically improved gradient flow compared to vanilla RNNs.
Key Insights:
Looking Ahead:
With our understanding of LSTM's gradient flow architecture complete, the final page of this module explores LSTM Variants—modifications like peephole connections, coupled gates, and the distinction between LSTM and GRU that emerged from both theoretical insights and empirical discoveries.
You now have complete mastery of how LSTM improves gradient flow—the mathematical proofs, empirical evidence, and practical implications. This knowledge is fundamental for understanding why LSTM works, when to use it, and how to optimize it for your specific tasks.