Loading content...
The previous pages revealed that not all attention heads are created equal—some capture essential syntactic and semantic patterns, while others appear redundant or learn simple positional rules. This observation leads to a practical question: can we remove unnecessary heads to make models faster and smaller without sacrificing performance?
Head pruning is a structured compression technique that identifies and removes entire attention heads from trained Transformers. Unlike unstructured weight pruning (which zeros individual parameters), head pruning removes complete computational units, yielding:
This page covers the theory and practice of attention head pruning: importance scoring methods, pruning strategies, the accuracy-efficiency trade-off, and practical implementation considerations.
Research consistently shows that 20-40% of attention heads can be removed from large Transformers with less than 1% performance degradation. For a 12-layer, 12-head BERT model (144 heads total), this means potentially removing 30-60 heads—a significant efficiency gain for deployment scenarios.
The foundation of head pruning is importance scoring—quantifying how much each head contributes to model performance. Several methods have been developed, each with different trade-offs.
1. Gradient-Based Importance
Measure importance by the sensitivity of the loss to removing the head:
$$I_{\text{grad}}(h) = \left| \frac{\partial \mathcal{L}}{\partial \xi_h} \right|$$
where $\xi_h$ is a gating variable for head $h$. Intuitively, if zeroing a head causes large gradient signals, that head is important for the loss.
Implementation:
2. Taylor Expansion Importance
Approximate the change in loss from removing head $h$ using first-order Taylor expansion:
$$\Delta \mathcal{L}_h \approx \left| \frac{\partial \mathcal{L}}{\partial o_h} \cdot o_h \right|$$
where $o_h$ is the output of head $h$. This combines gradient magnitude (sensitivity) with output magnitude (actual contribution).
| Method | What It Measures | Computational Cost | Reliability |
|---|---|---|---|
| Gradient-based | Loss sensitivity to head | One backward pass | Good for sparse gradients |
| Taylor expansion | Gradient × activation | One backward pass | More stable than pure gradient |
| Leave-one-out | Actual loss change when removed | O(H) forward passes | Ground truth but expensive |
| Attention entropy | How focused the attention is | Just forward pass | Heuristic, may miss important heads |
| Oracle (retrain) | Optimal pruned model | Full retraining per config | Best but prohibitively expensive |
3. Leave-One-Out (LOO) Importance
The most direct measure: actually remove each head and measure performance:
$$I_{\text{LOO}}(h) = \mathcal{L}(\text{model without } h) - \mathcal{L}(\text{full model})$$
Heads that cause large loss increases when removed are important.
Limitations:
4. Attention Pattern Heuristics
Simpler heuristics based on attention patterns:
Entropy-based: Low-entropy heads (very focused attention) may be more important than high-entropy heads (diffuse attention)
$$H(h) = -\sum_j \alpha_{h,j} \log \alpha_{h,j}$$
Uniformity-based: Heads with near-uniform attention may contribute less than heads with sharp attention patterns
Caveat: These heuristics don't always correlate with actual importance—a diffuse attention pattern might still be critical for certain tasks.
In practice, combining multiple importance signals often works better than any single method. For example: Taylor importance × (1 - attention entropy) weights by both gradient sensitivity and attention sharpness. Ensemble importance scores are more robust to the quirks of individual methods.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Dict, List, Tuple, Optionalfrom dataclasses import dataclassimport numpy as np @dataclassclass HeadImportance: """Container for head importance scores.""" layer: int head: int gradient_importance: float taylor_importance: float attention_entropy: float combined_score: float class HeadImportanceScorer: """ Compute importance scores for attention heads using multiple methods. Supports: 1. Gradient-based importance 2. Taylor expansion importance 3. Attention entropy 4. Combined/ensemble scoring """ def __init__(self, model: nn.Module, num_layers: int, num_heads: int): self.model = model self.num_layers = num_layers self.num_heads = num_heads # Storage for intermediate values during forward pass self.head_outputs: Dict[Tuple[int, int], torch.Tensor] = {} self.head_attentions: Dict[Tuple[int, int], torch.Tensor] = {} self.head_gradients: Dict[Tuple[int, int], torch.Tensor] = {} def register_hooks(self): """Register forward/backward hooks to capture head outputs and gradients.""" hooks = [] for layer_idx in range(self.num_layers): # This assumes a specific model structure; adapt as needed def make_forward_hook(layer): def hook(module, input, output): # Capture head outputs # Assuming output is (batch, heads, seq, d_k) before concatenation if hasattr(module, 'last_attention_weights'): for head in range(self.num_heads): self.head_attentions[(layer, head)] = module.last_attention_weights[:, head].detach() return hook def make_backward_hook(layer): def hook(module, grad_input, grad_output): # Capture gradients w.r.t head outputs if grad_output[0] is not None: grad = grad_output[0] if grad.dim() == 4: # (batch, heads, seq, d_k) for head in range(self.num_heads): self.head_gradients[(layer, head)] = grad[:, head].detach() return hook # Register hooks (model-specific path) # hooks.append(layer_module.register_forward_hook(make_forward_hook(layer_idx))) # hooks.append(layer_module.register_backward_hook(make_backward_hook(layer_idx))) return hooks def compute_gradient_importance( self, dataloader, loss_fn, num_batches: int = 100 ) -> np.ndarray: """ Compute gradient-based importance scores. Returns: Array of shape (num_layers, num_heads) with importance scores """ importance = np.zeros((self.num_layers, self.num_heads)) self.model.eval() for batch_idx, (inputs, targets) in enumerate(dataloader): if batch_idx >= num_batches: break # Forward pass outputs = self.model(inputs) loss = loss_fn(outputs, targets) # Backward pass loss.backward() # Accumulate gradient magnitudes for each head for (layer, head), grad in self.head_gradients.items(): importance[layer, head] += grad.abs().mean().item() # Normalize importance /= num_batches return importance def compute_taylor_importance( self, dataloader, loss_fn, num_batches: int = 100 ) -> np.ndarray: """ Compute Taylor expansion importance: |gradient × output| Returns: Array of shape (num_layers, num_heads) with importance scores """ importance = np.zeros((self.num_layers, self.num_heads)) self.model.eval() for batch_idx, (inputs, targets) in enumerate(dataloader): if batch_idx >= num_batches: break outputs = self.model(inputs) loss = loss_fn(outputs, targets) loss.backward() # Taylor importance = gradient * activation for (layer, head) in self.head_outputs.keys(): output = self.head_outputs[(layer, head)] grad = self.head_gradients.get((layer, head)) if grad is not None: taylor = (grad * output).abs().mean().item() importance[layer, head] += taylor importance /= num_batches return importance def compute_attention_entropy(self) -> np.ndarray: """ Compute average attention entropy for each head. Lower entropy = more focused attention. Returns: Array of shape (num_layers, num_heads) with entropy values """ entropy = np.zeros((self.num_layers, self.num_heads)) for (layer, head), attn in self.head_attentions.items(): # attn shape: (batch, seq, seq) - attention weights # Compute entropy for each query position, then average attn_clamped = attn.clamp(min=1e-9) head_entropy = -(attn_clamped * attn_clamped.log()).sum(dim=-1).mean() entropy[layer, head] = head_entropy.item() return entropy def compute_combined_importance( self, gradient_weight: float = 0.5, entropy_weight: float = 0.3, taylor_weight: float = 0.2 ) -> np.ndarray: """ Compute combined importance score from multiple signals. Higher score = more important head (should NOT be pruned) """ # Normalize each signal to [0, 1] def normalize(arr): if arr.max() - arr.min() < 1e-8: return np.zeros_like(arr) return (arr - arr.min()) / (arr.max() - arr.min()) grad_imp = normalize(self.compute_gradient_importance(None, None, 0)) taylor_imp = normalize(self.compute_taylor_importance(None, None, 0)) entropy = normalize(self.compute_attention_entropy()) # Low entropy = focused attention = potentially more important entropy_imp = 1 - entropy combined = ( gradient_weight * grad_imp + taylor_weight * taylor_imp + entropy_weight * entropy_imp ) return combined def demonstrate_importance_computation(): """ Demonstrate head importance computation on synthetic data. """ print("Head Importance Scoring Demonstration") print("=" * 60) num_layers = 12 num_heads = 12 # Simulate importance scores (in practice, computed from model) np.random.seed(42) # Gradient importance - higher in middle layers gradient_imp = np.random.rand(num_layers, num_heads) * 0.5 for layer in range(4, 8): gradient_imp[layer] += 0.3 + np.random.rand(num_heads) * 0.2 # Taylor importance - correlated with gradient but noisier taylor_imp = gradient_imp * (0.7 + np.random.rand(num_layers, num_heads) * 0.6) # Entropy - lower in specialized heads entropy = np.random.rand(num_layers, num_heads) * 2 # Raw entropy entropy[5, 2] = 0.3 # A focused, specialized head entropy[6, 0] = 0.4 # Another specialized head print("Sample importance scores for Layer 5:") print(f" Gradient importance: {gradient_imp[5].round(3)}") print(f" Taylor importance: {taylor_imp[5].round(3)}") print(f" Attention entropy: {entropy[5].round(3)}") print() # Identify least important heads combined = gradient_imp * 0.5 + taylor_imp * 0.3 + (1 - entropy/entropy.max()) * 0.2 # Flatten and find least important flat_combined = combined.flatten() least_important_indices = np.argsort(flat_combined)[:10] print("10 Least Important Heads (pruning candidates):") print("-" * 40) for idx in least_important_indices: layer = idx // num_heads head = idx % num_heads print(f" Layer {layer:2d}, Head {head:2d}: importance = {combined[layer, head]:.3f}") print() print("10 Most Important Heads (keep these):") print("-" * 40) most_important_indices = np.argsort(flat_combined)[-10:][::-1] for idx in most_important_indices: layer = idx // num_heads head = idx % num_heads print(f" Layer {layer:2d}, Head {head:2d}: importance = {combined[layer, head]:.3f}") if __name__ == "__main__": demonstrate_importance_computation()Once importance scores are computed, the next question is: how do we decide which heads to prune? Several strategies exist, with different trade-offs between simplicity, optimality, and computational cost.
1. Global Top-K Pruning
The simplest approach: rank all heads by importance, prune the K least important globally.
# Sort all heads by importance
all_heads = [(layer, head, importance[layer, head])
for layer in range(L) for head in range(H)]
all_heads.sort(key=lambda x: x[2]) # Sort by importance
# Prune bottom K
heads_to_prune = all_heads[:K]
Pros: Simple, respects global importance ranking Cons: May remove all heads from one layer, causing severe degradation
2. Layer-Balanced Pruning
Constrain pruning to remove at most K heads per layer:
for layer in range(L):
layer_heads = sorted(range(H), key=lambda h: importance[layer, h])
prune_from_layer = min(K_per_layer, len(layer_heads) - 1) # Keep at least 1
heads_to_prune.extend([(layer, h) for h in layer_heads[:prune_from_layer]])
Pros: Prevents catastrophic layer collapse Cons: May keep unimportant heads in some layers while pruning important ones in others
| Strategy | Pruning Pattern | Risk | Best For |
|---|---|---|---|
| Global Top-K | Remove K least important overall | May empty some layers | Small pruning ratios (<30%) |
| Layer-balanced | Remove K per layer | Suboptimal allocation | Consistent layer structure |
| Threshold-based | Remove if importance < θ | Unpredictable count | Quality-focused pruning |
| Iterative/gradual | Prune few, retrain, repeat | Expensive | Maximum accuracy retention |
| Learned pruning | Train gates to zero | Training complexity | End-to-end optimization |
3. Threshold-Based Pruning
Prune all heads below an importance threshold:
$$\text{Prune head } h \text{ if } I_h < \theta$$
The threshold can be set based on:
4. Iterative (Gradual) Pruning
Prune a small number of heads, fine-tune the model, repeat:
for iteration in range(num_iterations):
1. Compute importance scores
2. Prune K_small heads (e.g., 5% of remaining)
3. Fine-tune for E epochs
4. Evaluate; stop if accuracy drops too much
Pros: Allows model to adapt after each pruning step Cons: Computationally expensive; requires retraining budget
5. Learned (Differentiable) Pruning
Make pruning decisions learnable by introducing differentiable gates:
$$o_{\text{pruned}} = g(\theta_h) \cdot o_h$$
where $g(\theta_h)$ is a soft gate that approaches 0 for pruned heads. Train the model to minimize:
$$\mathcal{L}{\text{total}} = \mathcal{L}{\text{task}} + \lambda \sum_h g(\theta_h)$$
The regularization term encourages gates to go to zero, effectively pruning heads.
Pruning without fine-tuning typically causes significant accuracy loss. Even removing "unimportant" heads disrupts the learned computations that remaining heads and layers depend on. A short fine-tuning phase (1-10% of original training) after pruning is almost always necessary to recover accuracy.
Head pruning is a form of structured pruning—removing entire computational units rather than individual weights. Understanding this distinction is crucial for realizing efficiency gains.
Unstructured Pruning (Weight-Level):
Zero out individual weights based on magnitude:
$$W'{ij} = \begin{cases} W{ij} & \text{if } |W_{ij}| > \theta \ 0 & \text{otherwise} \end{cases}$$
Advantages:
Disadvantages:
Why Structured Pruning Gives Real Speedups:
When you remove an attention head entirely, you can physically remove the corresponding:
The result is a smaller dense model that runs faster with standard operations.
Mathematical View:
Original model with $h$ heads:
After pruning to $h'$ heads:
Speedup factor: $h/h'$ for attention computation (e.g., 12/8 = 1.5× with 4 heads pruned)
Other Structured Pruning Targets:
Beyond heads, structured pruning can target:
Each provides different trade-offs between compression and accuracy.
Head pruning inevitably involves a trade-off between model accuracy and computational efficiency. Understanding this trade-off—and the factors that influence it—is essential for practical pruning decisions.
Typical Pruning Curve:
Accuracy
|
|───────────────────────┐
| │
| └────────┐
| └───────────────
|________________________________________________________
Heads pruned (%) →
0-20%: Minimal accuracy loss (often <0.5%)
20-40%: Moderate loss (1-3%)
40-60%: Significant loss (3-10%)
>60%: Severe degradation
The "knee" of the curve—where accuracy drops more steeply—typically falls around 30-40% pruning for well-trained models.
| Heads Pruned | Speedup | Avg. Accuracy Drop | Notable Effects |
|---|---|---|---|
| 0% | 1.0× | 0% | Baseline |
| 20% | ~1.15× | <0.5% | Often noise-level difference |
| 30% | ~1.25× | 0.5-1% | Still competitive with baseline |
| 40% | ~1.4× | 1-2% | Noticeable on some tasks |
| 50% | ~1.6× | 2-4% | Clear accuracy trade-off |
| 60%+ | ~1.8×+ | 5%+ | May fail on complex tasks |
Factors Affecting the Trade-off:
1. Task Complexity
Simpler tasks (sentiment classification) tolerate more pruning than complex tasks (question answering, coreference resolution). Complex reasoning may require the full diversity of heads.
2. Fine-Tuning Budget
Models fine-tuned longer after pruning recover more accuracy. With sufficient fine-tuning, 40% pruning can match baseline; without fine-tuning, even 20% hurts.
3. Original Model Size
Larger models have more redundancy. BERT-Large tolerates higher pruning ratios than BERT-Base for the same relative accuracy loss.
4. Pruning Method Quality
Sophisticated importance scoring and iterative pruning achieve better accuracy at the same pruning ratio than naive methods.
5. Head vs. Layer Pruning
Pruning entire layers is more efficient but also more destructive. A 50% head pruning often beats 50% layer pruning in accuracy.
The optimal pruning ratio depends on your specific constraints. For production with strict latency requirements, accept 1-2% accuracy loss for 30% pruning. For research where every accuracy point matters, limit to 10-15% pruning. Always validate on your specific task—generic pruning benchmarks may not transfer.
Implementing head pruning requires careful handling of model surgery—removing heads from trained models while preserving the remaining functionality.
Step 1: Identify Heads to Prune
Use importance scoring to rank heads and select pruning targets.
Step 2: Modify Model Architecture
The cleanest approach is to create a new model with fewer heads per layer, then transfer weights:
def create_pruned_model(original_model, heads_to_keep):
"""
Create a new model with only the specified heads.
Args:
original_model: The trained model
heads_to_keep: Dict mapping layer_idx -> list of head indices to keep
"""
# Create new config with variable heads per layer
new_config = copy.deepcopy(original_model.config)
new_config.num_attention_heads_per_layer = {
layer: len(heads) for layer, heads in heads_to_keep.items()
}
# Initialize new model
pruned_model = ModelClass(new_config)
# Transfer weights for kept heads
for layer_idx, head_indices in heads_to_keep.items():
transfer_head_weights(original_model, pruned_model, layer_idx, head_indices)
return pruned_model
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
import torchimport torch.nn as nnimport copyfrom typing import Dict, List, Tuple, Setfrom dataclasses import dataclass @dataclassclass PruningConfig: """Configuration for head pruning.""" target_pruning_ratio: float # e.g., 0.3 for 30% pruning min_heads_per_layer: int = 1 # Minimum heads to keep per layer importance_method: str = "taylor" # "gradient", "taylor", "leave_one_out" fine_tune_epochs: int = 3 class HeadPruner: """ Prune attention heads from a trained Transformer model. Workflow: 1. Compute importance scores for all heads 2. Select heads to prune based on strategy 3. Create pruned model with fewer heads 4. Transfer weights from original to pruned model 5. Fine-tune to recover accuracy """ def __init__( self, model: nn.Module, num_layers: int, num_heads: int, d_model: int ): self.model = model self.num_layers = num_layers self.num_heads = num_heads self.d_model = d_model self.d_k = d_model // num_heads def compute_importance( self, dataloader, num_batches: int = 100, method: str = "taylor" ) -> torch.Tensor: """ Compute importance scores for all heads. Returns: Tensor of shape (num_layers, num_heads) with importance scores """ importance = torch.zeros(self.num_layers, self.num_heads) self.model.eval() for batch_idx, batch in enumerate(dataloader): if batch_idx >= num_batches: break # Forward pass with hooks to capture head outputs with torch.enable_grad(): outputs = self.model(**batch) loss = outputs.loss loss.backward() # Aggregate importance (implementation depends on method) # This is a simplified placeholder for layer in range(self.num_layers): for head in range(self.num_heads): # In practice, extract from model's attention layers importance[layer, head] += torch.rand(1).item() importance /= min(batch_idx + 1, num_batches) return importance def select_heads_to_prune( self, importance: torch.Tensor, config: PruningConfig ) -> Set[Tuple[int, int]]: """ Select which heads to prune based on importance scores. Returns: Set of (layer, head) tuples to prune """ total_heads = self.num_layers * self.num_heads num_to_prune = int(total_heads * config.target_pruning_ratio) # Collect all heads with their importance all_heads = [] for layer in range(self.num_layers): for head in range(self.num_heads): all_heads.append((layer, head, importance[layer, head].item())) # Sort by importance (ascending = least important first) all_heads.sort(key=lambda x: x[2]) # Select heads to prune, respecting min_heads_per_layer heads_to_prune = set() pruned_per_layer = {l: 0 for l in range(self.num_layers)} max_prune_per_layer = self.num_heads - config.min_heads_per_layer for layer, head, imp in all_heads: if len(heads_to_prune) >= num_to_prune: break if pruned_per_layer[layer] < max_prune_per_layer: heads_to_prune.add((layer, head)) pruned_per_layer[layer] += 1 return heads_to_prune def create_pruned_model( self, heads_to_prune: Set[Tuple[int, int]] ) -> nn.Module: """ Create a new model with specified heads removed. """ # Determine heads to keep per layer heads_to_keep = {} for layer in range(self.num_layers): kept = [h for h in range(self.num_heads) if (layer, h) not in heads_to_prune] heads_to_keep[layer] = kept # Create new model (architecture-specific) pruned_model = self._construct_pruned_architecture(heads_to_keep) # Transfer weights self._transfer_weights(pruned_model, heads_to_keep) return pruned_model def _construct_pruned_architecture( self, heads_to_keep: Dict[int, List[int]] ) -> nn.Module: """Construct a model with variable heads per layer.""" # This is highly architecture-specific # For illustration, we'll sketch the key weight transfers pruned_model = copy.deepcopy(self.model) for layer_idx, kept_heads in heads_to_keep.items(): num_kept = len(kept_heads) new_d_heads = num_kept * self.d_k # Get original layer # original_attn = self.model.layers[layer_idx].attention # pruned_attn = pruned_model.layers[layer_idx].attention # Resize projections # pruned_attn.W_q = nn.Linear(self.d_model, new_d_heads, bias=False) # pruned_attn.W_k = nn.Linear(self.d_model, new_d_heads, bias=False) # pruned_attn.W_v = nn.Linear(self.d_model, new_d_heads, bias=False) # pruned_attn.W_o = nn.Linear(new_d_heads, self.d_model, bias=False) pass # Architecture-specific implementation return pruned_model def _transfer_weights( self, pruned_model: nn.Module, heads_to_keep: Dict[int, List[int]] ): """Transfer weights for kept heads from original to pruned model.""" for layer_idx, kept_heads in heads_to_keep.items(): # Get weight matrices # original = self.model.layers[layer_idx].attention # pruned = pruned_model.layers[layer_idx].attention for i, orig_head_idx in enumerate(kept_heads): # Slice weights for this head from original start = orig_head_idx * self.d_k end = (orig_head_idx + 1) * self.d_k # Copy to new position in pruned model new_start = i * self.d_k new_end = (i + 1) * self.d_k # Transfer Q, K, V weights # pruned.W_q.weight[new_start:new_end] = original.W_q.weight[start:end] # ... similar for K, V # Transfer W_o weights (column slices) # pruned.W_o.weight[:, new_start:new_end] = original.W_o.weight[:, start:end] pass # Architecture-specific implementation def demonstrate_pruning_workflow(): """Demonstrate the complete head pruning workflow.""" print("Head Pruning Workflow Demonstration") print("=" * 60) # Simulated parameters num_layers = 12 num_heads = 12 total_heads = num_layers * num_heads print(f"Original model: {num_layers} layers × {num_heads} heads = {total_heads} heads") print() # Step 1: Compute importance (simulated) print("Step 1: Computing importance scores...") importance = torch.rand(num_layers, num_heads) # Make some heads clearly less important importance[0, :3] *= 0.1 # First layer, first 3 heads unimportant importance[11, 5:] *= 0.1 # Last layer, last 7 heads unimportant print(f" Importance range: [{importance.min():.3f}, {importance.max():.3f}]") print() # Step 2: Select heads to prune print("Step 2: Selecting heads to prune (30% target)...") target_ratio = 0.30 num_to_prune = int(total_heads * target_ratio) # Flatten and sort flat = [(l, h, importance[l, h].item()) for l in range(num_layers) for h in range(num_heads)] flat.sort(key=lambda x: x[2]) # Select lowest importance, keeping at least 1 per layer pruned = set() per_layer = {l: 0 for l in range(num_layers)} for l, h, imp in flat: if len(pruned) >= num_to_prune: break if per_layer[l] < num_heads - 1: # Keep at least 1 pruned.add((l, h)) per_layer[l] += 1 print(f" Selected {len(pruned)} heads to prune") print(f" Pruned per layer: {dict(sorted(per_layer.items()))}") print() # Step 3: Summarize pruned model print("Step 3: Pruned model structure") remaining_heads = total_heads - len(pruned) print(f" Remaining heads: {remaining_heads} ({100*remaining_heads/total_heads:.1f}%)") for layer in range(num_layers): kept = num_heads - per_layer[layer] bar = "█" * kept + "░" * per_layer[layer] print(f" Layer {layer:2d}: {bar} ({kept} heads)") print() print("Step 4: Fine-tuning required to recover accuracy") print(" Recommendation: 1-3 epochs at reduced learning rate") if __name__ == "__main__": demonstrate_pruning_workflow()Beyond basic head pruning, several advanced techniques offer improved trade-offs or address specific challenges.
1. Dynamic Head Pruning
Rather than statically removing heads, dynamically skip heads based on input:
$$o = \sum_{h \in \text{active}(x)} g_h(x) \cdot \text{head}_h(x)$$
where $\text{active}(x)$ selects which heads to compute for input $x$.
Benefits: Different inputs may need different heads; dynamic routing adapts Challenges: Selection mechanism adds overhead; irregular computation patterns
2. Task-Specific Pruning
Different downstream tasks may need different heads. Prune separately per task:
BERT-Base:
├── Sentiment → Keep 60 heads (syntactic less important)
├── NER → Keep 90 heads (need positional patterns)
├── QA → Keep 100 heads (complex reasoning)
└── Paraphrase → Keep 70 heads (semantic similarity)
This creates task-specific efficient models from one pre-trained base.
3. Lottery Ticket Hypothesis for Heads
The lottery ticket hypothesis suggests sparse subnetworks exist at initialization that can match full model performance. For heads:
Surprisingly, this sometimes matches full training—suggesting head importance is partially determined at initialization.
4. Pruning-Aware Training
Train with pruning in mind from the start:
Models trained this way are more "prunable" than standard training.
Head pruning composes well with other compression methods: quantization (reduce precision of remaining heads), knowledge distillation (train smaller model to match pruned model), and layer pruning (remove entire layers). A carefully combined approach can achieve 4-10× speedup while maintaining accuracy.
5. Neural Architecture Search for Optimal Head Configuration
Instead of starting from a full model and pruning, use NAS to find optimal head counts per layer:
Search space:
Layer 1: {4, 6, 8, 12} heads
Layer 2: {4, 6, 8, 12} heads
...
Layer 12: {4, 6, 8, 12} heads
Optimize for: accuracy + λ × (latency or parameters)
This can discover non-uniform configurations that manual pruning might miss—e.g., more heads in middle layers, fewer in early/late layers.
6. Hardware-Aware Pruning
Tailor pruning to target hardware characteristics:
A 30% head reduction might give only 10% speedup if remaining operations don't map efficiently to hardware.
We've explored the theory and practice of attention head pruning, from importance scoring through practical implementation. Let's consolidate the key insights:
Practical Recommendations:
| Scenario | Recommended Approach |
|---|---|
| Quick deployment optimization | 20-25% global pruning + 1 epoch fine-tuning |
| Maximum efficiency needed | 40% iterative pruning + 3-5 epochs fine-tuning |
| Task-specific deployment | Per-task pruning with task-specific fine-tuning |
| Research/exploration | Comprehensive importance analysis + ablation studies |
This Completes Module 3: Multi-Head Attention
You now have a deep understanding of multi-head attention: why multiple heads are necessary, how they're combined, the parallelism that makes Transformers efficient, what heads learn, and how to prune unnecessary heads for efficiency. This knowledge forms the foundation for understanding modern Transformer architectures and their optimization.
Congratulations! You've completed the Multi-Head Attention module. You understand not just the mechanics of multiple attention heads, but the deeper principles: why single heads are insufficient, how heads specialize, and how to make attention efficient through pruning. This prepares you for the Transformer Architecture module, where these components come together into the complete architecture.