Loading content...
When a transformer model translates 'The cat sat on the mat' to French, how does it know that 'cat' corresponds to 'chat'? When BERT classifies a movie review as positive, which words drove that decision? The answer lies in attention mechanisms—and more importantly, in our ability to visualize and interpret them.
Attention mechanisms are the backbone of modern NLP and increasingly vision models. Unlike recurrent networks that process sequences step-by-step, attention allows models to directly connect any input position to any output position. This creates a natural interpretability opportunity: we can literally see which parts of the input the model 'attends to' when producing each part of the output.
But attention visualization is not as simple as 'bright colors mean important'. Multi-head attention, layer depth, and the difference between attention weights and information flow create subtle interpretation challenges. This page gives you the complete framework to correctly visualize and interpret attention patterns.
This page covers: (1) The mechanics of attention and why it creates interpretable patterns, (2) Visualizing single-head and multi-head attention, (3) Layer-by-layer attention analysis, (4) BertViz and other visualization tools, (5) Attention in vision transformers, (6) Cross-attention for multi-modal models, (7) Critical limitations of attention as explanation, and (8) When attention tells the truth and when it misleads.
Before visualizing attention, we must understand what attention computes. The scaled dot-product attention from 'Attention Is All You Need' (Vaswani et al., 2017) operates as follows:
Query-Key-Value (QKV) Framework:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
where:
The Attention Weights Matrix:
The term $\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$ produces an attention weight matrix $A$ of shape [sequence_length × sequence_length]. Each row sums to 1 and represents how much each source position contributes to each target position.
$A_{ij}$ = attention weight from position $i$ to position $j$ = "how much position $i$ attends to position $j$"
High attention weight means the model 'looked' at that position. It doesn't necessarily mean that position was 'important' for the final prediction. Attention is one step in a complex computation—the model still applies transformations after aggregating attended values. Keep this distinction in mind throughout.
The simplest attention visualization shows the attention matrix as a heatmap. Given input tokens, we color cells based on attention weight values.
Basic Heatmap Visualization:
For self-attention (input = output positions), the matrix is square. For cross-attention (e.g., decoder attending to encoder), it's rectangular.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
import torchimport numpy as npimport matplotlib.pyplot as pltimport seaborn as snsfrom transformers import BertTokenizer, BertModel # Load pre-trained BERTtokenizer = BertTokenizer.from_pretrained('bert-base-uncased')model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True) # Sample sentencesentence = "The cat sat on the mat because it was tired"inputs = tokenizer(sentence, return_tensors='pt')tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) # Forward pass with attention outputswith torch.no_grad(): outputs = model(**inputs) # outputs.attentions is tuple of (num_layers, ) # each layer: [batch, num_heads, seq_len, seq_len]attentions = torch.stack(outputs.attentions).squeeze(1) # [layers, heads, seq, seq]print(f"Attention shape: {attentions.shape}") # Visualize single head from single layerlayer_idx = 5 # Middle layerhead_idx = 0 # First head attention_matrix = attentions[layer_idx, head_idx].numpy() fig, ax = plt.subplots(figsize=(10, 8))sns.heatmap( attention_matrix, xticklabels=tokens, yticklabels=tokens, cmap='Blues', ax=ax, vmin=0, vmax=1)ax.set_xlabel('Source Token (Key)')ax.set_ylabel('Target Token (Query)')ax.set_title(f'Attention Weights: Layer {layer_idx}, Head {head_idx}')plt.xticks(rotation=45, ha='right')plt.yticks(rotation=0)plt.tight_layout()plt.savefig('attention_heatmap.png', dpi=150)plt.show() # Line/arc visualization for specific tokendef plot_token_attention(attention_matrix, tokens, target_idx, ax=None): """Visualize attention FROM a specific target token.""" if ax is None: fig, ax = plt.subplots(figsize=(12, 4)) weights = attention_matrix[target_idx] positions = range(len(tokens)) # Bar chart of attention weights bars = ax.bar(positions, weights, color='steelblue', alpha=0.7) # Highlight the source token bars[target_idx].set_color('darkred') ax.set_xticks(positions) ax.set_xticklabels(tokens, rotation=45, ha='right') ax.set_ylabel('Attention Weight') ax.set_title(f'Attention from "{tokens[target_idx]}" to all tokens') ax.set_ylim(0, 1) return ax # What does 'it' attend to? (pronoun resolution)it_idx = tokens.index('it')fig, ax = plt.subplots(figsize=(12, 4))plot_token_attention(attention_matrix, tokens, it_idx, ax)plt.tight_layout()plt.savefig('pronoun_attention.png', dpi=150)plt.show()Real transformer models use multi-head attention: multiple attention mechanisms running in parallel. Each head can learn to attend differently:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$
where each head$_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$
Why Multiple Heads?
Visualizing Multiple Heads:
Displaying 12 heads × 12 layers = 144 attention matrices is overwhelming. Common strategies:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
import torchimport numpy as npimport matplotlib.pyplot as pltfrom transformers import BertTokenizer, BertModel # Load modeltokenizer = BertTokenizer.from_pretrained('bert-base-uncased')model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True) sentence = "The lawyer questioned the witness because she thought he was lying"inputs = tokenizer(sentence, return_tensors='pt')tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) with torch.no_grad(): outputs = model(**inputs) attentions = torch.stack(outputs.attentions).squeeze(1) # [12, 12, seq, seq]num_layers, num_heads, seq_len, _ = attentions.shape # Visualize all 12 heads from layer 8 (often captures semantic relations)layer = 8fig, axes = plt.subplots(3, 4, figsize=(16, 12)) for head in range(num_heads): ax = axes[head // 4, head % 4] attn = attentions[layer, head].numpy() im = ax.imshow(attn, cmap='Blues', vmin=0, vmax=attn.max()) ax.set_title(f'Head {head}', fontsize=10) ax.set_xticks(range(len(tokens))) ax.set_yticks(range(len(tokens))) if head % 4 == 0: ax.set_yticklabels(tokens, fontsize=6) else: ax.set_yticklabels([]) if head >= 8: ax.set_xticklabels(tokens, fontsize=6, rotation=90) else: ax.set_xticklabels([]) plt.suptitle(f'All 12 Attention Heads - Layer {layer}', fontsize=14)plt.tight_layout()plt.savefig('all_heads_layer8.png', dpi=150)plt.show() # Identify specialized headsdef analyze_head_patterns(attentions, tokens): """Analyze what patterns different heads capture.""" n_layers, n_heads, seq_len, _ = attentions.shape patterns = [] for layer in range(n_layers): for head in range(n_heads): attn = attentions[layer, head].numpy() # Metrics for head specialization diagonal_ratio = np.trace(attn) / seq_len # Self-attention # Previous token attention (diagonal shifted by 1) prev_token = np.mean([attn[i, i-1] for i in range(1, seq_len)]) # Next token attention next_token = np.mean([attn[i, i+1] for i in range(seq_len-1)]) # Special token ([CLS]) attention cls_attention = np.mean(attn[:, 0]) # Attention TO [CLS] # Entropy (spread of attention) entropy = -np.sum(attn * np.log(attn + 1e-10)) / seq_len patterns.append({ 'layer': layer, 'head': head, 'self_attn': diagonal_ratio, 'prev_token': prev_token, 'next_token': next_token, 'cls_attn': cls_attention, 'entropy': entropy }) return patterns patterns = analyze_head_patterns(attentions, tokens) # Find specialized headsprint("Head Specialization Analysis:")print("="*60) # Highest self-attentionself_attn = max(patterns, key=lambda x: x['self_attn'])print(f"Position-aware (self): L{self_attn['layer']}H{self_attn['head']} ({self_attn['self_attn']:.3f})") # Previous token focusprev_focus = max(patterns, key=lambda x: x['prev_token'])print(f"Previous token focus: L{prev_focus['layer']}H{prev_focus['head']} ({prev_focus['prev_token']:.3f})") # CLS focus cls_focus = max(patterns, key=lambda x: x['cls_attn'])print(f"[CLS] aggregator: L{cls_focus['layer']}H{cls_focus['head']} ({cls_focus['cls_attn']:.3f})")Research has found that specific heads consistently capture specific linguistic phenomena: syntactic heads track grammatical relations, coreference heads link pronouns to antecedents, and positional heads focus on adjacent tokens. The paper 'What Does BERT Look At?' (Clark et al., 2019) provides detailed analysis of BERT's attention heads.
Manual heatmap creation quickly becomes tedious. BertViz (Jesse Vig, 2019) provides interactive attention visualization that has become the standard tool for exploring transformer attention.
BertViz Views:
Head View: All heads from all layers in one interface. Click any head to see its attention pattern. Lines connect attended tokens with width proportional to attention weight.
Model View: Aggregated attention across all heads in a layer. Useful for seeing overall layer behavior.
Neuron View: Traces how individual query and key neurons contribute to attention. Most granular analysis.
Key Features:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
# Install BertViz: pip install bertviz from bertviz import head_view, model_viewfrom transformers import BertTokenizer, BertModelimport torch # Load model with attention outputstokenizer = BertTokenizer.from_pretrained('bert-base-uncased')model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True) # Analyze a sentencesentence_a = "The cat sat on the mat."sentence_b = "It was very comfortable." # Tokenizeinputs = tokenizer(sentence_a, sentence_b, return_tensors='pt')input_ids = inputs['input_ids'][0]tokens = tokenizer.convert_ids_to_tokens(input_ids) # Get attentionwith torch.no_grad(): outputs = model(**inputs) attention = outputs.attentions # Tuple of (batch, heads, seq, seq) per layer # HEAD VIEW: Interactive visualization of all heads# Shows connections between tokens with lines# Width of line = attention weighthead_view(attention, tokens) # MODEL VIEW: Aggregated view across heads# Useful for seeing layer-by-layer patternsmodel_view(attention, tokens) # For Jupyter notebooks, these render inline as interactive widgets# For scripts, they open in browser # Programmatic analysis alongside visualizationdef summarize_attention_patterns(attention, tokens): """Generate human-readable summary of attention patterns.""" attention_stack = torch.stack(attention).squeeze() # [layers, heads, seq, seq] summaries = [] for layer_idx in range(attention_stack.shape[0]): layer_attn = attention_stack[layer_idx] # [heads, seq, seq] # Average across heads avg_attn = layer_attn.mean(dim=0).numpy() # Find strongest non-diagonal attention np.fill_diagonal(avg_attn, 0) max_attn = np.unravel_index(np.argmax(avg_attn), avg_attn.shape) source, target = max_attn weight = avg_attn[source, target] summaries.append(f"Layer {layer_idx}: '{tokens[source]}' → '{tokens[target]}' ({weight:.3f})") return summaries summaries = summarize_attention_patterns(attention, tokens)print("Strongest Attention Connections per Layer:")for s in summaries: print(s)| Tool | Main Use Case | Interactivity | Installation |
|---|---|---|---|
| BertViz | General transformer attention exploration | High (web-based) | pip install bertviz |
| Transformers Interpret | Attribution + attention for HuggingFace | Medium | pip install transformers-interpret |
| Ecco | NLP interpretation with attention + embeddings | High | pip install ecco |
| AllenNLP Interpret | NLP interpretation suite | High (demo) | pip install allennlp-interpret |
| LIT (Language Interpretability Tool) | Google's interactive ML analysis | Very High | pip install lit-nlp |
Deep transformers stack many layers, and attention patterns evolve dramatically from early to late layers. Understanding this evolution reveals how models build up representations.
Layer Progression Patterns:
Early Layers (1-3): Attention often focuses on positional patterns—adjacent tokens, special tokens ([CLS], [SEP], [PAD]). Local syntactic structure.
Middle Layers (4-8): Increasingly semantic attention. Pronouns attend to antecedents. Related concepts connect. Long-range dependencies emerge.
Late Layers (9-12): Task-specific patterns. For classification, attention concentrates on [CLS]. For generation, attention patterns become more diffuse and abstract.
Aggregating Across Layers:
No single layer tells the complete story. Information flows through the network, transforming at each layer. Recent work suggests that attention rollout or attention flow that traces paths through all layers better reflects information flow than single-layer attention.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
import torchimport numpy as npimport matplotlib.pyplot as pltfrom transformers import BertTokenizer, BertModel tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True) sentence = "John gave Mary a book because she asked for it"inputs = tokenizer(sentence, return_tensors='pt')tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) with torch.no_grad(): outputs = model(**inputs) attentions = torch.stack(outputs.attentions).squeeze() # [12, 12, seq, seq] # Analyze attention entropy (spread) across layersdef attention_entropy(attn_matrix): """Calculate entropy of attention distribution per row, averaged.""" # attn_matrix: [seq, seq], each row sums to 1 epsilon = 1e-10 entropy_per_row = -torch.sum(attn_matrix * torch.log(attn_matrix + epsilon), dim=-1) return entropy_per_row.mean().item() # Track metrics across layerslayer_metrics = []for layer in range(12): layer_attn = attentions[layer].mean(dim=0) # Average across heads entropy = attention_entropy(layer_attn) diagonal = torch.trace(layer_attn).item() / len(tokens) # Attention to [CLS] cls_attn = layer_attn[:, 0].mean().item() layer_metrics.append({ 'layer': layer, 'entropy': entropy, 'self_attention': diagonal, 'cls_attention': cls_attn }) # Plot layer progressionfig, axes = plt.subplots(1, 3, figsize=(15, 4)) layers = range(12)metrics = layer_metrics axes[0].plot(layers, [m['entropy'] for m in metrics], 'o-', color='steelblue')axes[0].set_xlabel('Layer')axes[0].set_ylabel('Attention Entropy')axes[0].set_title('Attention Spread (Entropy)') axes[1].plot(layers, [m['self_attention'] for m in metrics], 'o-', color='darkorange')axes[1].set_xlabel('Layer')axes[1].set_ylabel('Self-Attention Ratio')axes[1].set_title('Diagonal Attention') axes[2].plot(layers, [m['cls_attention'] for m in metrics], 'o-', color='forestgreen')axes[2].set_xlabel('Layer')axes[2].set_ylabel('[CLS] Attention')axes[2].set_title('Attention to [CLS]') plt.suptitle('Attention Patterns Across BERT Layers', fontsize=14)plt.tight_layout()plt.savefig('layer_progression.png', dpi=150)plt.show() # ATTENTION ROLLOUT: Aggregate attention across layersdef attention_rollout(attentions, add_residual=True): """ Compute attention rollout as described in Abnar & Zuidema (2020). Aggregates attention across layers accounting for residual connections. """ # attentions: [n_layers, n_heads, seq_len, seq_len] n_layers, n_heads, seq_len, _ = attentions.shape # Average across heads attn_avg = attentions.mean(dim=1) # [n_layers, seq_len, seq_len] # Add identity matrix for residual connection if add_residual: eye = torch.eye(seq_len) attn_avg = 0.5 * attn_avg + 0.5 * eye # Normalize rows attn_avg = attn_avg / attn_avg.sum(dim=-1, keepdim=True) # Rollout: multiply attention matrices rollout = attn_avg[0] for layer in range(1, n_layers): rollout = torch.matmul(attn_avg[layer], rollout) return rollout rollout_attn = attention_rollout(attentions) # Compare single layer vs rolloutfig, axes = plt.subplots(1, 2, figsize=(14, 6)) # Single layer (last layer)single_layer = attentions[-1].mean(dim=0).numpy()im1 = axes[0].imshow(single_layer, cmap='Blues')axes[0].set_title('Single Layer (Layer 12, avg heads)')axes[0].set_xticks(range(len(tokens)))axes[0].set_xticklabels(tokens, rotation=90, fontsize=8)axes[0].set_yticks(range(len(tokens)))axes[0].set_yticklabels(tokens, fontsize=8)plt.colorbar(im1, ax=axes[0]) # Attention rolloutim2 = axes[1].imshow(rollout_attn.numpy(), cmap='Blues')axes[1].set_title('Attention Rollout (All Layers)')axes[1].set_xticks(range(len(tokens)))axes[1].set_xticklabels(tokens, rotation=90, fontsize=8)axes[1].set_yticks(range(len(tokens)))axes[1].set_yticklabels(tokens, fontsize=8)plt.colorbar(im2, ax=axes[1]) plt.tight_layout()plt.savefig('rollout_comparison.png', dpi=150)plt.show()Attention rollout assumes attention matrices can be simply multiplied across layers. This ignores the non-linear transformations (FFN, LayerNorm) between attention layers. More sophisticated methods like 'Attention Flow' (Abnar & Zuidema, 2020) or gradient-based methods may better capture true information flow.
Vision Transformers (ViT) apply the same attention mechanism to images by treating image patches as 'tokens'. This creates a powerful opportunity: we can visualize which parts of an image the model attends to for each patch.
ViT Patch Structure:
Attention Visualization for ViT:
For a 224×224 image with 16×16 patches, we get 196 patches + 1 [CLS] = 197 tokens. The attention from [CLS] to all patches can be reshaped back to a 14×14 spatial grid and overlaid on the original image.
This creates attention maps showing which regions the model 'looks at' for its classification decision.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
import torchimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Imagefrom transformers import ViTImageProcessor, ViTModelimport requests # Load ViT modelprocessor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')model = ViTModel.from_pretrained('google/vit-base-patch16-224', output_attentions=True) # Load sample imageurl = 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/1200px-Cat_November_2010-1a.jpg'image = Image.open(requests.get(url, stream=True).raw).convert('RGB') # Process imageinputs = processor(images=image, return_tensors="pt") # Get attentionwith torch.no_grad(): outputs = model(**inputs) attentions = torch.stack(outputs.attentions).squeeze() # [12, 12, 197, 197]# 197 = 196 patches (14x14) + 1 [CLS] token # Extract attention from [CLS] token to all patchesdef get_attention_map(attentions, layer, head, cls_idx=0, patch_size=14): """Get spatial attention map from [CLS] to patches.""" attn = attentions[layer, head] # [197, 197] # Attention FROM [CLS] TO patches (excluding [CLS] itself) cls_to_patches = attn[cls_idx, 1:].numpy() # [196] # Reshape to spatial grid attn_map = cls_to_patches.reshape(patch_size, patch_size) return attn_map # Visualize attention for different headsfig, axes = plt.subplots(3, 4, figsize=(16, 12)) for head in range(12): ax = axes[head // 4, head % 4] # Get attention from last layer attn_map = get_attention_map(attentions, layer=-1, head=head) # Resize to original image size attn_resized = np.array(Image.fromarray(attn_map).resize((224, 224), Image.BILINEAR)) # Overlay on original image ax.imshow(image.resize((224, 224))) ax.imshow(attn_resized, cmap='hot', alpha=0.6) ax.set_title(f'Head {head}') ax.axis('off') plt.suptitle('ViT Attention Maps (Last Layer, [CLS] → Patches)', fontsize=14)plt.tight_layout()plt.savefig('vit_attention_heads.png', dpi=150)plt.show() # Aggregate across heads and layers for overall attentiondef aggregate_vit_attention(attentions, method='mean'): """Aggregate attention across heads and layers.""" # attentions: [n_layers, n_heads, 197, 197] if method == 'mean': # Simple average avg_attn = attentions.mean(dim=(0, 1)) # [197, 197] cls_to_patches = avg_attn[0, 1:].numpy() elif method == 'rollout': # Attention rollout across layers n_layers, n_heads, seq_len, _ = attentions.shape attn_avg = attentions.mean(dim=1) # [n_layers, 197, 197] eye = torch.eye(seq_len) attn_avg = 0.5 * attn_avg + 0.5 * eye attn_avg = attn_avg / attn_avg.sum(dim=-1, keepdim=True) rollout = attn_avg[0] for layer in range(1, n_layers): rollout = torch.matmul(attn_avg[layer], rollout) cls_to_patches = rollout[0, 1:].numpy() return cls_to_patches.reshape(14, 14) # Compare aggregation methodsfig, axes = plt.subplots(1, 3, figsize=(15, 5)) axes[0].imshow(image.resize((224, 224)))axes[0].set_title('Original Image')axes[0].axis('off') mean_attn = aggregate_vit_attention(attentions, 'mean')mean_resized = np.array(Image.fromarray(mean_attn).resize((224, 224), Image.BILINEAR))axes[1].imshow(image.resize((224, 224)))axes[1].imshow(mean_resized, cmap='hot', alpha=0.6)axes[1].set_title('Mean Attention')axes[1].axis('off') rollout_attn = aggregate_vit_attention(attentions, 'rollout')rollout_resized = np.array(Image.fromarray(rollout_attn).resize((224, 224), Image.BILINEAR))axes[2].imshow(image.resize((224, 224)))axes[2].imshow(rollout_resized, cmap='hot', alpha=0.6)axes[2].set_title('Attention Rollout')axes[2].axis('off') plt.tight_layout()plt.savefig('vit_aggregation.png', dpi=150)plt.show()Some of the most interpretable attention patterns emerge in cross-attention, where one modality (e.g., text) attends to another (e.g., images). This is central to:
Why Cross-Attention is More Interpretable:
In cross-attention, we're explicitly asking: 'When generating word X, which parts of the image does the model look at?' This is more directly interpretable than self-attention, where the same sequence attends to itself.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
import torchimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Imagefrom transformers import BlipProcessor, BlipForConditionalGenerationimport requests # Load BLIP model (image captioning with cross-attention)processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base", output_attentions=True) # Load imageurl = 'https://upload.wikimedia.org/wikipedia/commons/b/bc/Juvenile_Ragdoll.jpg'image = Image.open(requests.get(url, stream=True).raw).convert('RGB') # Generate caption with attentioninputs = processor(images=image, return_tensors="pt") # Generate with output attentionswith torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=50, return_dict_in_generate=True, output_attentions=True ) # The cross-attentions show decoder tokens attending to encoder (image) tokens# outputs.cross_attentions is complex for generation (one per generated token) # Decode generated captioncaption = processor.decode(outputs.sequences[0], skip_special_tokens=True)print(f"Generated caption: {caption}") # For detailed cross-attention analysis, use a single forward pass# This is a simplified example showing the concept def visualize_translation_alignment(src_tokens, tgt_tokens, attention_matrix): """ Visualize cross-attention alignment for translation or similar tasks. attention_matrix: [tgt_len, src_len] """ fig, ax = plt.subplots(figsize=(12, 8)) im = ax.imshow(attention_matrix, cmap='Blues') ax.set_xticks(range(len(src_tokens))) ax.set_xticklabels(src_tokens, rotation=45, ha='right') ax.set_yticks(range(len(tgt_tokens))) ax.set_yticklabels(tgt_tokens) ax.set_xlabel('Source (Encoder)') ax.set_ylabel('Target (Decoder)') ax.set_title('Cross-Attention Alignment') plt.colorbar(im) plt.tight_layout() return fig # Simulated translation alignment (English-French)src_tokens = ['The', 'cat', 'sat', 'on', 'the', 'mat', '.']tgt_tokens = ['Le', 'chat', 's\'est', 'assis', 'sur', 'le', 'tapis', '.'] # Simulated attention matrix (would come from model in practice)np.random.seed(42)attention = np.zeros((len(tgt_tokens), len(src_tokens)))# Create plausible alignmentalignments = [(0, 0), (1, 1), (2, 2), (3, 2), (4, 3), (5, 4), (6, 5), (7, 6)]for tgt_idx, src_idx in alignments: attention[tgt_idx, src_idx] = np.random.uniform(0.6, 0.9) # Add some noise to neighbors for offset in [-1, 1]: if 0 <= src_idx + offset < len(src_tokens): attention[tgt_idx, src_idx + offset] = np.random.uniform(0.05, 0.15) # Normalize rowsattention = attention / attention.sum(axis=1, keepdims=True) fig = visualize_translation_alignment(src_tokens, tgt_tokens, attention)plt.savefig('translation_alignment.png', dpi=150)plt.show() # Key insight: Cross-attention creates word-level or patch-level alignments# that are often directly interpretable as "word X corresponds to word Y"# or "word X looks at image region Z"Cross-attention visualizations are particularly valuable for debugging multi-modal models. If an image captioning model says 'a dog on the couch' when the image shows a cat, checking cross-attention can reveal whether the model looked at the wrong region (attention error) or correctly attended but misclassified the object (recognition error).
Despite its intuitive appeal, attention visualization has fundamental limitations that every practitioner must understand. The seminal paper 'Attention is Not Explanation' (Jain & Wallace, 2019) demonstrated that attention weights do not reliably indicate feature importance.
Key Findings:
Alternative Attention Distributions: Many different attention patterns can produce the same prediction. High attention on a word doesn't mean that word was necessary.
Gradient Mismatch: Attention weights often don't align with gradient-based importance measures. The model might attend to a word but not be sensitive to changes in that word.
Attention is Input to Computation: Attention determines what gets aggregated, but the subsequent transformations (FFN layers, LayerNorm) determine how that information affects output.
Adversarial Attention: It's possible to create models with misleading attention that still perform well on tasks.
| Aspect | What Attention Shows | What Explanation Requires |
|---|---|---|
| Definition | Which inputs were weighted highly | Which inputs caused the output |
| Counterfactual | No: doesn't show what happens if input changed | Yes: requires sensitivity analysis |
| Sufficiency | No: aggregated values still transformed | Yes: should identify sufficient features |
| Uniqueness | No: multiple attention patterns give same output | Ideally: canonical explanation |
| Relationship | Correlation with output computation | Causal contribution to output |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
import torchimport numpy as npfrom transformers import BertTokenizer, BertForSequenceClassification # Load sentiment classification modeltokenizer = BertTokenizer.from_pretrained('textattack/bert-base-uncased-SST-2')model = BertForSequenceClassification.from_pretrained( 'textattack/bert-base-uncased-SST-2', output_attentions=True) sentence = "This movie was absolutely terrible and a waste of time"inputs = tokenizer(sentence, return_tensors='pt', padding=True)tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) # Get attentionmodel.eval()outputs = model(**inputs)attentions = torch.stack(outputs.attentions).squeeze() # [12, 12, seq, seq] # Average attention to [CLS] from last layerlast_layer_attn = attentions[-1].mean(dim=0) # Average across headsattn_to_cls = last_layer_attn[:, 0].numpy() # Attention FROM each token TO [CLS]attn_from_cls = last_layer_attn[0, :].numpy() # Attention FROM [CLS] TO each token # Now compute gradient-based importanceembeddings = model.bert.embeddings.word_embeddings(inputs['input_ids'])embeddings.retain_grad() # Forward pass with gradient trackinglogits = model.bert(inputs_embeds=embeddings)[0]logits_cls = model.classifier(logits[:, 0]) # Use [CLS] for classification # Compute gradient w.r.t. predicted classpredicted_class = logits_cls.argmax().item()logits_cls[0, predicted_class].backward() # Gradient importance: L2 norm of embedding gradientsgradient_importance = embeddings.grad[0].norm(dim=-1).detach().numpy() # Compare attention vs gradientimport matplotlib.pyplot as plt fig, axes = plt.subplots(3, 1, figsize=(12, 10)) x = range(len(tokens)) axes[0].bar(x, attn_from_cls, color='steelblue')axes[0].set_xticks(x)axes[0].set_xticklabels(tokens, rotation=45, ha='right')axes[0].set_ylabel('Attention Weight')axes[0].set_title('Attention FROM [CLS] (Last Layer, Avg Heads)') axes[1].bar(x, gradient_importance, color='darkorange')axes[1].set_xticks(x)axes[1].set_xticklabels(tokens, rotation=45, ha='right')axes[1].set_ylabel('Gradient Magnitude')axes[1].set_title('Gradient-Based Importance') # Correlationfrom scipy.stats import spearmanrcorr, pval = spearmanr(attn_from_cls, gradient_importance) axes[2].scatter(attn_from_cls, gradient_importance, alpha=0.7)for i, tok in enumerate(tokens): axes[2].annotate(tok, (attn_from_cls[i], gradient_importance[i]), fontsize=8)axes[2].set_xlabel('Attention Weight')axes[2].set_ylabel('Gradient Importance')axes[2].set_title(f'Attention vs Gradient (Spearman r={corr:.3f}, p={pval:.3f})') plt.tight_layout()plt.savefig('attention_vs_gradient.png', dpi=150)plt.show() # Key insight: If correlation is low, attention doesn't predict importance wellThe follow-up paper 'Attention is Not Not Explanation' (Wiegreffe & Pinter, 2019) argues that the critique is too strong. Attention can be a reasonable explanation in many contexts, especially when combined with other evidence. The takeaway: attention is one useful signal among many, not a complete explanation.
Attention visualization provides a window into transformer computations, but that window must be used carefully with full awareness of its limitations. Here's the essential framework:
What's Next:
Attention visualization explained what models 'look at' in a soft, distributed way. In the next page, we'll explore Saliency Maps—gradient-based methods that reveal which input features the model is most sensitive to. Saliency provides a complementary view based on counterfactual reasoning: 'How would the output change if this input changed?'
You now have a comprehensive understanding of attention visualization in transformer models. You can visualize single-head and multi-head patterns, analyze layer progression, interpret cross-attention for multi-modal tasks, and critically evaluate when attention does and doesn't provide reliable explanations.