Loading content...
In the previous module, we explored self-attention—a mechanism that allows each position in a sequence to attend to all other positions, computing weighted combinations based on learned compatibility functions. While self-attention represents a significant advancement over recurrent architectures, a single attention head imposes a fundamental limitation: it can only focus on one type of relationship at a time.
Consider the sentence: "The animal didn't cross the street because it was too tired."
Understanding this sentence requires simultaneously tracking multiple types of relationships:
A single attention mechanism, no matter how sophisticated, computes a single weighted average of the value vectors. This forces the model to collapse all these diverse relationships into one representation—an impossible task that inevitably loses information.
By the end of this page, you will understand why single-head attention is fundamentally limited, how multi-head attention overcomes this limitation through parallel attention mechanisms in different representation subspaces, and the mathematical formulation that enables this powerful architecture. You'll develop deep intuition for why multiple heads are necessary for expressive sequence modeling.
To understand why multiple attention heads are necessary, we must first rigorously analyze the limitations of single-head attention. This analysis reveals deep mathematical constraints that motivate the multi-head design.
Single-head attention computes:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
where $Q, K, V$ are the query, key, and value matrices derived from the input. For each query position, this produces exactly one probability distribution over all key positions, which is then used to compute a weighted sum of value vectors.
The softmax function produces a probability distribution that must sum to 1. This means a single attention head must divide its attention budget across all positions. If the model needs to strongly attend to multiple positions for different reasons, a single head cannot do this effectively—attending strongly to one position necessarily reduces attention to others.
Mathematical Analysis: The Averaging Problem
Consider a query position $q$ that needs to gather information from two semantically distinct positions $a$ and $b$. With single-head attention:
$$o_q = \alpha_a v_a + \alpha_b v_b + \sum_{i \neq a,b} \alpha_i v_i$$
where $\sum_i \alpha_i = 1$. If $v_a$ and $v_b$ encode orthogonal types of information (say, syntactic role vs. semantic category), the output $o_q$ is a blended mixture that conflates both signals.
This conflation has several consequences:
Information Loss: The network cannot recover the individual contributions from $v_a$ and $v_b$ once they're averaged
Interference: Features from different positions may destructively interfere when combined
Representation Bottleneck: The single $d_v$-dimensional output must encode all attended information, creating a representational bottleneck
| Constraint | Mathematical Manifestation | Practical Impact |
|---|---|---|
| Single attention distribution | softmax outputs sum to 1 | Cannot attend strongly to multiple positions simultaneously |
| Fixed representation subspace | Single $W^Q, W^K, W^V$ projection | All relationships projected to same subspace |
| Output dimensionality equals $d_v$ | Output shape limited by value dimension | Limited capacity for encoding multiple relationships |
| Linear combination of values | $\sum_i \alpha_i v_i$ | Information from different positions is averaged, not preserved |
Empirical Evidence of the Limitation
Research has demonstrated that single-head attention consistently underperforms on tasks requiring multi-faceted reasoning:
Syntactic analysis: Single heads struggle to simultaneously track local dependencies (adjacent words) and long-range dependencies (subject-verb agreement across clauses)
Coreference resolution: Resolving pronouns requires attending to multiple candidate antecedents and weighing syntactic, semantic, and positional cues
Multi-hop reasoning: Questions requiring information from multiple passages cannot be answered by attending to just one location
These limitations motivated the development of multi-head attention, which we'll now explore in detail.
Multi-head attention addresses the single-head limitations through a elegant yet powerful idea: run multiple attention mechanisms in parallel, each operating in its own learned representation subspace, then combine their outputs.
The Core Insight:
Instead of computing one attention function with $d_{model}$-dimensional queries, keys, and values, we project the inputs $h$ times into lower-dimensional subspaces, apply attention in each subspace independently, then concatenate and project the results back.
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O$$
where each head is computed as:
$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$
Here, $W_i^Q \in \mathbb{R}^{d_{model} \times d_k}$, $W_i^K \in \mathbb{R}^{d_{model} \times d_k}$, $W_i^V \in \mathbb{R}^{d_{model} \times d_v}$, and $W^O \in \mathbb{R}^{hd_v \times d_{model}}$.
Each head's projection matrices $W_i^Q, W_i^K, W_i^V$ define a unique representation subspace. By learning different projections, each head can specialize in detecting different types of relationships—one head might learn to track syntactic dependencies, another semantic similarity, another positional patterns. The network learns these specializations from data during training.
Dimensional Analysis: No Additional Computation
A crucial design decision in the original Transformer is that multi-head attention maintains the same computational cost as single-head attention would with the full dimension:
For standard Transformer configurations:
Computation comparison:
| Configuration | Q/K/V Projections | Attention Computation | Output Projection |
|---|---|---|---|
| Single Head ($d_k = 512$) | $3 \times n \times d_{model}^2$ | $n^2 \times d_{model}$ | $n \times d_{model}^2$ |
| Multi-Head (8 heads, $d_k = 64$) | $3 \times n \times d_{model}^2$ | $8 \times n^2 \times 64$ | $n \times d_{model}^2$ |
The total parameter count and computation are equivalent, but multi-head attention distributes the capacity across independent attention functions, enabling richer representations.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
import torchimport torch.nn as nnimport torch.nn.functional as Fimport math class MultiHeadAttention(nn.Module): """ Multi-Head Attention mechanism as described in 'Attention Is All You Need'. This implementation clearly separates the conceptual components: 1. Linear projections to Q, K, V for each head (in one batched operation) 2. Parallel attention computation across all heads 3. Concatenation and output projection Args: d_model: The model's embedding dimension num_heads: Number of attention heads (h) dropout: Dropout probability for attention weights """ def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1): super().__init__() assert d_model % num_heads == 0, ( f"d_model ({d_model}) must be divisible by num_heads ({num_heads})" ) self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # Dimension per head self.d_v = d_model // num_heads # Usually d_v = d_k # Linear projections for Q, K, V (all heads combined) # Each projects from d_model to d_model (h * d_k = h * d_v = d_model) self.W_q = nn.Linear(d_model, d_model, bias=False) self.W_k = nn.Linear(d_model, d_model, bias=False) self.W_v = nn.Linear(d_model, d_model, bias=False) # Output projection: concatenated heads back to d_model self.W_o = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.d_k) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None, return_attention: bool = False ) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: """ Forward pass for multi-head attention. Args: query: Query tensor of shape (batch, seq_len_q, d_model) key: Key tensor of shape (batch, seq_len_k, d_model) value: Value tensor of shape (batch, seq_len_k, d_model) mask: Optional attention mask (batch, 1, seq_len_q, seq_len_k) return_attention: Whether to return attention weights Returns: Output tensor of shape (batch, seq_len_q, d_model) Optionally, attention weights of shape (batch, num_heads, seq_len_q, seq_len_k) """ batch_size = query.size(0) seq_len_q = query.size(1) seq_len_k = key.size(1) # Step 1: Linear projections for all heads simultaneously # Shape: (batch, seq_len, d_model) -> (batch, seq_len, d_model) Q = self.W_q(query) K = self.W_k(key) V = self.W_v(value) # Step 2: Reshape to separate heads # Shape: (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k) Q = Q.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2) K = K.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2) V = V.view(batch_size, seq_len_k, self.num_heads, self.d_v).transpose(1, 2) # Step 3: Compute scaled dot-product attention for all heads in parallel # Attention scores: (batch, num_heads, seq_len_q, d_k) @ (batch, num_heads, d_k, seq_len_k) # Result: (batch, num_heads, seq_len_q, seq_len_k) attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # Apply mask if provided (e.g., for causal attention or padding) if mask is not None: attention_scores = attention_scores.masked_fill(mask == 0, float('-inf')) # Softmax over the key dimension attention_weights = F.softmax(attention_scores, dim=-1) attention_weights = self.dropout(attention_weights) # Step 4: Apply attention weights to values # (batch, num_heads, seq_len_q, seq_len_k) @ (batch, num_heads, seq_len_k, d_v) # Result: (batch, num_heads, seq_len_q, d_v) context = torch.matmul(attention_weights, V) # Step 5: Concatenate heads # (batch, num_heads, seq_len_q, d_v) -> (batch, seq_len_q, num_heads * d_v) context = context.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model) # Step 6: Final output projection output = self.W_o(context) if return_attention: return output, attention_weights return output # Demonstration of the dimensional flowdef demonstrate_multi_head_shapes(): """Show the tensor shapes at each stage of multi-head attention.""" # Configuration batch_size = 2 seq_len = 10 d_model = 512 num_heads = 8 d_k = d_model // num_heads # 64 print("Multi-Head Attention Dimensional Analysis") print("=" * 60) print(f"Configuration: batch={batch_size}, seq_len={seq_len}, " f"d_model={d_model}, heads={num_heads}, d_k={d_k}") print() # Create module and input mha = MultiHeadAttention(d_model, num_heads) x = torch.randn(batch_size, seq_len, d_model) print(f"Input shape: {x.shape}") print(f" -> (batch={batch_size}, seq_len={seq_len}, d_model={d_model})") print() # Trace forward pass manually Q = mha.W_q(x) print(f"After Q projection: {Q.shape}") Q_heads = Q.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2) print(f"After reshape to heads: {Q_heads.shape}") print(f" -> (batch, num_heads, seq_len, d_k)") print() # Attention computation output, attn = mha(x, x, x, return_attention=True) print(f"Attention weights shape: {attn.shape}") print(f" -> (batch, num_heads, seq_len_q, seq_len_k)") print(f" Note: Each head has its own {seq_len}x{seq_len} attention matrix") print() print(f"Final output shape: {output.shape}") print(f" -> Same as input: (batch, seq_len, d_model)") if __name__ == "__main__": demonstrate_multi_head_shapes()Let's develop the complete mathematical formulation of multi-head attention, analyzing each component rigorously.
Setup and Notation:
Given an input sequence $X \in \mathbb{R}^{n \times d_{model}}$ with $n$ positions and $d_{model}$ embedding dimensions:
Per-Head Computation:
For each head $i \in {1, \ldots, h}$:
Project to subspace: $$Q_i = XW_i^Q \in \mathbb{R}^{n \times d_k}$$ $$K_i = XW_i^K \in \mathbb{R}^{n \times d_k}$$ $$V_i = XW_i^V \in \mathbb{R}^{n \times d_v}$$
Compute attention scores: $$A_i = \frac{Q_i K_i^T}{\sqrt{d_k}} \in \mathbb{R}^{n \times n}$$
Apply softmax row-wise: $$\tilde{A}_i = \text{softmax}(A_i) \in \mathbb{R}^{n \times n}$$
Compute attended values: $$\text{head}_i = \tilde{A}_i V_i \in \mathbb{R}^{n \times d_v}$$
Aggregation:
Concatenate all heads: $$\text{Concat} = [\text{head}_1; \text{head}_2; \ldots; \text{head}_h] \in \mathbb{R}^{n \times hd_v}$$
Project back to model dimension: $$\text{MultiHead}(X) = \text{Concat} \cdot W^O \in \mathbb{R}^{n \times d_{model}}$$
The output projection $W^O$ is not just a dimensionality reduction—it's a learned mixing function that combines information across heads. Each column of $W^O$ learns how to weight and combine the contributions from all heads for each output dimension. This allows the model to learn complex interactions between what different heads have discovered.
Gradient Flow Analysis:
Multi-head attention has favorable gradient properties. Let's trace the gradient flow for position $j$:
$$\frac{\partial \mathcal{L}}{\partial x_j} = \sum_{i=1}^{h} \left( \frac{\partial \mathcal{L}}{\partial \text{head}_i} \cdot \frac{\partial \text{head}_i}{\partial x_j} \right) \cdot W^O$$
Key observations:
Parallel gradient paths: Gradients flow through $h$ independent attention computations, then sum through $W^O$
No vanishing through heads: Unlike RNNs, gradients don't traverse sequential operations within multi-head attention
Direct gradient connection: Each position receives gradients directly from all positions that attend to it
Scaling preserves magnitudes: The $\sqrt{d_k}$ scaling in attention keeps gradients in a reasonable range
| Component | Shape | Parameter Count |
|---|---|---|
| Query projections ($h$ heads) | $h \times (d_{model} \times d_k)$ | $d_{model}^2$ |
| Key projections ($h$ heads) | $h \times (d_{model} \times d_k)$ | $d_{model}^2$ |
| Value projections ($h$ heads) | $h \times (d_{model} \times d_v)$ | $d_{model}^2$ |
| Output projection | $d_{model} \times d_{model}$ | $d_{model}^2$ |
| Total | — | $4d_{model}^2$ |
Expressive Power Analysis:
A key question is: what can multi-head attention express that single-head cannot?
Theorem (informal): Multi-head attention can express any function that requires attending to $h$ distinct positions with different weights for the same query, which single-head attention cannot.
Proof sketch:
This separation in expressive power is fundamental, not merely a practical convenience.
One of the most profound aspects of multi-head attention is that each head operates in its own learned representation subspace. This isn't just a computational convenience—it's the mechanism by which heads specialize to capture different types of relationships.
The Subspace Mechanism:
For head $i$, the projection matrices $W_i^Q, W_i^K, W_i^V$ define transformations from the shared $d_{model}$-dimensional input space to a $d_k$-dimensional subspace. Different heads learn different projections, meaning they:
Attend to different features: Head 1's $W^Q$ might focus on syntactic features, while Head 2's $W^Q$ focuses on semantic features
Compute different compatibility functions: The $Q_i K_i^T$ product defines a similarity metric in each head's subspace
Extract different value combinations: Each head's $W^V$ determines which aspects of token representations are extracted and combined
Think of each head as wearing differently-colored glasses. Head 1's glasses might highlight all the nouns, making them easy to match. Head 2's glasses highlight verb tenses. Head 3's glasses highlight entities that could be coreferent. Each head sees the same sentence but perceives different structure based on its learned projections.
Empirical Evidence of Specialization:
Research analyzing trained Transformer models has revealed striking patterns of head specialization. In BERT-style models, researchers have identified heads that specialize in:
Syntactic Heads:
Semantic Heads:
Positional Heads:
Rare Pattern Heads:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
import torchimport torch.nn as nnimport numpy as npimport matplotlib.pyplot as pltfrom typing import List, Tuple def analyze_head_subspaces( W_q: torch.Tensor, # Shape: (num_heads, d_model, d_k) W_k: torch.Tensor, # Shape: (num_heads, d_model, d_k)) -> dict: """ Analyze the subspaces defined by different attention heads. Computes: 1. Subspace similarity between heads (how much their query/key spaces overlap) 2. Effective rank of each head's projection 3. Principal directions of each subspace Returns: Dictionary containing analysis results """ num_heads = W_q.shape[0] d_model = W_q.shape[1] d_k = W_q.shape[2] # Compute the "attention subspace" for each head # This is the space where Q^T K compatibility is computed # We can analyze W_q @ W_k^T to understand what the head attends to attention_matrices = [] for i in range(num_heads): # The effective attention "filter" for head i # When applied to x, this determines attention patterns A_i = W_q[i] @ W_k[i].T # Shape: (d_model, d_model) attention_matrices.append(A_i) attention_matrices = torch.stack(attention_matrices) # (num_heads, d_model, d_model) # 1. Compute subspace similarity via Frobenius inner product subspace_similarity = torch.zeros(num_heads, num_heads) for i in range(num_heads): for j in range(num_heads): # Normalize and compute correlation A_i_flat = attention_matrices[i].flatten() A_j_flat = attention_matrices[j].flatten() A_i_norm = A_i_flat / (A_i_flat.norm() + 1e-8) A_j_norm = A_j_flat / (A_j_flat.norm() + 1e-8) subspace_similarity[i, j] = (A_i_norm * A_j_norm).sum() # 2. Compute effective rank of each head's attention matrix effective_ranks = [] for i in range(num_heads): # SVD of the attention matrix U, S, V = torch.svd(attention_matrices[i]) # Effective rank = exp(entropy of normalized singular values) S_normalized = S / S.sum() S_normalized = S_normalized[S_normalized > 1e-10] # Remove zeros entropy = -(S_normalized * torch.log(S_normalized + 1e-10)).sum() effective_rank = torch.exp(entropy) effective_ranks.append(effective_rank.item()) # 3. Extract principal directions (top singular vectors) principal_directions = [] for i in range(num_heads): U, S, V = torch.svd(attention_matrices[i]) # Top-k principal directions k = min(5, d_model) principal_directions.append({ 'left_singular': U[:, :k], 'singular_values': S[:k], 'right_singular': V[:, :k] }) return { 'subspace_similarity': subspace_similarity, 'effective_ranks': effective_ranks, 'principal_directions': principal_directions, 'attention_matrices': attention_matrices } def visualize_subspace_similarity(similarity_matrix: torch.Tensor, title: str = "Head Subspace Similarity"): """ Visualize pairwise similarity between attention head subspaces. """ fig, ax = plt.subplots(figsize=(8, 6)) num_heads = similarity_matrix.shape[0] im = ax.imshow(similarity_matrix.numpy(), cmap='RdBu_r', vmin=-1, vmax=1) # Add colorbar cbar = plt.colorbar(im) cbar.set_label('Subspace Correlation') # Labels ax.set_xticks(range(num_heads)) ax.set_yticks(range(num_heads)) ax.set_xticklabels([f'Head {i}' for i in range(num_heads)]) ax.set_yticklabels([f'Head {i}' for i in range(num_heads)]) ax.set_title(title) ax.set_xlabel('Head') ax.set_ylabel('Head') # Add correlation values as text for i in range(num_heads): for j in range(num_heads): val = similarity_matrix[i, j].item() color = 'white' if abs(val) > 0.5 else 'black' ax.text(j, i, f'{val:.2f}', ha='center', va='center', color=color, fontsize=8) plt.tight_layout() return fig def demonstrate_subspace_diversity(): """ Demonstrate that trained heads develop diverse subspaces, while randomly initialized heads have more similar subspaces. """ d_model = 512 num_heads = 8 d_k = d_model // num_heads # Simulate random initialization (Xavier) scale = np.sqrt(2.0 / (d_model + d_k)) W_q_random = torch.randn(num_heads, d_model, d_k) * scale W_k_random = torch.randn(num_heads, d_model, d_k) * scale # Simulate "trained" weights with forced diversity # In practice, training finds these automatically W_q_trained = torch.randn(num_heads, d_model, d_k) * scale W_k_trained = torch.randn(num_heads, d_model, d_k) * scale # Force each head to focus on different dimensions (simulating specialization) for i in range(num_heads): # Each head gets a different "focus" region focus_start = (i * d_model) // num_heads focus_end = ((i + 1) * d_model) // num_heads # Amplify weights in focus region W_q_trained[i, focus_start:focus_end, :] *= 3.0 W_k_trained[i, focus_start:focus_end, :] *= 3.0 # Analyze both random_analysis = analyze_head_subspaces(W_q_random, W_k_random) trained_analysis = analyze_head_subspaces(W_q_trained, W_k_trained) print("Subspace Analysis: Random vs Trained Heads") print("=" * 60) print() print("Effective Ranks (higher = more diverse attention patterns):") print(f" Random: {[f'{r:.1f}' for r in random_analysis['effective_ranks']]}") print(f" Trained: {[f'{r:.1f}' for r in trained_analysis['effective_ranks']]}") print() print("Average pairwise similarity (lower = more specialized heads):") # Compute off-diagonal average random_sim = random_analysis['subspace_similarity'] trained_sim = trained_analysis['subspace_similarity'] n = random_sim.shape[0] random_avg = (random_sim.sum() - random_sim.trace()) / (n * n - n) trained_avg = (trained_sim.sum() - trained_sim.trace()) / (n * n - n) print(f" Random: {random_avg:.3f}") print(f" Trained: {trained_avg:.3f}") print() print("Interpretation: Trained heads develop more specialized, diverse subspaces") if __name__ == "__main__": demonstrate_subspace_diversity()The multi-head attention mechanism involves several design choices that affect model capacity, computational efficiency, and expressive power. Understanding these trade-offs is essential for architecting Transformers for specific applications.
1. Number of Heads ($h$)
The number of heads controls the diversity of attention patterns:
Fewer heads ($h = 1-4$): Each head has larger dimension ($d_k$), allowing more nuanced similarity computations within each head, but limits the diversity of attention patterns
Standard heads ($h = 8-16$): The sweet spot found in most Transformer variants, balancing specialization with per-head capacity
Many heads ($h = 32+$): Maximum diversity but each head has very limited dimension, potentially reducing individual head expressivity
| Model | $d_{model}$ | Heads ($h$) | $d_k = d_v$ | Notes |
|---|---|---|---|---|
| Transformer Base | 512 | 8 | 64 | Original paper configuration |
| Transformer Large | 1024 | 16 | 64 | Same d_k, doubled heads |
| BERT-Base | 768 | 12 | 64 | Optimized for NLU tasks |
| BERT-Large | 1024 | 16 | 64 | Scaled BERT |
| GPT-2 Small | 768 | 12 | 64 | Autoregressive language model |
| GPT-2 XL | 1600 | 25 | 64 | Large autoregressive model |
| GPT-3 175B | 12288 | 96 | 128 | Largest GPT-3 variant |
2. Dimension per Head ($d_k, d_v$)
The per-head dimension affects the expressive capacity of individual heads:
Small $d_k$ (16-32):
Standard $d_k$ (64):
Large $d_k$ (128+):
Research has shown that per-head dimension below 32 significantly degrades performance, while dimensions above 64 show diminishing returns for most tasks. This suggests 64 is near the "natural" dimensionality needed to capture important relationships in language.
3. Separate vs. Shared Projections
The standard formulation uses separate $W^Q, W^K, W^V$ matrices. Alternatives include:
Shared Query-Key projections ($W^Q = W^K$):
Tied heads (same projections for some heads):
Factorized projections:
4. The Output Projection Trade-off
The output projection $W^O$ can be:
Full rank ($d_{model} \times d_{model}$):
Removed entirely:
Low-rank factorized:
We've explored the fundamental motivation, mathematical formulation, and design considerations for multiple attention heads. Let's consolidate the key insights:
What's Next:
In the next page, we'll explore head concatenation and output projection—the mechanism by which information from multiple heads is combined. We'll analyze how $W^O$ enables heads to interact, the importance of this projection for model expressivity, and alternatives to simple concatenation.
You now understand why multiple attention heads are essential for Transformers—not just a practical improvement, but a fundamental solution to the representational limitations of single-head attention. The ability to attend in multiple subspaces simultaneously is what enables Transformers to model the rich, multi-faceted relationships in language and other sequences.