Loading learning content...
The original LSTM architecture from 1997 was a breakthrough, but researchers quickly began exploring variations—some theoretical refinements, others practical simplifications. Understanding these variants is essential for:
This page covers the major LSTM variants, from the original formulation through peepholes connections, coupled gates, and the influential GRU simplification. We'll analyze what each variant adds or removes, and when each is most appropriate.
By the end of this page, you will understand:
• The original LSTM vs. modern standard implementations • Peephole connections: what they add and when they help • Coupled forget-input gates and their implications • GRU: the elegant simplification and how it compares • Deep LSTMs, bidirectional variants, and stacking strategies • How to choose the right variant for your task
The LSTM we use today differs from the original 1997 architecture in several ways. Understanding this evolution clarifies which components are essential.
Original LSTM (Hochreiter & Schmidhuber, 1997):
LSTM with Forget Gate (Gers et al., 2000):
Peephole LSTM (Gers & Schmidhuber, 2000):
| Year | Variant | Key Change | Impact |
|---|---|---|---|
| 1997 | Original LSTM | Input + Output gates only | Solved vanishing gradients |
| 2000 | Forget Gate LSTM | Added forget gate | Became standard, cleared memory |
| 2000 | Peephole LSTM | Cell state in gate inputs | Better timing precision |
| 2014 | GRU | Simplified to 2 gates | Faster, often comparable |
| 2014-16 | Various normalizations | Layer/weight normalization | Improved stability |
| 2016-17 | AWD-LSTM | Weight dropout regularization | State-of-art language models |
The "Standard" LSTM Today:
When papers and libraries refer to "LSTM" without qualification, they typically mean:
$$f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$$ $$i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$$ $$\tilde{c}t = \tanh(W_c \cdot [h{t-1}, x_t] + b_c)$$ $$c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}t$$ $$o_t = \sigma(W_o \cdot [h{t-1}, x_t] + b_o)$$ $$h_t = o_t \odot \tanh(c_t)$$
This is the forget-gate LSTM without peepholes—a good balance of expressivity and efficiency.
In standard LSTM, gates compute their values based on $h_{t-1}$ and $x_t$, but they cannot directly observe the cell state $c_{t-1}$. Peephole connections add this direct pathway.
Motivation:
Gates should ideally know what's stored in the cell state to make optimal decisions:
Without peepholes, gates can only infer cell state indirectly through $h_{t-1} = o_{t-1} \odot \tanh(c_{t-1})$—a lossy, gated view.
Peephole Equations:
$$f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + w_f \odot c_{t-1} + b_f)$$ $$i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + w_i \odot c_{t-1} + b_i)$$ $$c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}t$$ $$o_t = \sigma(W_o \cdot [h{t-1}, x_t] + w_o \odot c_t + b_o)$$
Note: $w_f, w_i, w_o$ are vectors (element-wise multiplication), not matrices. The output gate peephole uses $c_t$ (current cell state) rather than $c_{t-1}$.
When Do Peepholes Help?
Precise timing tasks: When gates need to fire based on cell state values reaching thresholds (e.g., count reaching N)
Value-conditional gating: When whether to forget/write/output depends on what is currently stored
Time series with value triggers: Financial applications where actions trigger at specific price levels
When Are Peepholes Unnecessary?
Language modeling: Typically no benefit; $h_{t-1}$ provides sufficient information
Standard classification: Peephole overhead > benefit for most classification tasks
Real-time constraints: Additional computation (3 × hidden_size parameters) may not be worth it
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
import torchimport torch.nn as nnimport torch.nn.functional as F class PeepholeLSTMCell(nn.Module): """LSTM cell with peephole connections.""" def __init__(self, input_size: int, hidden_size: int): super().__init__() self.hidden_size = hidden_size # Standard weights self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size)) self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size)) self.bias = nn.Parameter(torch.zeros(4 * hidden_size)) # Peephole weights (vectors, not matrices) self.weight_ci = nn.Parameter(torch.randn(hidden_size)) # Input gate peephole self.weight_cf = nn.Parameter(torch.randn(hidden_size)) # Forget gate peephole self.weight_co = nn.Parameter(torch.randn(hidden_size)) # Output gate peephole self._init_weights() def _init_weights(self): # Standard initialization nn.init.xavier_uniform_(self.weight_ih) nn.init.orthogonal_(self.weight_hh) # Forget bias = 1 self.bias.data[self.hidden_size:2*self.hidden_size].fill_(1.0) # Small peephole weights nn.init.uniform_(self.weight_ci, -0.1, 0.1) nn.init.uniform_(self.weight_cf, -0.1, 0.1) nn.init.uniform_(self.weight_co, -0.1, 0.1) def forward(self, x, state): h_prev, c_prev = state # Standard gate pre-activations gates = (F.linear(x, self.weight_ih) + F.linear(h_prev, self.weight_hh) + self.bias) i_pre, f_pre, g_pre, o_pre = gates.chunk(4, dim=-1) # Add peephole connections to input and forget gates # using PREVIOUS cell state c_{t-1} i = torch.sigmoid(i_pre + self.weight_ci * c_prev) f = torch.sigmoid(f_pre + self.weight_cf * c_prev) g = torch.tanh(g_pre) # Update cell state c = f * c_prev + i * g # Output gate peephole uses CURRENT cell state c_t o = torch.sigmoid(o_pre + self.weight_co * c) # Hidden state h = o * torch.tanh(c) return h, c # Parameter comparisonstandard_lstm_params = 4 * 256 * (256 + 128 + 1) # 4n(n+d+1)peephole_lstm_params = standard_lstm_params + 3 * 256 # + 3n peepholes print(f"Standard LSTM: {standard_lstm_params:,} parameters")print(f"Peephole LSTM: {peephole_lstm_params:,} parameters")print(f"Overhead: {3 * 256 / standard_lstm_params * 100:.1f}%")The influential study "An Empirical Exploration of Recurrent Network Architectures" (Jozefowicz et al., 2015) found that peephole connections rarely help significantly for language modeling and other typical tasks. The standard non-peephole LSTM is usually the best default choice. Consider peepholes only for tasks with strong value-dependent timing requirements.
An important observation about standard LSTM: the input gate and forget gate often learn complementary patterns. When we write new information, we often want to forget old; when we preserve memory, we often want to block new input.
Coupling Hypothesis:
Instead of learning $f_t$ and $i_t$ independently, we could enforce:
$$i_t = 1 - f_t$$
This reduces the cell state update to:
$$c_t = f_t \odot c_{t-1} + (1 - f_t) \odot \tilde{c}_t$$
Interpretation:
This is a weighted average between old and new, rather than independent scaling.
When Coupling Works Well:
When Coupling Hurts:
Historical Note:
Coupled gates became the foundation for the GRU architecture (next section), which further simplified the gating mechanism while retaining most of LSTM's power.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
import torchimport torch.nn as nn class CoupledGateLSTMCell(nn.Module): """LSTM with coupled input-forget gates: i = 1 - f""" def __init__(self, input_size: int, hidden_size: int): super().__init__() self.hidden_size = hidden_size # Only 3 gates now: forget (input = 1-forget), cell, output self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size)) self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size)) self.bias = nn.Parameter(torch.zeros(3 * hidden_size)) self._init_weights() def _init_weights(self): nn.init.xavier_uniform_(self.weight_ih) nn.init.orthogonal_(self.weight_hh) # Forget bias = 1 (remember by default) self.bias.data[:self.hidden_size].fill_(1.0) def forward(self, x, state): h_prev, c_prev = state gates = (torch.mm(x, self.weight_ih.t()) + torch.mm(h_prev, self.weight_hh.t()) + self.bias) # Split into 3 gates (not 4) f_pre, g_pre, o_pre = gates.chunk(3, dim=-1) f = torch.sigmoid(f_pre) i = 1 - f # Coupled: input = 1 - forget g = torch.tanh(g_pre) o = torch.sigmoid(o_pre) # Cell update is now a weighted average c = f * c_prev + i * g # = f*c + (1-f)*g h = o * torch.tanh(c) return h, c # Comparisonstandard_params = 4 * 256 * (256 + 128) # 4n(n+d)coupled_params = 3 * 256 * (256 + 128) # 3n(n+d)print(f"Standard LSTM: {standard_params:,} params")print(f"Coupled LSTM: {coupled_params:,} params")print(f"Savings: {(standard_params - coupled_params) / standard_params * 100:.1f}%")The Gated Recurrent Unit (GRU), introduced by Cho et al. in 2014, is the most significant simplification of LSTM. It combines the cell state and hidden state into a single state vector and uses only two gates instead of three.
GRU Equations:
$$z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \quad \text{(Update gate)}$$ $$r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \quad \text{(Reset gate)}$$ $$\tilde{h}t = \tanh(W_h \cdot [r_t \odot h{t-1}, x_t] + b_h) \quad \text{(Candidate)}$$ $$h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \quad \text{(Update)}$$
Key Differences from LSTM:
| Aspect | LSTM | GRU |
|---|---|---|
| Number of gates | 3 (forget, input, output) | 2 (update, reset) |
| State vectors | 2 (h and c) | 1 (h only) |
| Parameters per layer | 4n(n+d) | 3n(n+d) |
| Cell/hidden coupling | Separate, gated output | Unified, direct output |
| Update mechanism | Additive: f·c + i·c̃ | Interpolation: (1-z)·h + z·h̃ |
| Memory protection | Cell state behind output gate | Exposed in hidden state |
Understanding GRU's Gates:
Update Gate ($z_t$): Analogous to LSTM's coupled forget-input gates. Determines how much of the new candidate to incorporate:
Reset Gate ($r_t$): Controls how much of the previous hidden state influences the candidate computation:
The reset gate is unique to GRU and provides a different kind of "forgetting" than LSTM's forget gate. It doesn't erase memory directly—it determines whether memory influences the next candidate.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
import torchimport torch.nn as nnimport torch.nn.functional as F class GRUCell(nn.Module): """ Gated Recurrent Unit implementation. Simpler than LSTM with comparable performance. """ def __init__(self, input_size: int, hidden_size: int): super().__init__() self.hidden_size = hidden_size # Gate computations (3 combined: reset, update, candidate) self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size)) self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size)) self.bias_ih = nn.Parameter(torch.zeros(3 * hidden_size)) self.bias_hh = nn.Parameter(torch.zeros(3 * hidden_size)) self._init_weights() def _init_weights(self): """Initialize for stable training.""" nn.init.xavier_uniform_(self.weight_ih) nn.init.orthogonal_(self.weight_hh) # GRU doesn't need forget bias trick since update gate works differently def forward(self, x: torch.Tensor, h_prev: torch.Tensor) -> torch.Tensor: """ Args: x: Input (batch, input_size) h_prev: Previous hidden state (batch, hidden_size) Returns: h: New hidden state (batch, hidden_size) """ # Compute gates from input gates_x = F.linear(x, self.weight_ih, self.bias_ih) gates_h = F.linear(h_prev, self.weight_hh, self.bias_hh) # Split for reset and update (but NOT candidate yet) r_x, z_x, n_x = gates_x.chunk(3, dim=-1) r_h, z_h, n_h = gates_h.chunk(3, dim=-1) # Reset and update gates r = torch.sigmoid(r_x + r_h) z = torch.sigmoid(z_x + z_h) # Candidate uses reset-gated hidden state # This is the key difference: r gates the hidden contribution to candidate n = torch.tanh(n_x + r * n_h) # Interpolate between old and new h = (1 - z) * h_prev + z * n return h # Compare parameter countsinput_size, hidden_size = 256, 512 lstm_params = 4 * hidden_size * (input_size + hidden_size + 1)gru_params = 3 * hidden_size * (input_size + hidden_size + 1) print(f"LSTM parameters: {lstm_params:,}")print(f"GRU parameters: {gru_params:,}")print(f"GRU savings: {(1 - gru_params/lstm_params) * 100:.1f}%")Choose GRU when: • Training speed is important (faster per step) • Memory/compute is constrained (fewer parameters) • Sequential dependencies are moderate (<100 steps) • Empirical testing shows comparable accuracy
Choose LSTM when: • Very long dependencies (>100 steps) are critical • Task requires accumulation (separate f and i gates help) • Memory protection is important (output gate) • LSTM has proven better for your specific domain
In practice, try both and measure—the performance difference is often small.
The choice between LSTM and GRU has been extensively studied. Let's examine the empirical evidence and theoretical differences.
Empirical Studies:
Multiple large-scale studies have compared LSTM and GRU:
Chung et al. (2014) — Original GRU paper: GRU competitive with LSTM on speech and music modeling
Jozefowicz et al. (2015) — 10,000+ architectures tested: "LSTM and GRU performed comparably on most tasks"
Greff et al. (2017) — Systematic LSTM component analysis: "Forget gate and output activation are crucial; other components less so"
Task-Specific Patterns:
| Task Type | LSTM Advantage | GRU Advantage | Notes |
|---|---|---|---|
| Language Modeling | Slight | — | Output gate helps word prediction |
| Machine Translation | — | Slight | GRU often faster, similar quality |
| Speech Recognition | — | — | Roughly equivalent |
| Music Generation | — | — | Both work well |
| Very Long Sequences | Significant | — | Cell state helps 200+ step deps |
| Small Data | — | Slight | Fewer params = less overfitting |
| Real-time Inference | — | Clear | 25% fewer operations |
Gradient Flow Comparison:
Both architectures improve over vanilla RNN, but slightly differently:
LSTM Gradient Path: $$\frac{\partial c_T}{\partial c_t} = \prod_{k=t+1}^{T} f_k$$
GRU Gradient Path: $$\frac{\partial h_T}{\partial h_t} = \prod_{k=t+1}^{T} (1 - z_k) + \text{(additional terms)}$$
LSTM's gradient path is "purer"—only forget gates affect it. GRU's path involves more complex interactions. In practice, both are effective, but LSTM may have a theoretical edge for very long sequences.
1234567891011121314151617181920212223
LSTM Gradient Flow (simplified):─────────────────────────────────c_t = f_t · c_{t-1} + i_t · c̃_t∂c_T/∂c_t = ∏(k=t+1 to T) f_k Key property: Pure product of forget gatesWhen f ≈ 1: Gradient ≈ 1 (preserved)When f ≈ 0: Gradient → 0 (vanishes) GRU Gradient Flow (simplified):────────────────────────────────h_t = (1-z_t) · h_{t-1} + z_t · h̃_t∂h_T/∂h_t = ∏(k=t+1 to T) [(1-z_k) + z_k · ∂h̃_k/∂h_{k-1}] Key property: Product PLUS additional terms from candidate gradients The ∂h̃_k/∂h_{k-1} involves reset gate and weights When z ≈ 0 (preserve): (1-z) ≈ 1, pure preservationWhen z ≈ 1 (update): Gradient flows through candidate path Conclusion:- LSTM: Cleaner separation between "preserve" and "gradient" paths- GRU: More entangled, but often sufficient in practiceFor most practical purposes, LSTM and GRU are interchangeable. Differences are usually within noise on typical benchmarks. The choice often comes down to:
• Default to LSTM for new projects (more widely tested, more resources available) • Use GRU if speed or memory matters and initial experiments show comparable quality • Always benchmark both if the task is important enough to optimize
Beyond single-layer unidirectional LSTM, several structural variants significantly expand modeling capacity.
Stacked (Deep) LSTMs:
Multiple LSTM layers can be stacked, with each layer's output becoming the next layer's input:
$$h_t^{(1)} = \text{LSTM}^{(1)}(x_t, h_{t-1}^{(1)})$$ $$h_t^{(2)} = \text{LSTM}^{(2)}(h_t^{(1)}, h_{t-1}^{(2)})$$ $$\vdots$$ $$h_t^{(L)} = \text{LSTM}^{(L)}(h_t^{(L-1)}, h_{t-1}^{(L)})$$
Benefits of Depth:
Bidirectional LSTMs:
Process sequences in both directions and combine:
$$\overrightarrow{h}t = \text{LSTM}{\text{forward}}(x_t, \overrightarrow{h}{t-1})$$ $$\overleftarrow{h}t = \text{LSTM}{\text{backward}}(x_t, \overleftarrow{h}{t+1})$$ $$h_t = [\overrightarrow{h}_t; \overleftarrow{h}_t] \quad \text{or} \quad h_t = \overrightarrow{h}_t + \overleftarrow{h}_t$$
Benefits of Bidirectionality:
Limitation:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
import torchimport torch.nn as nn class StackedBidirectionalLSTM(nn.Module): """ Multi-layer bidirectional LSTM with residual connections. """ def __init__( self, input_size: int, hidden_size: int, num_layers: int = 2, dropout: float = 0.2, bidirectional: bool = True, residual: bool = True, ): super().__init__() self.num_layers = num_layers self.bidirectional = bidirectional self.residual = residual self.num_directions = 2 if bidirectional else 1 # Input projection if sizes don't match if input_size != hidden_size: self.input_proj = nn.Linear(input_size, hidden_size) else: self.input_proj = nn.Identity() # Stack of LSTM layers self.layers = nn.ModuleList() for i in range(num_layers): layer_input_size = hidden_size if i == 0 else hidden_size * self.num_directions # For residual connections, we need matching dimensions if residual and i > 0: layer_input_size = hidden_size * self.num_directions self.layers.append(nn.LSTM( input_size=hidden_size * self.num_directions if i > 0 else hidden_size, hidden_size=hidden_size, num_layers=1, batch_first=True, bidirectional=bidirectional, )) # Layer normalization between layers self.layer_norms = nn.ModuleList([ nn.LayerNorm(hidden_size * self.num_directions) for _ in range(num_layers - 1) ]) # Dropout between layers self.dropout = nn.Dropout(dropout) # Projection for residual if bidirectional (doubles dimension) if residual and bidirectional: self.residual_proj = nn.Linear(hidden_size, hidden_size * 2) else: self.residual_proj = None def forward(self, x, states=None): """ Args: x: Input tensor (batch, seq_len, input_size) states: Optional initial states for each layer Returns: output: (batch, seq_len, hidden_size * num_directions) final_states: List of (h_n, c_n) for each layer """ # Project input x = self.input_proj(x) # Optional: project for first residual if self.residual and self.residual_proj is not None: residual = self.residual_proj(x) else: residual = x final_states = [] for i, lstm in enumerate(self.layers): # Get layer state if provided layer_state = states[i] if states else None # LSTM forward output, state = lstm(x, layer_state) final_states.append(state) # Layer norm (except last layer) if i < len(self.layer_norms): output = self.layer_norms[i](output) # Residual connection (except first layer) if self.residual and i > 0: output = output + residual # Update residual for next layer residual = output # Dropout (except last layer) if i < self.num_layers - 1: output = self.dropout(output) # Output becomes input to next layer x = output return output, final_states # Example usagemodel = StackedBidirectionalLSTM( input_size=256, hidden_size=512, num_layers=3, bidirectional=True, residual=True) x = torch.randn(32, 100, 256) # batch, seq, featuresoutput, states = model(x)print(f"Output shape: {output.shape}") # (32, 100, 1024)Research has produced many LSTM variants beyond the basics. Here are the most impactful.
Layer-Normalized LSTM (LN-LSTM):
Applies layer normalization to the gates before activation:
$$f_t = \sigma(\text{LN}(W_f \cdot [h, x] + b_f))$$
Benefits:
AWD-LSTM (ASGD Weight-Dropped LSTM):
Introduced by Merity et al. (2017), combines several regularization techniques:
Achieved state-of-the-art language modeling before Transformers dominated.
Mogrifier LSTM:
Introduces mutual gating between input and hidden state:
$$x^{(1)} = x \odot \sigma(W_x^{(1)} h)$$ $$h^{(1)} = h \odot \sigma(W_h^{(1)} x^{(1)})$$ $$x^{(2)} = x^{(1)} \odot \sigma(W_x^{(2)} h^{(1)})$$ $$\ldots$$
Multiple rounds of this "mogrification" before standard LSTM computation improves language modeling.
| Variant | Key Innovation | Best For | Overhead |
|---|---|---|---|
| LN-LSTM | Layer norm on gates | Stable training | ~10% slower |
| AWD-LSTM | Multiple regularizations | Language modeling | Training only |
| Mogrifier | Mutual input-hidden gating | Language modeling | ~30% slower |
| IndRNN | Independent recurrent units | Very long sequences | Simplified |
| SRU | Simple Recurrent Unit | Speed-critical | Much faster |
For most applications: Standard LSTM with forget bias = 1.0
For language modeling: AWD-LSTM or Mogrifier if pushing SOTA
For unstable training: LN-LSTM
For very long sequences: Consider IndRNN or attention augmentation
For speed-critical inference: SRU or GRU
Always profile on your specific task—benchmarks don't always transfer.
We've explored the rich landscape of LSTM variants—from historical evolution through modern extensions.
Key Insights:
Module Complete:
With this page, you've completed the comprehensive study of Long Short-Term Memory networks. You now understand:
This knowledge prepares you for the next module on Gated Recurrent Units, where we'll dive deeper into GRU and compare it systematically with LSTM.
Congratulations! You've mastered Long Short-Term Memory networks—the architecture that revolutionized sequence modeling and enabled accurate deep learning for sequential data. You're now equipped to implement, debug, and optimize LSTM networks for any application.