Loading learning content...
One of the most fascinating aspects of multi-head attention is that heads naturally specialize to capture different types of patterns. Unlike hand-engineered features, these specializations emerge purely from training on large-scale data. Understanding what heads learn provides:
This page explores the rich landscape of what attention heads learn, from syntactic patterns to semantic relationships, and the methods researchers use to analyze these learned behaviors.
No one tells BERT or GPT that "head 5 in layer 7 should track subject-verb agreement." These specializations emerge from the training objective alone. The fact that interpretable linguistic patterns spontaneously appear in learned attention weights is a profound finding—suggesting that certain linguistic structures are statistically learnable from text alone.
Some of the most striking findings in attention head analysis involve heads that capture syntactic structure—the grammatical relationships between words.
Core Syntactic Phenomena Captured:
1. Subject-Verb Agreement
Certain heads learn to connect subjects to their verbs, even across multiple intervening tokens:
"The keys to the cabinet are on the table."
A subject-verb head would show strong attention from "are" to "keys" (the true subject), not to "cabinet" (the local noun). This is remarkable because:
2. Direct Object Heads
Some heads attend from verbs to their direct objects:
"The chef prepared the delicious meal."
The verb "prepared" would attend to "meal" (direct object) more than to other nouns.
3. Modifier Attachment
Heads that connect modifiers to their targets:
| Pattern | Example | Attention Direction | Layer (typical) |
|---|---|---|---|
| Subject-verb | The dogs [that I saw] run | verb → subject | Middle layers |
| Direct object | She read the book | verb → object | Middle layers |
| Determiner-noun | the large dog | det → noun | Early layers |
| Adjective-noun | the large dog | adj ↔ noun | Early/middle |
| Preposition-object | on the table | prep → object | Early layers |
| Clause boundaries | I think [that she left] | that ↔ clause | Middle layers |
Syntactic heads tend to appear in specific layers: simple syntactic relationships (determiner-noun) in early layers, complex relationships (subject-verb agreement across clauses) in middle layers. This mirrors the linguistic intuition that complex syntax builds on simpler structures.
Quantitative Evidence:
Researchers have evaluated attention heads against gold-standard syntactic parses from linguistic treebanks. Key findings:
The Syntax Paradox:
Language models aren't trained with syntactic supervision, yet they develop syntactically-aware representations. This suggests that:
Beyond syntax, attention heads capture semantic relationships—connections based on meaning rather than grammatical structure.
1. Coreference Resolution Heads
Some heads specialize in linking pronouns to their antecedents:
"Mary went to the store. She bought milk."
A coreference head would show strong attention from "She" to "Mary", identifying them as referring to the same entity. This requires:
2. Semantic Role Heads
Heads that attend based on semantic roles (agent, patient, instrument):
"The hammer broke the window."
Despite "hammer" being the grammatical subject, a semantic role head might encode that "window" is the affected entity (patient) and "hammer" is the instrument.
3. Entity Tracking Heads
Heads that maintain consistent attention to entity mentions throughout a passage:
"Einstein developed relativity. The physicist revolutionized physics. He received the Nobel Prize."
Entity tracking heads would link all three mentions (Einstein, physicist, He) as the same entity, enabling consistent reasoning about the entity's properties.
4. Semantic Similarity Heads
Some heads attend to semantically similar words, even without explicit syntactic relationships:
"The doctor examined the patient. The physician recommended treatment."
A semantic similarity head might show mutual attention between "doctor" and "physician" as synonyms, facilitating lexical abstraction.
5. Named Entity Heads
Heads that selectively attend to or between named entities:
Unlike syntactic patterns which can be verified against parse trees, semantic relationships are harder to evaluate. Coreference can be measured against annotated corpora (like OntoNotes), but broader semantic relationships often require human judgment or proxy tasks. This makes semantic head analysis more exploratory than syntactic analysis.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
import torchimport numpy as npfrom typing import List, Tuple, Dict, Optionalfrom dataclasses import dataclassfrom collections import defaultdict @dataclassclass AttentionPattern: """Represents an attention pattern for analysis.""" head_idx: int layer_idx: int attention_weights: np.ndarray # (seq_len, seq_len) tokens: List[str] def get_attention_for_token(self, token_idx: int) -> np.ndarray: """Get attention FROM a specific token TO all other tokens.""" return self.attention_weights[token_idx] def get_attention_to_token(self, token_idx: int) -> np.ndarray: """Get attention TO a specific token FROM all other tokens.""" return self.attention_weights[:, token_idx] class SyntacticPatternDetector: """ Detect and evaluate syntactic patterns in attention heads. This class implements methods from the research literature for analyzing whether attention heads capture syntactic relations. """ def __init__(self, pattern: AttentionPattern): self.pattern = pattern self.tokens = pattern.tokens self.weights = pattern.attention_weights def evaluate_subject_verb_agreement( self, subject_indices: List[int], verb_indices: List[int] ) -> float: """ Evaluate how well this head captures subject-verb relationships. Args: subject_indices: Token indices of subjects verb_indices: Token indices of corresponding verbs Returns: Average attention from verbs to their subjects """ if len(subject_indices) != len(verb_indices): raise ValueError("Must provide matching subject-verb pairs") scores = [] for subj_idx, verb_idx in zip(subject_indices, verb_indices): # Attention FROM verb TO subject attention_to_subject = self.weights[verb_idx, subj_idx] # Compare to attention to other tokens all_attention = self.weights[verb_idx] rank = (all_attention >= attention_to_subject).sum() # Score: 1 if subject is top-1 attention, scaled by rank otherwise score = 1.0 / rank scores.append(score) return np.mean(scores) def detect_adjacent_patterns(self) -> Dict[str, float]: """ Detect heads that primarily attend to adjacent positions. Returns: Dictionary with pattern statistics """ n = len(self.tokens) # Measure attention to previous token prev_attn = np.mean([ self.weights[i, i-1] if i > 0 else 0 for i in range(n) ]) # Measure attention to next token next_attn = np.mean([ self.weights[i, i+1] if i < n-1 else 0 for i in range(n) ]) # Measure attention to same position self_attn = np.mean([self.weights[i, i] for i in range(n)]) return { 'previous_token': prev_attn, 'next_token': next_attn, 'self_attention': self_attn, 'is_positional_head': (prev_attn > 0.3 or next_attn > 0.3) } def compute_dependency_score( self, gold_dependencies: List[Tuple[int, int]] # (head_idx, dependent_idx) ) -> float: """ Compare attention pattern to gold-standard dependencies. Args: gold_dependencies: List of (head, dependent) index pairs from a dependency parse Returns: F1 score comparing attention "edges" to gold edges """ # Create soft edges from attention (threshold-based) threshold = 1.0 / len(self.tokens) # Above uniform baseline predicted_edges = set() for i in range(len(self.tokens)): for j in range(len(self.tokens)): if i != j and self.weights[i, j] > threshold: # Edge from i to j (attention from token i to token j) predicted_edges.add((j, i)) # (head, dependent) gold_set = set(gold_dependencies) # Compute precision, recall, F1 true_positives = len(predicted_edges & gold_set) precision = true_positives / len(predicted_edges) if predicted_edges else 0 recall = true_positives / len(gold_set) if gold_set else 0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 return f1 class CoreferenceHeadDetector: """ Detect heads that specialize in coreference resolution. """ def __init__(self, pattern: AttentionPattern): self.pattern = pattern def evaluate_coreference( self, coreference_clusters: List[List[int]] # Each cluster: list of token indices ) -> float: """ Evaluate how well this head captures coreference. A good coreference head should show high within-cluster attention and low between-cluster attention. Args: coreference_clusters: Token indices for each coreference cluster Returns: Cluster separation score (higher = better coreference detection) """ weights = self.pattern.attention_weights # Compute average within-cluster attention within_cluster_attn = [] for cluster in coreference_clusters: if len(cluster) < 2: continue for i in cluster: for j in cluster: if i != j: within_cluster_attn.append(weights[i, j]) avg_within = np.mean(within_cluster_attn) if within_cluster_attn else 0 # Compute average between-cluster attention between_cluster_attn = [] for cluster1 in coreference_clusters: for cluster2 in coreference_clusters: if cluster1 is not cluster2: for i in cluster1: for j in cluster2: between_cluster_attn.append(weights[i, j]) avg_between = np.mean(between_cluster_attn) if between_cluster_attn else 0 # Return separation ratio (within/between) return avg_within / (avg_between + 1e-8) def analyze_head_specialization(): """ Demonstrate analysis of attention head specialization. """ print("Attention Head Specialization Analysis") print("=" * 60) # Example sentence with known syntactic structure tokens = ["The", "cat", "that", "I", "saw", "yesterday", "runs", "fast"] # 0 1 2 3 4 5 6 7 # Simulate different head patterns seq_len = len(tokens) # Head 1: Subject-verb head (should link "cat" to "runs") weights_subj_verb = np.random.rand(seq_len, seq_len) * 0.1 weights_subj_verb[6, 1] = 0.8 # "runs" → "cat" weights_subj_verb = weights_subj_verb / weights_subj_verb.sum(axis=-1, keepdims=True) pattern_subj_verb = AttentionPattern( head_idx=0, layer_idx=5, attention_weights=weights_subj_verb, tokens=tokens ) # Head 2: Previous token head (positional) weights_prev = np.zeros((seq_len, seq_len)) for i in range(1, seq_len): weights_prev[i, i-1] = 0.9 weights_prev[i, i] = 0.1 weights_prev[0, 0] = 1.0 # First token attends to self weights_prev = weights_prev / weights_prev.sum(axis=-1, keepdims=True) pattern_prev = AttentionPattern( head_idx=1, layer_idx=2, attention_weights=weights_prev, tokens=tokens ) # Analyze head 1 (subject-verb) detector1 = SyntacticPatternDetector(pattern_subj_verb) sv_score = detector1.evaluate_subject_verb_agreement( subject_indices=[1], # "cat" verb_indices=[6] # "runs" ) print("Head 1 Analysis (Layer 5, Head 0):") print(f" Subject-verb agreement score: {sv_score:.3f}") print(f" Interpretation: {'Strong' if sv_score > 0.5 else 'Weak'} subject-verb head") print() # Analyze head 2 (positional) detector2 = SyntacticPatternDetector(pattern_prev) pos_patterns = detector2.detect_adjacent_patterns() print("Head 2 Analysis (Layer 2, Head 1):") print(f" Previous token attention: {pos_patterns['previous_token']:.3f}") print(f" Next token attention: {pos_patterns['next_token']:.3f}") print(f" Self attention: {pos_patterns['self_attention']:.3f}") print(f" Is positional head: {pos_patterns['is_positional_head']}") print() print("Summary:") print(" Head 1 specializes in subject-verb agreement") print(" Head 2 specializes in attending to previous positions") if __name__ == "__main__": analyze_head_specialization()A significant number of attention heads learn positional patterns—attending based on relative or absolute position rather than content. While these may seem less interesting than syntactic heads, they serve crucial functions.
1. Adjacent Position Heads
Heads that consistently attend to the previous or next token:
Previous-token heads: Allow information to flow from left to right, similar to unidirectional processing
Next-token heads: Enable right-to-left information flow in bidirectional models
These heads implement a form of local context aggregation—gathering information from immediate neighbors regardless of content.
2. Fixed-Offset Heads
Some heads attend to specific relative positions:
These create predictable information highways that other layers can rely on.
3. Beginning/End of Sequence Heads
BOS/CLS heads: Attend to the beginning token, often used as a "global" aggregation point
EOS heads: Attend to end-of-sequence markers
Separator heads: Attend to sentence boundaries or [SEP] tokens
| Head Type | Attention Pattern | Function | Prevalence |
|---|---|---|---|
| Previous token | Position i → i-1 | Local left context | Very common (10-20% of heads) |
| Next token | Position i → i+1 | Local right context | Common in bidirectional models |
| Self | Position i → i | Identity/residual | Some heads in all models |
| First token (CLS) | All positions → 0 | Global aggregation | 1-2 heads per model |
| Separator | → [SEP] tokens | Segment awareness | In segmented models |
| Broad/uniform | ~Uniform over all | Background aggregation | Common in later layers |
Positional heads might seem "dumb"—they ignore content entirely. But they provide essential infrastructure: previous-token heads ensure local context is always accessible; CLS heads provide a consistent global aggregation point; separator heads enable segment-aware reasoning. Without these, content-based heads would need to redundantly learn positional awareness.
4. Delimiter and Punctuation Heads
Heads that specifically attend to structural markers:
These heads help the model understand document structure beyond pure syntax.
5. Sentence-Level Heads
In multi-sentence contexts, some heads attend preferentially within vs. across sentence boundaries:
This division of labor enables both local coherence and discourse-level reasoning.
The Positional-Content Spectrum:
Heads exist on a spectrum from pure positional to pure content-based:
Pure Positional ←──────────────────────────────→ Pure Content
"previous" "previous noun" "subject of" "semantically
token" in clause" current verb" similar"
Most heads combine positional and content signals to varying degrees.
Beyond individual head specialization, attention heads exhibit systematic layer-by-layer patterns. The type of information captured evolves through the network depth.
The General Pattern:
| Layer Position | Predominant Patterns | Information Type |
|---|---|---|
| Early (1-3) | Positional, local syntax | Surface patterns |
| Middle (4-8) | Syntactic relations, coreference | Grammatical structure |
| Late (9-12) | Semantic, task-specific | Abstract meaning/task |
This progression mirrors classical NLP pipelines: tokenization → parsing → semantic analysis.
Early Layers (1-3):
Token identity: Some heads attend strongly to the same token type elsewhere in the sequence
Local patterns: N-gram-like patterns, attending to fixed-window neighbors
Subword reconstruction: Heads that link WordPiece/BPE subwords of the same word
Basic POS-like patterns: Attending from punctuation to punctuation, content words to content words
Middle Layers (4-8):
Syntactic dependencies: Subject-verb, modifier-head, clause boundaries
Coreference: Pronouns to antecedents within and across sentences
Semantic roles: Agent, patient, instrument relationships
Named entity linking: Connecting mentions of the same entity
Late Layers (9-12):
Task-specific patterns: In fine-tuned models, heads adapt to the downstream task
Abstract reasoning: Logical relationships, cause-effect
Global context: Broad attention gathering diverse context
Output preparation: Aggregating information needed for predictions
Because Transformers use residual connections, each layer doesn't receive just the previous layer's output—it receives the cumulative sum of all previous layers. This means early-layer features (like positional patterns) remain accessible throughout the network, even as later layers add more abstract features. Heads can read from and write to this "residual stream" selectively.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
import numpy as npimport matplotlib.pyplot as pltfrom typing import List, Dict, Tuple def analyze_layer_patterns( attention_weights: np.ndarray, # (num_layers, num_heads, seq_len, seq_len) tokens: List[str]) -> Dict[str, np.ndarray]: """ Analyze how attention patterns change across layers. Returns: Dictionary containing layer-wise pattern metrics """ num_layers, num_heads, seq_len, _ = attention_weights.shape results = { 'locality': np.zeros((num_layers, num_heads)), 'self_attention': np.zeros((num_layers, num_heads)), 'entropy': np.zeros((num_layers, num_heads)), 'first_token': np.zeros((num_layers, num_heads)), } for layer in range(num_layers): for head in range(num_heads): weights = attention_weights[layer, head] # (seq_len, seq_len) # Locality: Average attention to adjacent positions locality_scores = [] for i in range(seq_len): adjacent_attn = 0 if i > 0: adjacent_attn += weights[i, i-1] if i < seq_len - 1: adjacent_attn += weights[i, i+1] locality_scores.append(adjacent_attn) results['locality'][layer, head] = np.mean(locality_scores) # Self-attention: Average attention to same position results['self_attention'][layer, head] = np.mean([ weights[i, i] for i in range(seq_len) ]) # Entropy: How spread out is the attention? # Higher entropy = broader attention entropies = [] for i in range(seq_len): row = weights[i] row = row[row > 1e-8] # Avoid log(0) entropy = -np.sum(row * np.log(row + 1e-10)) entropies.append(entropy) results['entropy'][layer, head] = np.mean(entropies) # First token: Average attention to first position results['first_token'][layer, head] = np.mean([ weights[i, 0] for i in range(seq_len) ]) return results def visualize_layer_progression(): """ Visualize how head behaviors change across layers. """ # Simulate typical pattern (for illustration) num_layers = 12 num_heads = 12 # Create synthetic data matching typical observations: # - Early layers: high locality, low entropy # - Middle layers: moderate locality, higher entropy # - Late layers: low locality, high entropy np.random.seed(42) # Locality decreases with depth locality = np.zeros((num_layers, num_heads)) for layer in range(num_layers): base = 0.6 - 0.04 * layer # Decreasing trend locality[layer] = base + np.random.randn(num_heads) * 0.1 locality = np.clip(locality, 0, 1) # Entropy increases with depth entropy = np.zeros((num_layers, num_heads)) max_entropy = np.log(50) # Based on sequence length for layer in range(num_layers): base = 1.5 + 0.15 * layer # Increasing trend entropy[layer] = base + np.random.randn(num_heads) * 0.3 entropy = np.clip(entropy, 0, max_entropy) # First token attention is high in specific heads first_token = np.random.rand(num_layers, num_heads) * 0.1 first_token[5:8, 0:2] = 0.6 + np.random.rand(3, 2) * 0.2 # CLS heads in middle layers print("Layer-Wise Head Pattern Analysis") print("=" * 60) print() # Summarize by layer group early_layers = slice(0, 4) middle_layers = slice(4, 8) late_layers = slice(8, 12) print("Locality (attention to adjacent positions):") print(f" Early layers (0-3): {locality[early_layers].mean():.3f}") print(f" Middle layers (4-7): {locality[middle_layers].mean():.3f}") print(f" Late layers (8-11): {locality[late_layers].mean():.3f}") print() print("Entropy (attention spread):") print(f" Early layers: {entropy[early_layers].mean():.3f}") print(f" Middle layers: {entropy[middle_layers].mean():.3f}") print(f" Late layers: {entropy[late_layers].mean():.3f}") print() print("First token attention:") print(f" Early layers: {first_token[early_layers].mean():.3f}") print(f" Middle layers: {first_token[middle_layers].mean():.3f}") print(f" Late layers: {first_token[late_layers].mean():.3f}") print() print("Interpretation:") print(" - Early layers focus on local patterns (high locality)") print(" - Entropy increases as patterns become more abstract") print(" - Middle layers often have dedicated CLS/global heads") print() # Identify specialized heads print("\nSpecialized Head Detection:") print("-" * 40) # Find most local heads most_local = np.unravel_index(np.argmax(locality), locality.shape) print(f" Most local head: Layer {most_local[0]}, Head {most_local[1]} " f"(locality={locality[most_local]:.3f})") # Find highest entropy heads most_spread = np.unravel_index(np.argmax(entropy), entropy.shape) print(f" Broadest attention: Layer {most_spread[0]}, Head {most_spread[1]} " f"(entropy={entropy[most_spread]:.3f})") # Find CLS heads cls_head = np.unravel_index(np.argmax(first_token), first_token.shape) print(f" Strongest CLS head: Layer {cls_head[0]}, Head {cls_head[1]} " f"(first_token_attn={first_token[cls_head]:.3f})") if __name__ == "__main__": visualize_layer_progression()An important question for both understanding and efficiency is: do heads learn redundant functions, or do they specialize into diverse roles?
Evidence for Redundancy:
Research has shown that many heads can be removed with minimal performance impact:
Why Redundancy Might Exist:
| Study | Model | Finding |
|---|---|---|
| Michel et al. (2019) | BERT, Transformer | Can remove 20-40% of heads with <1% performance drop |
| Voita et al. (2019) | Transformer MT | Only ~10% of heads are "important" for translation |
| Clark et al. (2019) | BERT | Identified interpretable patterns in subset of heads |
| Kovaleva et al. (2019) | BERT | Many heads show vertical/diagonal patterns (positional) |
Measuring Head Diversity:
Head diversity can be quantified by comparing attention patterns:
$$\text{Diversity}(h_i, h_j) = 1 - \cos(\vec{h_i}, \vec{h_j})$$
where $\vec{h}$ is the flattened attention matrix. Low diversity indicates redundant heads.
Average pairwise diversity by layer:
Implications for Efficiency:
The existence of redundant heads motivates:
While redundancy seems wasteful, it may contribute to Transformer robustness. Models with redundant heads are more resilient to noise, adversarial attacks, and domain shift. The "extra" capacity also provides headroom for fine-tuning—redundant heads can specialize for new tasks without disrupting essential functions.
Understanding what attention heads learn requires effective visualization and analysis techniques. Here we survey the main approaches used by researchers and practitioners.
1. Attention Heatmaps
The most direct visualization: plot attention weights as a matrix with query positions on one axis and key positions on the other.
Keys (what is attended to)
─────────────────────────────
| The cat sat on the mat
Q | .2 .5 .1 .05 .1 .05 (The → cat is highest)
u | .1 .1 .6 .1 .05 .05 (sat has broad attention)
e | ... ... ... ... ... ...
r |
i |
e |
s |
Strengths: Direct, interpretable Weaknesses: Only shows one head at a time; hard to compare across layers
2. Attention Flow / River Diagrams
Visualize attention as flowing from query tokens to key tokens, with line thickness proportional to attention weight. Good for showing which words attend to which.
3. Aggregated Head Summaries
Rather than looking at individual examples, compute statistics across many examples:
4. Probing Classifiers
Train simple classifiers (linear probes) on attention weights to test if they encode specific properties:
def probe_for_syntax(attention_weights, labels):
"""Train a probe to predict syntactic labels from attention."""
# Flatten attention pattern
features = attention_weights.reshape(num_examples, -1)
# Train logistic regression
from sklearn.linear_model import LogisticRegression
probe = LogisticRegression()
probe.fit(features, labels)
return probe.score(features, labels)
High probe accuracy suggests the attention encodes that property.
5. Behavioral Testing
Test specific linguistic capabilities by constructing diagnostic examples:
Subject-verb agreement test:
"The keys to the cabinet [ARE/IS] on the table."
Check: Does the subject-verb head attend "ARE" → "keys" more than "ARE" → "cabinet"?
6. Causal Interventions
Modify attention patterns and observe effects:
We've explored the fascinating world of what attention heads learn, from syntactic structure to semantic relationships to positional patterns. Let's consolidate the key insights:
What's Next:
In the final page of this module, we'll explore head pruning—techniques for identifying and removing unnecessary heads to create smaller, faster models while maintaining performance. This connects the interpretability insights from this page to practical efficiency gains.
You now understand the rich landscape of attention head specialization. The key insight: Transformers don't just memorize—they discover structure. The emergence of syntactic and semantic patterns from pure next-token prediction is one of the most surprising findings in modern NLP, suggesting that linguistic structure is statistically learnable from text alone.