Loading problem...
In modern Transformer architectures, multi-head attention is a cornerstone mechanism that enables the model to jointly attend to information from different representation subspaces. However, empirical research has revealed a remarkable finding: a significant portion of attention heads can be removed without substantially degrading model performance. Studies on models like BERT and GPT have demonstrated that up to 50% of attention heads can be pruned with less than 1% accuracy loss on downstream tasks.
Attention head pruning is a critical technique in model compression and efficient inference. The core idea is to identify and remove heads that contribute minimally to the model's predictions, thereby reducing computational cost, memory footprint, and inference latency. This is especially valuable in production environments where models must serve millions of requests with tight latency budgets.
The Pruning Process:
Given a transformer layer with H attention heads, each head produces an attention weight matrix of shape (sequence_length × sequence_length). To prune heads:
Your Task:
Implement a function that performs attention head pruning. Given:
Return:
This technique is fundamental to efficient transformer deployment in systems like DistilBERT, TinyBERT, and various mobile-optimized language models.
attn = (4, 3, 3) tensor with 4 attention heads
importance = [0.8, 0.3, 0.9, 0.2]
ratio = 0.5([pruned_attn with shape (2, 3, 3)], [0, 2])Step-by-step analysis:
Initial Setup: We have 4 attention heads with importance scores [0.8, 0.3, 0.9, 0.2]
Calculate Heads to Keep:
Rank by Importance:
Select Top 2 Heads: Heads 2 and 0 have the highest importance scores
Sort by Original Index: Return indices in ascending order → [0, 2]
Extract Attention Weights: The output contains attention matrices for heads 0 and 2 only, reducing the tensor from shape (4, 3, 3) to (2, 3, 3)
Heads 1 and 3 are pruned as they have the lowest importance scores.
attn = (2, 2, 2) tensor with 2 attention heads
importance = [0.6, 0.4]
ratio = 0.5([pruned_attn with shape (1, 2, 2)], [0])Pruning with minimal head count:
Calculate Heads to Keep: floor(2 × 0.5) = 1 head to keep
Compare Importance Scores:
Select Top Head: Head 0 has higher importance (0.6 > 0.4)
Result: Only head 0's attention weights are retained, and the kept indices list contains just [0]
This demonstrates the edge case where aggressive pruning reduces the model to a single-head attention mechanism.
attn = (3, 2, 2) tensor with 3 attention heads
importance = [0.5, 0.8, 0.3]
ratio = 0.0([original_attn with shape (3, 2, 2)], [0, 1, 2])No Pruning Scenario (ratio = 0.0):
Calculate Heads to Keep: floor(3 × (1 - 0.0)) = floor(3 × 1.0) = 3 heads
All Heads Retained: With a pruning ratio of 0, no heads are removed
Result: The output attention tensor is identical to the input, and all indices [0, 1, 2] are returned
This represents the baseline case where we want to preserve the full model capacity, useful for comparison or when pruning is disabled.
Constraints