Loading learning content...
In the previous pages, we identified the information bottleneck as the fundamental limitation of basic seq2seq models: compressing an entire source sequence into a single fixed-size vector inevitably loses information, especially for long sequences.
Attention mechanisms solve this problem elegantly. Instead of forcing the encoder to compress everything into one vector, attention allows the decoder to dynamically access any part of the encoded sequence at each generation step. The decoder learns to focus on relevant source positions while generating each output token.
This seemingly simple idea—letting the decoder look back at the encoder outputs—revolutionized neural sequence modeling and laid the groundwork for the Transformer architecture that now dominates deep learning.
This page previews attention in the RNN context. You will understand the core attention mechanism, additive vs multiplicative attention, how attention integrates with seq2seq, and the properties that make attention so powerful. Chapter 35 will provide a comprehensive treatment of attention and Transformers.
Before diving into the mechanism, let's understand why attention is necessary through a concrete example.
Machine Translation Example
Consider translating: "The black cat sat on the mat" → "Le chat noir était assis sur le tapis"
When generating the French word "chat" (cat), the decoder needs to focus on the English word "cat". When generating "noir" (black), it needs "black". When generating "tapis" (mat), it needs "mat".
Without Attention:
With Attention:
The attention weights α form a probability distribution over source positions—this is 'soft' attention, a differentiable weighted average. 'Hard' attention would discretely select one position, which is not differentiable and requires reinforcement learning. We focus on soft attention.
The attention mechanism computes a context vector as a weighted sum of encoder hidden states, where weights depend on the current decoder state.
Given:
Step 1: Compute Alignment Scores
For each source position $i$, compute how well it aligns with the current decoder state:
$$e_{ti} = \text{score}(\mathbf{s}_t, \mathbf{h}_i^{\text{enc}})$$
Step 2: Normalize to Attention Weights
Apply softmax to get a probability distribution:
$$\alpha_{ti} = \frac{\exp(e_{ti})}{\sum_{j=1}^{T_x} \exp(e_{tj})}$$
Step 3: Compute Context Vector
Weighted sum of encoder states:
$$\mathbf{c}t = \sum{i=1}^{T_x} \alpha_{ti} \mathbf{h}_i^{\text{enc}}$$
Step 4: Use Context in Decoder
Combine context with decoder hidden state for output:
$$\tilde{\mathbf{s}}_t = \tanh(\mathbf{W}_c[\mathbf{c}_t; \mathbf{s}t])$$ $$P(y_t | y{<t}, \mathbf{x}) = \text{softmax}(\mathbf{W}_o \tilde{\mathbf{s}}_t)$$
Key Properties
| Property | Implication |
|---|---|
| Weights sum to 1 | $\sum_i \alpha_{ti} = 1$ (valid probability distribution) |
| Differentiable | End-to-end training with backpropagation |
| Position-specific | Different context for each decoder step |
| Soft selection | Smooth combination, not discrete choice |
| Interpretable | Weights show which source positions are attended |
The Score Function
The score function $\text{score}(\mathbf{s}, \mathbf{h})$ determines how alignment is computed. Different choices yield different attention variants (next section).
Several score functions have been proposed, each with different computational and modeling properties.
Additive Attention (Bahdanau)
Also called "concat" attention. Uses a feedforward network:
$$e_{ti} = \mathbf{v}^\top \tanh(\mathbf{W}_s \mathbf{s}_t + \mathbf{W}_h \mathbf{h}_i^{\text{enc}})$$
where:
Multiplicative Attention (Luong)
Also called "dot-product" attention. Direct dot product:
$$e_{ti} = \mathbf{s}_t^\top \mathbf{h}_i^{\text{enc}}$$
or with learned transformation:
$$e_{ti} = \mathbf{s}_t^\top \mathbf{W}_a \mathbf{h}_i^{\text{enc}}$$
Scaled Dot-Product Attention
Divides by square root of dimension to prevent large logits:
$$e_{ti} = \frac{\mathbf{s}_t^\top \mathbf{h}_i^{\text{enc}}}{\sqrt{d}}$$
This is the variant used in Transformers.
| Variant | Formula | Parameters | Complexity |
|---|---|---|---|
| Additive | $\mathbf{v}^\top \tanh(\mathbf{W}_s \mathbf{s} + \mathbf{W}_h \mathbf{h})$ | $d_a(d_s + d_h) + d_a$ | Slower (non-parallelizable) |
| Dot-Product | $\mathbf{s}^\top \mathbf{h}$ | 0 (requires $d_s = d_h$) | Fast (matrix multiplication) |
| General | $\mathbf{s}^\top \mathbf{W}_a \mathbf{h}$ | $d_s \times d_h$ | Fast (one matrix) |
| Scaled Dot-Product | $\frac{\mathbf{s}^\top \mathbf{h}}{\sqrt{d}}$ | 0 | Fast + numerically stable |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
import torchimport torch.nn as nnimport torch.nn.functional as Fimport math class AdditiveAttention(nn.Module): """ Bahdanau-style additive attention. Uses a feedforward network to compute alignment scores. """ def __init__( self, encoder_dim: int, decoder_dim: int, attention_dim: int ): super().__init__() self.encoder_proj = nn.Linear(encoder_dim, attention_dim, bias=False) self.decoder_proj = nn.Linear(decoder_dim, attention_dim, bias=False) self.v = nn.Linear(attention_dim, 1, bias=False) def forward( self, encoder_outputs: torch.Tensor, decoder_hidden: torch.Tensor, mask: torch.Tensor = None ) -> tuple[torch.Tensor, torch.Tensor]: """ Args: encoder_outputs: [batch, src_len, encoder_dim] decoder_hidden: [batch, decoder_dim] mask: [batch, src_len] - True for valid positions Returns: context: [batch, encoder_dim] attention_weights: [batch, src_len] """ src_len = encoder_outputs.size(1) # Project encoder and decoder states encoder_proj = self.encoder_proj(encoder_outputs) # [batch, src_len, attn_dim] decoder_proj = self.decoder_proj(decoder_hidden) # [batch, attn_dim] # Expand decoder projection to match source length decoder_proj = decoder_proj.unsqueeze(1).expand(-1, src_len, -1) # Compute scores energy = torch.tanh(encoder_proj + decoder_proj) # [batch, src_len, attn_dim] scores = self.v(energy).squeeze(-1) # [batch, src_len] # Mask padding positions if mask is not None: scores = scores.masked_fill(~mask, float('-inf')) # Normalize to attention weights attention_weights = F.softmax(scores, dim=-1) # [batch, src_len] # Compute context vector context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs) context = context.squeeze(1) # [batch, encoder_dim] return context, attention_weights class DotProductAttention(nn.Module): """ Luong-style dot-product attention. Fast and efficient, requires matching dimensions. """ def __init__(self, scaled: bool = True): super().__init__() self.scaled = scaled def forward( self, query: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, mask: torch.Tensor = None ) -> tuple[torch.Tensor, torch.Tensor]: """ Args: query: [batch, query_dim] keys: [batch, src_len, key_dim] values: [batch, src_len, value_dim] mask: [batch, src_len] Returns: context: [batch, value_dim] attention_weights: [batch, src_len] """ # Compute dot product: query @ keys^T # query: [batch, 1, query_dim], keys: [batch, src_len, key_dim] scores = torch.bmm(query.unsqueeze(1), keys.transpose(1, 2)) scores = scores.squeeze(1) # [batch, src_len] # Scale by sqrt(d) for numerical stability if self.scaled: d = query.size(-1) scores = scores / math.sqrt(d) # Mask and normalize if mask is not None: scores = scores.masked_fill(~mask, float('-inf')) attention_weights = F.softmax(scores, dim=-1) # Compute context context = torch.bmm(attention_weights.unsqueeze(1), values) context = context.squeeze(1) return context, attention_weights class GeneralAttention(nn.Module): """ General attention with learned transformation. Allows different encoder/decoder dimensions. """ def __init__( self, encoder_dim: int, decoder_dim: int, scaled: bool = True ): super().__init__() self.W = nn.Linear(encoder_dim, decoder_dim, bias=False) self.scaled = scaled def forward( self, decoder_hidden: torch.Tensor, encoder_outputs: torch.Tensor, mask: torch.Tensor = None ) -> tuple[torch.Tensor, torch.Tensor]: """ Args: decoder_hidden: [batch, decoder_dim] encoder_outputs: [batch, src_len, encoder_dim] mask: [batch, src_len] Returns: context: [batch, encoder_dim] attention_weights: [batch, src_len] """ # Transform encoder outputs transformed = self.W(encoder_outputs) # [batch, src_len, decoder_dim] # Compute scores: decoder @ transformed^T scores = torch.bmm( decoder_hidden.unsqueeze(1), # [batch, 1, decoder_dim] transformed.transpose(1, 2) # [batch, decoder_dim, src_len] ).squeeze(1) # [batch, src_len] if self.scaled: scores = scores / math.sqrt(decoder_hidden.size(-1)) if mask is not None: scores = scores.masked_fill(~mask, float('-inf')) attention_weights = F.softmax(scores, dim=-1) context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs) context = context.squeeze(1) return context, attention_weightsLet's see how attention integrates into the full seq2seq architecture we developed earlier.
Bahdanau Attention (Input-Feeding)
In the original Bahdanau attention, the context is computed from the previous decoder state and fed as additional input to the current decoder step:
Luong Attention
In Luong attention, context is computed from the current decoder state:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
import torchimport torch.nn as nnimport torch.nn.functional as F class AttentionDecoder(nn.Module): """ LSTM decoder with Luong-style attention. Computes context from current hidden state after each LSTM step. """ def __init__( self, vocab_size: int, embed_dim: int, encoder_dim: int, decoder_dim: int, attention_dim: int, num_layers: int = 1, dropout: float = 0.2 ): super().__init__() self.vocab_size = vocab_size self.decoder_dim = decoder_dim self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM( input_size=embed_dim, hidden_size=decoder_dim, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0 ) # Attention mechanism self.attention = AdditiveAttention(encoder_dim, decoder_dim, attention_dim) # Combine context + hidden for output self.context_combine = nn.Linear(encoder_dim + decoder_dim, decoder_dim) # Output projection self.fc_out = nn.Linear(decoder_dim, vocab_size) self.dropout = nn.Dropout(dropout) def forward( self, input_token: torch.Tensor, hidden: tuple[torch.Tensor, torch.Tensor], encoder_outputs: torch.Tensor, src_mask: torch.Tensor = None ) -> tuple[torch.Tensor, tuple, torch.Tensor]: """ Single decoder step with attention. Args: input_token: [batch, 1] - current input token hidden: (h, c) - LSTM states encoder_outputs: [batch, src_len, encoder_dim] src_mask: [batch, src_len] - True for valid positions Returns: output: [batch, vocab_size] - token probabilities hidden: Updated (h, c) attention_weights: [batch, src_len] """ # Embed input embedded = self.dropout(self.embedding(input_token)) # [batch, 1, embed] # LSTM step rnn_output, hidden = self.lstm(embedded, hidden) rnn_output = rnn_output.squeeze(1) # [batch, decoder_dim] # Attention over encoder outputs context, attention_weights = self.attention( encoder_outputs, rnn_output, src_mask ) # context: [batch, encoder_dim] # Combine context and RNN output combined = torch.cat([context, rnn_output], dim=-1) combined = torch.tanh(self.context_combine(combined)) combined = self.dropout(combined) # Project to vocabulary output = self.fc_out(combined) # [batch, vocab_size] return output, hidden, attention_weights class AttentionSeq2Seq(nn.Module): """ Complete seq2seq model with attention. """ def __init__( self, encoder: nn.Module, decoder: AttentionDecoder, device: torch.device ): super().__init__() self.encoder = encoder self.decoder = decoder self.device = device def forward( self, src: torch.Tensor, src_lengths: torch.Tensor, trg: torch.Tensor, teacher_forcing_ratio: float = 0.5 ) -> tuple[torch.Tensor, torch.Tensor]: """ Training forward with attention. Args: src: [batch, src_len] src_lengths: [batch] trg: [batch, trg_len] teacher_forcing_ratio: probability of teacher forcing Returns: outputs: [batch, trg_len-1, vocab_size] attentions: [batch, trg_len-1, src_len] """ batch_size = src.size(0) trg_len = trg.size(1) src_len = src.size(1) vocab_size = self.decoder.vocab_size # Storage for outputs and attention weights outputs = torch.zeros(batch_size, trg_len - 1, vocab_size).to(self.device) attentions = torch.zeros(batch_size, trg_len - 1, src_len).to(self.device) # Encode source encoder_outputs, hidden = self.encoder(src, src_lengths) # Create source mask src_mask = torch.arange(src_len, device=self.device)[None, :] < src_lengths[:, None] # First input is <sos> decoder_input = trg[:, 0:1] for t in range(1, trg_len): # Decode with attention output, hidden, attn_weights = self.decoder( decoder_input, hidden, encoder_outputs, src_mask ) outputs[:, t-1] = output attentions[:, t-1] = attn_weights # Next input use_tf = torch.rand(1).item() < teacher_forcing_ratio decoder_input = trg[:, t:t+1] if use_tf else output.argmax(-1, keepdim=True) return outputs, attentions def translate( self, src: torch.Tensor, src_lengths: torch.Tensor, max_length: int = 50, sos_idx: int = 2, eos_idx: int = 3 ) -> tuple[list[int], torch.Tensor]: """ Greedy translation with attention visualization. Returns: tokens: List of generated token indices attention_matrix: [generated_len, src_len] """ self.eval() with torch.no_grad(): encoder_outputs, hidden = self.encoder(src, src_lengths) src_mask = torch.arange( src.size(1), device=self.device )[None, :] < src_lengths[:, None] decoder_input = torch.tensor([[sos_idx]], device=self.device) tokens = [] attentions_list = [] for _ in range(max_length): output, hidden, attn = self.decoder( decoder_input, hidden, encoder_outputs, src_mask ) pred_token = output.argmax(dim=-1).item() attentions_list.append(attn.squeeze(0)) if pred_token == eos_idx: break tokens.append(pred_token) decoder_input = torch.tensor([[pred_token]], device=self.device) attention_matrix = torch.stack(attentions_list, dim=0) return tokens, attention_matrixOne of attention's most valuable properties is interpretability. The attention weights reveal which source positions influenced each output position.
Attention Heatmap
For a translation from English to French, the attention matrix shows alignment:
The black cat sat on the mat
Le 0.60 0.05 0.15 0.05 0.05 0.05 0.05
chat 0.05 0.05 0.80 0.03 0.02 0.02 0.03
noir 0.05 0.80 0.05 0.03 0.02 0.02 0.03
était 0.05 0.05 0.10 0.60 0.10 0.05 0.05
assis 0.05 0.05 0.05 0.75 0.05 0.03 0.02
sur 0.03 0.03 0.03 0.03 0.75 0.10 0.03
le 0.03 0.03 0.03 0.03 0.05 0.75 0.08
tapis 0.03 0.03 0.03 0.03 0.03 0.08 0.77
The diagonal pattern shows the model learning approximate word alignment, with some deviations for reordering (e.g., "black cat" → "chat noir").
While attention weights are interpretable, they should be viewed cautiously. High attention doesn't necessarily mean 'the model used this for prediction'—it means 'this contributed to the context vector.' The relationship between attention and model behavior is complex.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
import torchimport matplotlib.pyplot as pltimport seaborn as snsimport numpy as np def visualize_attention( source_tokens: list[str], target_tokens: list[str], attention_matrix: torch.Tensor, save_path: str = None): """ Create attention heatmap visualization. Args: source_tokens: List of source tokens target_tokens: List of target tokens (generated) attention_matrix: [target_len, source_len] tensor save_path: Optional path to save figure """ # Convert to numpy attn = attention_matrix.cpu().numpy() # Create figure fig, ax = plt.subplots(figsize=(10, 8)) # Create heatmap sns.heatmap( attn, xticklabels=source_tokens, yticklabels=target_tokens, cmap='Blues', ax=ax, cbar_kws={'label': 'Attention Weight'} ) ax.set_xlabel('Source Tokens') ax.set_ylabel('Target Tokens') ax.set_title('Attention Weights') # Rotate x labels for readability plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.show() def analyze_attention_patterns( attention_matrices: list[torch.Tensor], threshold: float = 0.3) -> dict: """ Analyze attention patterns across multiple examples. Args: attention_matrices: List of [target_len, source_len] tensors threshold: Minimum attention weight to consider "focused" Returns: Dictionary of statistics """ stats = { 'avg_entropy': [], 'avg_max_attention': [], 'diagonal_alignment': [], 'num_focused_positions': [] } for attn in attention_matrices: attn = attn.cpu() # Attention entropy (lower = more focused) # H = -sum(p * log(p)) entropy = -torch.sum(attn * torch.log(attn + 1e-10), dim=-1) stats['avg_entropy'].append(entropy.mean().item()) # Maximum attention per target position max_attn = attn.max(dim=-1).values stats['avg_max_attention'].append(max_attn.mean().item()) # Diagonal alignment (for monotonic tasks like translation) min_len = min(attn.size(0), attn.size(1)) diagonal = torch.diag(attn[:min_len, :min_len]) stats['diagonal_alignment'].append(diagonal.mean().item()) # Number of positions with attention > threshold focused = (attn > threshold).sum(dim=-1).float() stats['num_focused_positions'].append(focused.mean().item()) # Aggregate statistics return { 'entropy': { 'mean': np.mean(stats['avg_entropy']), 'std': np.std(stats['avg_entropy']) }, 'max_attention': { 'mean': np.mean(stats['avg_max_attention']), 'std': np.std(stats['avg_max_attention']) }, 'diagonal_alignment': { 'mean': np.mean(stats['diagonal_alignment']), 'std': np.std(stats['diagonal_alignment']) }, 'focused_positions': { 'mean': np.mean(stats['num_focused_positions']), 'std': np.std(stats['num_focused_positions']) } }Attention's effectiveness stems from several complementary factors:
1. Eliminates the Bottleneck
Without attention: Information flows through single $\mathbf{c}$ With attention: Direct pathways from every $\mathbf{h}_i^{\text{enc}}$ to decoder
$$\text{Information capacity:} \quad d \quad \text{vs} \quad T_x \cdot d$$
2. Shortens Gradient Paths
In vanilla seq2seq, gradients from output $y_t$ to source $x_1$ must traverse: $$y_t \to \mathbf{s}_t \to \ldots \to \mathbf{s}1 \to \mathbf{c} \to \mathbf{h}{T_x}^{\text{enc}} \to \ldots \to \mathbf{h}_1^{\text{enc}}$$
With attention, there's a direct path: $$y_t \to \mathbf{c}_t \to \mathbf{h}_1^{\text{enc}}$$
This dramatically improves gradient flow for learning long-range dependencies.
3. Task-Appropriate Inductive Bias
Attention encodes the assumption that outputs depend on weighted combinations of inputs—which is accurate for many sequence transduction tasks. The model learns which combinations, but the compositional structure is built-in.
| Component | Without Attention | With Attention |
|---|---|---|
| Encoder burden | Must compress everything into c | Can produce any useful representation per position |
| Decoder access | Same c for all timesteps | Dynamic c_t specific to each timestep |
| Long sequences | Performance degrades significantly | Scales well with sequence length |
| Gradient flow | Long paths, vanishing gradients | Direct paths to every source position |
| Interpretability | Black box | Attention weights show alignment |
Attention's success with RNNs raised a natural question: if attention provides direct access to all positions, do we even need the sequential RNN? The answer—no—led to the Transformer architecture, which uses attention exclusively. Chapter 35 covers this evolution in depth.
Basic attention has spawned numerous extensions addressing its limitations.
Coverage Mechanism
Problem: Attention may repeatedly focus on the same positions (over-translation) or ignore some positions (under-translation).
Solution: Track cumulative attention and penalize re-attention:
$$\text{coverage}t = \sum{t'=1}^{t-1} \alpha_{t'}$$
$$e_{ti} = f(\mathbf{s}_t, \mathbf{h}i^{\text{enc}}, \text{coverage}{t-1,i})$$
Local Attention
Problem: Global attention over all $T_x$ positions is $O(T_x)$ per step.
Solution: Attend only to a window around a predicted position:
$$\alpha_{ti} \propto \exp(-\frac{(i - p_t)^2}{2\sigma^2}) \cdot \text{score}(\mathbf{s}_t, \mathbf{h}_i)$$
where $p_t$ is a predicted or fixed alignment position.
Multi-Head Attention
Problem: Single attention head captures one type of relationship.
Solution: Multiple parallel attention heads, each learning different patterns:
$$\text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = [\text{head}_1; \ldots; \text{head}_h]\mathbf{W}^O$$
This is central to Transformers (Chapter 35).
These limitations motivated the development of self-attention and Transformers, which eliminate the recurrent bottleneck entirely. Chapter 35 covers the full attention and Transformer story.
We have introduced the attention mechanism—the key innovation that overcomes the seq2seq bottleneck and laid the foundation for modern deep learning architectures. Let's consolidate the key takeaways:
Module Complete!
This concludes our exploration of Advanced RNN Topics. You've mastered:
In Chapter 35, we'll dive deep into Attention & Transformers—exploring self-attention, multi-head attention, positional encodings, and the Transformer architecture that now powers state-of-the-art models across NLP, vision, and beyond.
Congratulations! You've completed Module 6: Advanced RNN Topics. You now understand the full landscape of RNN architectures—from basic recurrence through bidirectional processing, deep stacking, sequence-to-sequence translation, and attention mechanisms. These concepts form the foundation for understanding modern sequence models.