Loading learning content...
Self-attention's power comes from its ability to connect every position to every other position in a single operation. But this global connectivity comes at a cost: quadratic complexity in sequence length.
For a sequence of length $n$:
Understanding this complexity is critical because it:
This page provides a rigorous analysis of self-attention complexity, its practical implications, and strategies to mitigate the quadratic bottleneck.
By the end of this page, you will understand the exact time and space complexity of self-attention, be able to calculate memory requirements for given configurations, appreciate the scaling challenges, and be familiar with approaches to efficient attention.
Let's precisely analyze the computational cost of each step in self-attention.
Notation:
Step-by-Step Complexity:
| Operation | Computation | FLOPs | Dominant Term |
|---|---|---|---|
| Q = XW_Q | (n, d) × (d, d_k) | O(n·d·d_k) | O(n·d²) if d_k ∝ d |
| K = XW_K | (n, d) × (d, d_k) | O(n·d·d_k) | O(n·d²) |
| V = XW_V | (n, d) × (d, d_v) | O(n·d·d_v) | O(n·d²) |
| S = QK^T | (n, d_k) × (d_k, n) | O(n²·d_k) | O(n²·d) |
| Scale | (n, n) scalar div | O(n²) | O(n²) |
| Softmax | Row-wise over (n, n) | O(n²) | O(n²) |
| O = AV | (n, n) × (n, d_v) | O(n²·d_v) | O(n²·d) |
| Output = OW_O | (n, d_v) × (d_v, d) | O(n·d_v·d) | O(n·d²) |
Total Time Complexity:
$$T(n, d) = O(n \cdot d^2 + n^2 \cdot d)$$
For typical configurations where $d$ is a few hundred and $n$ can be thousands:
In modern LLMs with long contexts, the $O(n^2 \cdot d)$ attention term dominates.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
import numpy as npimport time def count_flops_detailed(n: int, d: int, d_k: int, d_v: int) -> dict: """ Count FLOPs for each operation in self-attention. Matrix multiply (m,k) × (k,n) = 2*m*k*n FLOPs (mul + add) """ flops = {} # QKV projections: 3 × (n, d) × (d, d_k) flops['Q_proj'] = 2 * n * d * d_k flops['K_proj'] = 2 * n * d * d_k flops['V_proj'] = 2 * n * d * d_v # Attention scores: (n, d_k) × (d_k, n) flops['QK^T'] = 2 * n * d_k * n # Scale: n² divisions flops['scale'] = n * n # Softmax: ~5 ops per element (exp, sum, div, etc.) flops['softmax'] = 5 * n * n # AV: (n, n) × (n, d_v) flops['AV'] = 2 * n * n * d_v # Output projection: (n, d_v) × (d_v, d) flops['O_proj'] = 2 * n * d_v * d flops['total'] = sum(flops.values()) flops['projection_total'] = flops['Q_proj'] + flops['K_proj'] + flops['V_proj'] + flops['O_proj'] flops['attention_total'] = flops['QK^T'] + flops['scale'] + flops['softmax'] + flops['AV'] return flops def analyze_scaling(): """Analyze how complexity scales with n and d.""" print("Self-Attention FLOPs Analysis") print("=" * 70) d = 512 # Model dimension d_k = d_v = 64 # Head dimension (d/8) print(f"Fixed: d={d}, d_k=d_v={d_k}") print(f"{'n':<10} {'Proj (M)':<15} {'Attn (M)':<15} {'Total (M)':<15} {'Attn %':<10}") print("-" * 70) for n in [128, 256, 512, 1024, 2048, 4096, 8192]: flops = count_flops_detailed(n, d, d_k, d_v) proj_m = flops['projection_total'] / 1e6 attn_m = flops['attention_total'] / 1e6 total_m = flops['total'] / 1e6 attn_pct = 100 * flops['attention_total'] / flops['total'] print(f"{n:<10} {proj_m:<15.1f} {attn_m:<15.1f} {total_m:<15.1f} {attn_pct:<10.1f}%") print("\n→ Attention % increases with sequence length (quadratic term)") # Crossover point print(f"\nCrossover analysis (d={d}, d_k={d_k}):") print("Projection FLOPs: 4 × 2nd²/h = 8nd²/h = O(nd²)") print("Attention FLOPs: 2n²d_k + 2n²d_v ≈ 4n²d/h = O(n²d)") print(f"Crossover when: 8nd²/h = 4n²d/h → n = 2d = {2*d}") print(f"For n > {2*d}, attention dominates.") analyze_scaling()Doubling sequence length quadruples attention computation time. Going from 2K to 128K tokens (64× length) requires 4096× more attention computation. This is why efficient attention variants are essential for long-context models.
Memory usage is often the more critical bottleneck than compute, especially for training.
Memory Categories:
Memory Breakdown:
| Component | Size | Scaling | Notes |
|---|---|---|---|
| W_Q, W_K, W_V, W_O | 4 × d × d_k | O(d²) | Independent of n |
| Input X | n × d | O(n·d) | |
| Q, K, V | 3 × n × d_k | O(n·d) | |
| Attention Scores | n × n | O(n²) | The bottleneck! |
| Attention Weights | n × n | O(n²) | After softmax |
| Output O | n × d_v | O(n·d) |
Total Memory:
$$M(n, d) = O(d^2 + n \cdot d + n^2)$$
For long sequences, the $O(n^2)$ attention matrix dominates.
Training Memory (with gradients):
During training, we must store activations for backpropagation. The attention matrix and its gradient together require $2 \times n^2$ elements. With multiple layers and heads:
$$M_{\text{attn}} = 2 \times L \times h \times n^2 \times \text{sizeof(dtype)}$$
Where $L$ = layers, $h$ = heads per layer.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
import numpy as np def calculate_memory_bytes(n: int, d: int, L: int, h: int, dtype_bytes: int = 4) -> dict: """ Calculate memory requirements for self-attention in a transformer. Args: n: Sequence length d: Model dimension L: Number of layers h: Number of attention heads dtype_bytes: Bytes per element (4 for float32, 2 for float16/bf16) Returns: Dictionary with memory breakdown in bytes """ d_k = d_v = d // h # Head dimension memory = {} # Model parameters (per attention layer) params_per_layer = 4 * d * d_k * h # W_Q, W_K, W_V, W_O memory['parameters'] = params_per_layer * L * dtype_bytes # Activations (per layer, must store for backprop) # Q, K, V, attention scores, attention weights, output memory['Q'] = n * d_k * h * L * dtype_bytes memory['K'] = n * d_k * h * L * dtype_bytes memory['V'] = n * d_k * h * L * dtype_bytes memory['attention_scores'] = n * n * h * L * dtype_bytes memory['attention_weights'] = n * n * h * L * dtype_bytes memory['output'] = n * d * L * dtype_bytes memory['total_activations'] = (memory['Q'] + memory['K'] + memory['V'] + memory['attention_scores'] + memory['attention_weights'] + memory['output']) # Attention matrix specifically memory['attention_matrix_only'] = 2 * n * n * h * L * dtype_bytes memory['total'] = memory['parameters'] + memory['total_activations'] return memory def format_bytes(b): """Format bytes as human-readable.""" for unit in ['B', 'KB', 'MB', 'GB', 'TB']: if b < 1024: return f"{b:.1f} {unit}" b /= 1024 return f"{b:.1f} PB" # Analysis for different configurationsprint("Self-Attention Memory Analysis")print("=" * 80) configs = [ # (name, n, d, L, h, dtype_bytes) ("Small (BERT)", 512, 768, 12, 12, 4), ("Medium (GPT-2)", 1024, 1024, 24, 16, 4), ("Large (GPT-3 style)", 2048, 2048, 32, 32, 2), ("Long context 8K", 8192, 2048, 32, 32, 2), ("Long context 32K", 32768, 2048, 32, 32, 2), ("Long context 128K", 131072, 2048, 32, 32, 2),] print(f"{'Config':<20} {'n':<8} {'Params':<12} {'Attn Matrix':<15} {'Total Activ':<15}")print("-" * 80) for name, n, d, L, h, dtype_bytes in configs: mem = calculate_memory_bytes(n, d, L, h, dtype_bytes) print(f"{name:<20} {n:<8} {format_bytes(mem['parameters']):<12} " f"{format_bytes(mem['attention_matrix_only']):<15} " f"{format_bytes(mem['total_activations']):<15}") print("\n→ Attention matrix grows QUADRATICALLY with sequence length")print(" 128K tokens requires ~1TB just for attention matrices!")For a 128K context model with 32 layers and 32 heads, storing just the attention matrices requires over 1TB of memory! This exceeds any single GPU's memory. Flash Attention and gradient checkpointing are essential for long-context training.
How does self-attention's O(n²) complexity compare to other sequence models?
Architecture Complexity Comparison:
| Architecture | Time per Layer | Memory | Max Path Length | Parallelizable |
|---|---|---|---|---|
| Self-Attention | O(n²·d) | O(n²) | O(1) | Yes |
| RNN (LSTM/GRU) | O(n·d²) | O(n·d) | O(n) | No (sequential) |
| CNN (size k) | O(n·k·d²) | O(n·d) | O(log_k(n)) | Yes |
| Linear Attention | O(n·d²) | O(n·d) | O(1) | Yes |
| Sparse Attention | O(n·s·d) | O(n·s) | O(n/s) | Partially |
The Trade-off:
Break-Even Analysis:
At what sequence length does self-attention become slower than an RNN?
Assuming:
With enough parallelism, self-attention is faster for n < some threshold that depends on hardware.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
import numpy as npimport matplotlib.pyplot as plt def compare_architectures(n_values: list, d: int = 512): """ Compare theoretical complexity of different architectures. """ results = { 'n': n_values, 'self_attention': [], 'rnn': [], 'cnn_k3': [], 'linear_attention': [], } k = 3 # CNN kernel size for n in n_values: # Self-attention: O(n²·d) results['self_attention'].append(n**2 * d) # RNN: O(n·d²) results['rnn'].append(n * d**2) # CNN with log(n) layers for full receptive field num_layers = int(np.ceil(np.log(n) / np.log(k))) results['cnn_k3'].append(num_layers * n * k * d**2) # Linear attention: O(n·d²) results['linear_attention'].append(n * d**2) return results # Generate comparisonn_values = [2**i for i in range(7, 18)] # 128 to 131072d = 512 results = compare_architectures(n_values, d) print("Architecture Complexity Comparison (d = 512)")print("=" * 80)print(f"{'n':<10} {'Self-Attn':<20} {'RNN':<20} {'Linear Attn':<20}")print("-" * 80) for i, n in enumerate(n_values): sa = results['self_attention'][i] rnn = results['rnn'][i] linear = results['linear_attention'][i] # Normalize to millions of ops print(f"{n:<10} {sa/1e9:.2f}B ops {rnn/1e9:.4f}B ops {linear/1e9:.4f}B ops") # Find crossover pointsprint("\nCrossover Analysis:")print(f"Self-attention is faster than RNN when parallelizable and n < d = {d}")print(f"(Assuming full parallelization of self-attention)") # At what n does self-attention have 10x, 100x, 1000x ops of linear?for factor in [10, 100, 1000]: # n²d = factor * nd² → n = factor * d crossover_n = factor * d print(f"Self-attention has {factor}x ops of linear attention at n = {crossover_n}")Despite O(n²) complexity, self-attention often runs faster than O(n) RNNs for moderate n because GPUs can execute millions of operations in parallel. The RNN's serial nature is its Achilles' heel. Only for very long sequences does the quadratic term truly dominate.
Multi-head attention computes $h$ separate attention heads in parallel. How does this affect complexity?
Multi-Head Configuration:
Key Insight: Total Computation is the Same
With $h$ heads:
The split doesn't increase or decrease total work—it distributes it across heads.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
import numpy as np def multihead_complexity_analysis(): """ Analyze multi-head attention complexity. """ d = 512 # Model dimension n = 1024 # Sequence length print("Multi-Head Attention Complexity Analysis") print("=" * 70) print(f"d = {d}, n = {n}") print() configs = [1, 4, 8, 16, 32] # Different head counts print(f"{'Heads':<8} {'d_k=d/h':<10} {'Attn/Head':<15} {'Total Attn':<15} {'Memory/Head':<12}") print("-" * 70) for h in configs: d_k = d // h # FLOPs per head: 2 * n² * d_k (for QK^T) + 2 * n² * d_k (for AV) flops_per_head = 4 * n * n * d_k total_flops = h * flops_per_head # Memory per head: n × n attention matrix mem_per_head = n * n * 4 # float32 total_mem = h * mem_per_head print(f"{h:<8} {d_k:<10} {flops_per_head/1e6:.1f}M " f"{total_flops/1e6:.1f}M {mem_per_head/1e6:.1f}MB") print() print("Observations:") print("1. Total attention FLOPs is constant (~2.1B) regardless of head count") print("2. Each head has smaller d_k → less work per head") print("3. More heads → more parallelism opportunity") print("4. Memory per head decreases BUT number of attention matrices increases") print() # Memory analysis print("Memory Analysis:") print("-" * 40) print("With h heads, we store h separate attention matrices of size n×n") print(f"Total attention memory = h × n² = {h} × {n}² = h × 1M entries") print("Memory is independent of head count (same total entries)!") print() # But we store n×d_k Q, K, V per head... print("However, Q/K/V storage:") for h in configs: d_k = d // h qkv_per_head = 3 * n * d_k total_qkv = h * qkv_per_head print(f" h={h}: {h} × (3 × {n} × {d_k}) = {total_qkv} = {total_qkv} entries (constant!)") multihead_complexity_analysis()Multi-head attention is computationally equivalent to single-head with the same total dimension, but more parallelizable and often faster on GPUs. The heads can capture different types of relationships simultaneously, making it both more expressive AND more hardware-friendly.
In practice, attention performance depends heavily on how well it maps to GPU architecture.
GPU Memory Hierarchy:
The Memory-Bound Problem:
Self-attention is memory-bound, not compute-bound. The $n \times n$ attention matrix must be:
Each HBM access is expensive. For large n, the attention matrix doesn't fit in L2 cache, causing many HBM round-trips.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
import numpy as np def gpu_memory_analysis(n: int, h: int, dtype_bytes: int = 2): """ Analyze attention's interaction with GPU memory hierarchy. """ # Attention matrix size per head attn_matrix_bytes = n * n * dtype_bytes attn_matrix_mb = attn_matrix_bytes / (1024 * 1024) # Typical GPU specs gpu_specs = { 'A100': {'hbm_gb': 80, 'hbm_bw_tb': 2.0, 'l2_mb': 40, 'flops_tf': 312}, 'H100': {'hbm_gb': 80, 'hbm_bw_tb': 3.35, 'l2_mb': 50, 'flops_tf': 989}, 'RTX_4090': {'hbm_gb': 24, 'hbm_bw_tb': 1.0, 'l2_mb': 72, 'flops_tf': 83}, } print(f"GPU Memory Analysis for n={n}, h={h} heads, {dtype_bytes*8}-bit") print("=" * 70) print(f"Attention matrix size (per head): {attn_matrix_mb:.1f} MB") print(f"Total across heads: {attn_matrix_mb * h:.1f} MB") print() total_mb = attn_matrix_mb * h for gpu_name, specs in gpu_specs.items(): print(f"{gpu_name}:") print(f" L2 Cache: {specs['l2_mb']} MB → ", end="") if total_mb <= specs['l2_mb']: print(f"✓ Fits in L2 ({total_mb:.1f}/{specs['l2_mb']} MB)") else: print(f"✗ Exceeds L2 ({total_mb:.1f}/{specs['l2_mb']} MB) - HBM accesses needed") # Estimate memory bandwidth limitation # Standard attention: 3 reads + 2 writes of attention matrix bytes_transferred = 5 * attn_matrix_bytes * h time_memory_limited = bytes_transferred / (specs['hbm_bw_tb'] * 1e12) # Compute needed flops = 4 * n * n * (512 // h) * h # QK^T and AV time_compute_limited = flops / (specs['flops_tf'] * 1e12) print(f" Memory time: {time_memory_limited*1000:.3f} ms") print(f" Compute time: {time_compute_limited*1000:.3f} ms") print(f" Bottleneck: {'Memory' if time_memory_limited > time_compute_limited else 'Compute'}") print() # Analyze for different sequence lengthsfor n in [512, 2048, 8192]: gpu_memory_analysis(n, h=32, dtype_bytes=2) print()FlashAttention: Solving the Memory Problem
FlashAttention restructures attention computation to:
The result: 2-4× faster attention while using O(n) memory instead of O(n²).
Key Techniques:
For any production transformer work, use FlashAttention (or similar implementations like FlashAttention-2, xFormers). The standard O(n²) memory implementation is too slow and memory-hungry. PyTorch 2.0+ includes FlashAttention via F.scaled_dot_product_attention().
The quadratic complexity problem has spawned a rich literature of efficient attention mechanisms. Here's an overview of the main approaches:
Category 1: Sparse Attention
Compute attention only for a subset of pairs, reducing from O(n²) to O(n·s) where s << n.
| Method | Pattern | Complexity | Use Case |
|---|---|---|---|
| Longformer | Local + global tokens | O(n·k) | Long documents |
| BigBird | Local + random + global | O(n) | General long sequences |
| Sparse Transformer | Fixed strided patterns | O(n√n) | Images, audio |
Category 2: Low-Rank / Linear Attention
Approximate the attention matrix with low-rank factorization, reducing O(n²) to O(n·r).
| Method | Approach | Complexity | Notes |
|---|---|---|---|
| Linear Attention | Kernel feature maps | O(n·d²) | Remove softmax via kernels |
| Performer | Random features for softmax | O(n·d·r) | FAVOR+ mechanism |
| Linformer | Project K,V to fixed length | O(n·k) | Fixed compression |
Category 3: Memory-Efficient Exact Attention
Compute exact attention but with O(n) memory instead of O(n²).
12345678910111213141516171819202122232425262728293031323334353637383940
import numpy as np def compare_attention_methods(): """ Compare complexity of different attention mechanisms. """ methods = [ ("Standard (Dense)", lambda n, d: n**2 * d, lambda n, d: n**2), ("Longformer (k=512)", lambda n, d: n * 512 * d, lambda n, d: n * 512), ("Linformer (k=256)", lambda n, d: n * 256 * d, lambda n, d: n * 256), ("Linear Attention", lambda n, d: n * d**2, lambda n, d: n * d), ("FlashAttention", lambda n, d: n**2 * d, lambda n, d: n * d), # Same compute, O(n) memory ] d = 512 print("Efficient Attention Comparison (d=512)") print("=" * 90) print(f"{'Method':<25} {'Compute n=2K':<15} {'Compute n=32K':<15} " f"{'Memory n=2K':<15} {'Memory n=32K':<15}") print("-" * 90) for name, compute_fn, memory_fn in methods: c_2k = compute_fn(2048, d) / 1e9 c_32k = compute_fn(32768, d) / 1e9 m_2k = memory_fn(2048, d) * 4 / 1e6 # float32 MB m_32k = memory_fn(32768, d) * 4 / 1e6 print(f"{name:<25} {c_2k:<15.2f} {c_32k:<15.2f} " f"{m_2k:<15.1f} {m_32k:<15.1f}") print() print("Units: Compute in GFLOPs, Memory in MB") print() print("Key observations:") print("- FlashAttention: same compute as standard, but O(n) memory!") print("- Sparse methods: ~100x less compute and memory for long sequences") print("- Linear methods: even lower complexity, but may sacrifice quality") compare_attention_methods()For most applications: (1) Use FlashAttention for exact attention with O(n) memory. (2) For very long sequences (>100K), consider sparse patterns like Longformer. (3) Linear attention variants are experimental but promising for extreme lengths.
Armed with complexity analysis, here are practical guidelines for working with self-attention in production:
Memory Budget Planning:
123456789101112131415161718192021222324252627282930313233343536373839404142434445
def plan_attention_memory( available_gpu_memory_gb: float, model_params_gb: float, n_layers: int, n_heads: int, d_model: int, dtype_bytes: int = 2 # float16/bf16): """ Plan maximum sequence length given memory constraints. """ # Memory available for activations (rough estimate) # Typically: params (×2 for optimizer), activations, attention available_for_activations = (available_gpu_memory_gb - model_params_gb * 2) * 0.5 available_for_activations_bytes = available_for_activations * 1e9 # Attention matrix memory: n² × h × L × dtype_bytes # Solve for n: n² = available / (h × L × dtype_bytes) coeff = n_heads * n_layers * dtype_bytes * 2 # ×2 for gradients max_n_squared = available_for_activations_bytes / coeff max_n = int(np.sqrt(max_n_squared)) return max_n # Example configurationsconfigs = [ ("A100 80GB, 7B model", 80, 14, 32, 32, 4096), ("A100 40GB, 7B model", 40, 14, 32, 32, 4096), ("RTX 4090 24GB, 1B model", 24, 2, 16, 16, 2048),] print("Maximum Sequence Length Planning")print("=" * 70)print(f"{'Config':<30} {'Max n (naive)':<15} {'With FlashAttn':<15}")print("-" * 70) for name, gpu_gb, params_gb, n_layers, n_heads, d_model in configs: max_n_naive = plan_attention_memory(gpu_gb, params_gb, n_layers, n_heads, d_model) # FlashAttention reduces attention memory from O(n²) to O(n) # Very rough estimate: can handle ~100x longer sequences max_n_flash = max_n_naive * 10 # Conservative estimate print(f"{name:<30} {max_n_naive:<15,} {max_n_flash:<15,}") print("\nNote: These are rough estimates. Actual limits depend on many factors.")With FlashAttention on an 80GB A100, you can train sequences up to ~16K-32K tokens for typical LLM architectures. For 100K+ context, sparse attention or sliding window approaches become necessary even with FlashAttention.
We've thoroughly analyzed the computational and memory complexity of self-attention, understanding both the theoretical foundations and practical implications.
Module Complete:
This completes our exploration of self-attention! We've covered:
With this foundation, you're ready to explore multi-head attention, the full transformer architecture, and advanced attention mechanisms.
You now have a comprehensive understanding of self-attention—from mathematical formulation to practical implementation considerations. This knowledge forms the foundation for understanding transformers, LLMs, and modern deep learning architectures.