Loading learning content...
One of the most revolutionary aspects of the Transformer architecture is its inherent parallelism. Unlike Recurrent Neural Networks (RNNs), where computation unfolds sequentially through time, multi-head attention processes all positions and all heads simultaneously. This isn't just a nice property—it's the key reason why Transformers have scaled to billions of parameters while RNNs hit fundamental bottlenecks.
Understanding this parallelism is crucial for:
This page dissects the three dimensions of parallelism in multi-head attention: head parallelism, positional parallelism, and batch parallelism—and explains how modern hardware exploits each.
RNNs process sequences step-by-step: to compute output at position t, you must first compute positions 1 through t-1. For a sequence of length n, this means n sequential steps that cannot be parallelized. Transformers compute ALL positions simultaneously in a constant number of steps, regardless of sequence length. This fundamental difference explains why Transformers train orders of magnitude faster.
Multi-head attention exhibits parallelism along multiple independent dimensions, each of which can be exploited by modern parallel hardware:
1. Head Parallelism (h dimensions)
The $h$ attention heads are completely independent:
This means all $h$ heads can compute simultaneously, with zero synchronization overhead.
2. Positional Parallelism (n dimensions for queries)
For a query at position $i$, the attention computation: $$\text{Attention}_i = \text{softmax}\left(\frac{q_i K^T}{\sqrt{d_k}}\right) V$$
is independent of the attention computation at any other position $j \neq i$. All $n$ positions can compute their attention outputs simultaneously.
3. Batch Parallelism (B dimensions)
Different sequences in a batch share no dependencies. All $B$ sequences can be processed in parallel.
| Dimension | Count | Independence | Parallelizable? |
|---|---|---|---|
| Batch | $B$ (batch size) | Sequences are fully independent | ✓ Fully parallel |
| Heads | $h$ (number of heads) | Heads share no state until concat | ✓ Fully parallel |
| Query positions | $n_q$ (query sequence length) | Each query position independent | ✓ Fully parallel |
| Key positions | $n_k$ (key sequence length) | Attention weights computed for all keys | ✓ Fully parallel (within each query) |
| Feature dimensions | $d_k$, $d_v$ | Dot products computed element-wise | ✓ Fully parallel |
Total Available Parallelism:
The total number of independent scalar operations that can execute simultaneously is:
$$\text{Parallelism} = B \times h \times n_q \times n_k \times d$$
For typical configurations (B=32, h=12, n=512, d=64), this is: $$32 \times 12 \times 512 \times 512 \times 64 = 6.44 \times 10^9$$
over 6 billion parallel operations—well beyond what even the largest GPUs can execute simultaneously, ensuring full hardware utilization.
Contrast with RNNs:
In an LSTM processing the same sequence:
This parallelism gap explains why RNNs hit training speed walls while Transformers continue scaling.
RNNs have an irreducible sequential dependency: h_t = f(h_{t-1}, x_t). Even with infinite parallel hardware, you cannot compute h_t until h_{t-1} is complete. For sequence length n, this means n sequential steps. Transformers have no such sequential dependency—all positions can compute simultaneously.
The parallel nature of multi-head attention is elegantly expressed through matrix operations, which map directly to highly optimized GPU/TPU primitives.
The Complete Computation:
For multi-head attention with input $X \in \mathbb{R}^{n \times d_{model}}$:
Projection (all heads, batched): $$Q = XW^Q, \quad K = XW^K, \quad V = XW^V$$ where $W^Q, W^K, W^V \in \mathbb{R}^{d_{model} \times d_{model}}$
Reshape to heads: Split the last dimension into $(h, d_k)$ $$Q, K, V \in \mathbb{R}^{n \times h \times d_k}$$
Attention scores (batched matrix multiply): $$A = QK^T / \sqrt{d_k} \in \mathbb{R}^{h \times n \times n}$$
Softmax (element-wise, parallel): $$\tilde{A} = \text{softmax}(A)$$
Value aggregation (batched matrix multiply): $$O = \tilde{A}V \in \mathbb{R}^{h \times n \times d_k}$$
Concatenate and project: $$\text{Output} = \text{Concat}(O_1, \ldots, O_h)W^O$$
Batched Operations Enable Parallelism:
The key insight is that steps 3-5 involve batched matrix multiplications over the head dimension. Instead of computing $h$ separate attention operations:
# Naive (not parallel across heads)
for i in range(h):
A_i = Q[i] @ K[i].T / sqrt(d_k)
attn_i = softmax(A_i) @ V[i]
We compute a single batched operation:
# Batched (parallel across heads)
# Q, K, V shape: (batch, heads, seq, d_k)
A = torch.matmul(Q, K.transpose(-2, -1)) / sqrt(d_k) # All heads simultaneously
attn = torch.matmul(softmax(A), V) # All heads simultaneously
Modern linear algebra libraries (cuBLAS, MKL, etc.) detect this structure and dispatch all head computations to separate GPU cores in parallel.
Memory Access Pattern:
The matrix layout also enables coalesced memory access:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
import torchimport torch.nn.functional as Fimport timefrom typing import Tuple def measure_attention_parallelism(): """ Demonstrate and measure the parallelism in multi-head attention. Compare batched vs sequential computation. """ # Configuration batch_size = 32 seq_len = 512 num_heads = 12 d_k = 64 d_model = num_heads * d_k device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") print() # Create inputs Q = torch.randn(batch_size, num_heads, seq_len, d_k, device=device) K = torch.randn(batch_size, num_heads, seq_len, d_k, device=device) V = torch.randn(batch_size, num_heads, seq_len, d_k, device=device) print("Multi-Head Attention Parallelism Analysis") print("=" * 60) print(f"Batch: {batch_size}, Heads: {num_heads}, Seq: {seq_len}, d_k: {d_k}") print() # Calculate theoretical parallelism total_parallel_ops = batch_size * num_heads * seq_len * seq_len * d_k print(f"Theoretical parallel operations: {total_parallel_ops:,}") print() # Warmup for _ in range(10): _ = torch.matmul(Q, K.transpose(-2, -1)) if device.type == 'cuda': torch.cuda.synchronize() # Method 1: Fully batched (parallel across everything) start = time.perf_counter() num_runs = 100 for _ in range(num_runs): # All heads compute simultaneously attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5) attn_weights = F.softmax(attn_scores, dim=-1) output = torch.matmul(attn_weights, V) if device.type == 'cuda': torch.cuda.synchronize() batched_time = (time.perf_counter() - start) / num_runs * 1000 # Method 2: Sequential over heads (simulated reduced parallelism) start = time.perf_counter() for _ in range(num_runs): outputs = [] for h in range(num_heads): attn_scores_h = torch.matmul(Q[:, h], K[:, h].transpose(-2, -1)) / (d_k ** 0.5) attn_weights_h = F.softmax(attn_scores_h, dim=-1) output_h = torch.matmul(attn_weights_h, V[:, h]) outputs.append(output_h) output_seq = torch.stack(outputs, dim=1) if device.type == 'cuda': torch.cuda.synchronize() sequential_time = (time.perf_counter() - start) / num_runs * 1000 print("Timing Comparison:") print("-" * 40) print(f"Fully batched (parallel): {batched_time:.2f} ms") print(f"Sequential over heads: {sequential_time:.2f} ms") print(f"Speedup from head parallelism: {sequential_time / batched_time:.2f}x") print() # Verify outputs match if device.type == 'cuda': torch.cuda.synchronize() max_diff = (output - output_seq).abs().max().item() print(f"Output difference (sanity check): {max_diff:.6f}") def analyze_memory_bandwidth(): """ Analyze memory bandwidth utilization in attention computation. """ print("\nMemory Bandwidth Analysis") print("=" * 60) batch_size = 32 seq_len = 512 num_heads = 12 d_k = 64 # Bytes per element (float32) bytes_per_elem = 4 # Memory reads q_bytes = batch_size * num_heads * seq_len * d_k * bytes_per_elem k_bytes = batch_size * num_heads * seq_len * d_k * bytes_per_elem v_bytes = batch_size * num_heads * seq_len * d_k * bytes_per_elem # Intermediate attention matrix attn_bytes = batch_size * num_heads * seq_len * seq_len * bytes_per_elem # Output output_bytes = batch_size * num_heads * seq_len * d_k * bytes_per_elem total_memory = q_bytes + k_bytes + v_bytes + attn_bytes + output_bytes print(f"Memory footprint per forward pass:") print(f" Q, K, V tensors: 3 × {q_bytes / 1e6:.1f} MB = {3*q_bytes / 1e6:.1f} MB") print(f" Attention matrix: {attn_bytes / 1e6:.1f} MB") print(f" Output tensor: {output_bytes / 1e6:.1f} MB") print(f" Total: {total_memory / 1e6:.1f} MB") print() # FLOPs analysis # QK^T: batch * heads * seq * seq * d_k MACs qk_flops = batch_size * num_heads * seq_len * seq_len * d_k * 2 # mult + add # Softmax: batch * heads * seq * seq (exp + sum + div) softmax_flops = batch_size * num_heads * seq_len * seq_len * 5 # AttnV: batch * heads * seq * seq * d_k MACs av_flops = batch_size * num_heads * seq_len * seq_len * d_k * 2 total_flops = qk_flops + softmax_flops + av_flops print(f"Computation (FLOPs per forward pass):") print(f" QK^T matmul: {qk_flops / 1e9:.2f} GFLOPs") print(f" Softmax: {softmax_flops / 1e9:.2f} GFLOPs") print(f" Attention × V: {av_flops / 1e9:.2f} GFLOPs") print(f" Total: {total_flops / 1e9:.2f} GFLOPs") print() # Arithmetic intensity (FLOPs per byte) arithmetic_intensity = total_flops / total_memory print(f"Arithmetic intensity: {arithmetic_intensity:.1f} FLOPs/byte") print() print("Interpretation:") print(" > 10: Compute-bound (good for GPUs)") print(" < 10: Memory-bound (bandwidth limited)") print(f" Attention is typically {'compute' if arithmetic_intensity > 10 else 'memory'}-bound") def compare_rnn_vs_transformer_parallelism(): """ Compare the parallelism characteristics of RNNs vs Transformers. """ print("\nRNN vs Transformer Parallelism Comparison") print("=" * 60) batch_size = 32 seq_len = 512 hidden_dim = 768 num_heads = 12 print("Configuration:") print(f" Batch size: {batch_size}") print(f" Sequence length: {seq_len}") print(f" Hidden dimension: {hidden_dim}") print() # RNN parallelism (per timestep) rnn_parallel_per_step = batch_size * hidden_dim rnn_sequential_steps = seq_len rnn_total_ops = rnn_parallel_per_step * rnn_sequential_steps # Transformer parallelism (all at once) # Attention: batch * heads * seq * seq transformer_parallel = batch_size * num_heads * seq_len * seq_len transformer_sequential_steps = 1 # Depth only print("RNN (LSTM):") print(f" Parallel ops per step: {rnn_parallel_per_step:,}") print(f" Sequential steps: {rnn_sequential_steps}") print(f" Total operations: {rnn_total_ops:,}") print() print("Transformer (attention):") print(f" Parallel ops: {transformer_parallel:,}") print(f" Sequential steps: {transformer_sequential_steps}") print(f" Total operations: {transformer_parallel:,}") print() parallelism_ratio = transformer_parallel / rnn_parallel_per_step print(f"Parallelism advantage: {parallelism_ratio:.0f}× more parallel work") print() print("Key insight:") print(" RNN: Parallelism limited by hidden dim, sequential over time") print(" Transformer: Parallelism scales with seq_len² × heads") print(" For seq_len=512, heads=12: 3.2M parallel vs 24K parallel") if __name__ == "__main__": measure_attention_parallelism() analyze_memory_bandwidth() compare_rnn_vs_transformer_parallelism()The parallel structure of multi-head attention maps exceptionally well to modern parallel hardware. Understanding this mapping is essential for optimizing Transformer performance.
GPU Architecture Overview:
Modern GPUs consist of:
Mapping Attention to GPU:
| Attention Operation | GPU Mapping | Parallelization |
|---|---|---|
| Q, K, V projections | cuBLAS GEMM | Batch × positions |
| Attention scores ($QK^T$) | cuBLAS batched GEMM | Batch × heads |
| Softmax | Custom kernel | Batch × heads × positions |
| Value aggregation | cuBLAS batched GEMM | Batch × heads |
| Output projection | cuBLAS GEMM | Batch × positions |
NVIDIA Tensor Cores can perform 4×4 matrix multiplications in a single clock cycle. Since attention is dominated by matrix multiplications (QK^T and softmax×V), Tensor Cores provide 8-16× speedup over standard CUDA cores for attention computation. This is why modern GPUs are so effective for Transformers.
Batched Matrix Multiplication:
The key primitive for efficient attention is batched GEMM (General Matrix Multiply):
$$C_{i} = A_{i} B_{i} \quad \text{for } i = 1, \ldots, \text{batch_size}$$
For attention, this becomes:
batch_size × num_headscuBLAS Batched GEMM:
// Pseudocode for batched attention matmul
cublasSgemmStridedBatched(
handle,
CUBLAS_OP_T, CUBLAS_OP_N, // K transposed, Q not
seq_len, seq_len, d_k, // Matrix dimensions
&alpha,
K, d_k, seq_len * d_k, // K matrix and stride
Q, d_k, seq_len * d_k, // Q matrix and stride
&beta,
attn_scores, seq_len, seq_len * seq_len, // Output
batch_size * num_heads // Total batches
);
This single call computes attention scores for all heads in all sequences simultaneously.
TPU Architecture and Mapping:
Google's TPUs are optimized specifically for the matrix-heavy computations that dominate Transformers:
TPU Optimizations for Attention:
XLA fusion: Compiler automatically fuses attention operations to reduce memory roundtrips
Memory layout optimization: TPU compiler rearranges tensors for optimal access patterns
Pipeline parallelism: Multiple attention layers process different batches simultaneously
Memory Considerations:
The $O(n^2)$ attention matrix creates memory pressure:
| Sequence Length | Attention Matrix Size (fp32) |
|---|---|
| 512 | 1 MB per head |
| 2048 | 16 MB per head |
| 8192 | 256 MB per head |
| 32768 | 4 GB per head |
For long sequences, the attention matrix may not fit in GPU memory, motivating Flash Attention and other memory-efficient approaches.
Flash Attention is a breakthrough that fuses the attention computation (QK^T, softmax, multiply by V) into a single kernel that never materializes the full n×n attention matrix. It achieves this by processing attention in tiles, reducing memory usage from O(n²) to O(n) while maintaining full parallelism. This enables processing much longer sequences on the same hardware.
Understanding the computational complexity of multi-head attention reveals both its power and its limitations.
Time Complexity:
For a sequence of length $n$ with model dimension $d$:
| Operation | Complexity | Dominant for |
|---|---|---|
| Q, K, V projections | $O(n \cdot d^2)$ | Short sequences |
| $QK^T$ attention scores | $O(n^2 \cdot d)$ | Long sequences |
| Softmax | $O(n^2)$ | — |
| Attention × V | $O(n^2 \cdot d)$ | Long sequences |
| Output projection | $O(n \cdot d^2)$ | Short sequences |
Total: $O(n^2 \cdot d + n \cdot d^2)$
The crossover point where the quadratic $n^2$ term dominates is approximately: $$n^2 \cdot d > n \cdot d^2 \implies n > d$$
For $d = 512$, sequences longer than 512 tokens are dominated by the $O(n^2)$ attention computation.
Parallel Time Complexity:
With unlimited parallelism, the depth of computation determines wall-clock time:
| Operation | Parallel Depth | Sequential Dependencies |
|---|---|---|
| Projections | $O(\log d)$ | Matrix multiply reduction |
| $QK^T$ | $O(\log d)$ | Dot product reduction |
| Softmax | $O(\log n)$ | Reduction for exp and sum |
| Attention × V | $O(\log d)$ | Dot product reduction |
| Output projection | $O(\log d)$ | Matrix multiply reduction |
Total parallel depth: $O(\log n + \log d) = O(\log(nd))$
This is polylogarithmic—effectively constant for practical sequences!
Comparison:
| Architecture | Sequential Complexity | Parallel Depth |
|---|---|---|
| RNN | $O(n)$ | $O(n)$ (irreducible) |
| Transformer | $O(n^2 \cdot d)$ | $O(\log n)$ |
The Transformer's $O(\log n)$ parallel depth vs RNN's $O(n)$ is the fundamental reason for Transformer efficiency on parallel hardware.
| Sequence Length | RNN Steps | Transformer Parallel Depth | Speedup (ideal) |
|---|---|---|---|
| 128 | 128 | ~10 | 13× |
| 512 | 512 | ~12 | 43× |
| 2048 | 2048 | ~14 | 146× |
| 8192 | 8192 | ~16 | 512× |
While Transformers parallelize excellently, the O(n²) memory and compute for attention limits practical sequence lengths. At n=32K tokens, you need about 4GB just for ONE attention matrix. This has motivated extensive research into efficient attention variants (linear attention, sparse attention, etc.) that trade some expressivity for O(n) or O(n log n) complexity.
The parallel structure of Transformers enables effective scaling beyond a single GPU. Understanding the distribution strategies is essential for training large models.
1. Data Parallelism
The simplest form: replicate the model across GPUs, each processes different batches.
GPU 0: batch[0:B//N] → gradients_0
GPU 1: batch[B//N:2B//N] → gradients_1
...
GPU N: batch[(N-1)B//N:B] → gradients_N
All-reduce: gradient = mean(gradients_0, ..., gradients_N)
For attention, data parallelism is straightforward because:
2. Tensor Parallelism (for Model-Parallel Attention)
For very large models, attention layers are split across GPUs:
Head-parallel strategy:
Q,K,V Attention Concat+W^O
│ │ │
┌────┼────┐ ┌────┼────┐ ┌────┼────┐
│ │ │ │ │ │ │ │ │
GPU0 GPU1 GPU2 GPU0 GPU1 GPU2 GPU0 GPU1 GPU2
│ │ │ │ │ │ │ │ │
│ heads │ │ heads │ │ concat │
│ 0-3 │ │ 0-3 │ │ │
└────┴────┘ └────┴────┘ └──AllGather┘
Column-parallel splits $W^Q, W^K, W^V$ by output dimension:
Row-parallel splits $W^O$ by input dimension:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
import torchimport torch.nn as nnimport torch.distributed as distfrom typing import Optional class TensorParallelMultiHeadAttention(nn.Module): """ Multi-head attention with tensor parallelism across multiple GPUs. Each GPU handles a subset of attention heads, reducing per-GPU memory and enabling larger models than would fit on a single device. Pattern: 1. Column-parallel: Split Q, K, V projections by output dimension 2. Local attention: Each GPU computes attention for its heads 3. Column-parallel: Split output projection by input dimension 4. All-reduce: Sum partial outputs across GPUs """ def __init__( self, d_model: int, num_heads: int, tensor_parallel_size: int, tensor_parallel_rank: int, ): super().__init__() assert num_heads % tensor_parallel_size == 0, ( f"num_heads ({num_heads}) must be divisible by " f"tensor_parallel_size ({tensor_parallel_size})" ) self.d_model = d_model self.num_heads = num_heads self.tp_size = tensor_parallel_size self.tp_rank = tensor_parallel_rank # Each GPU handles a subset of heads self.local_num_heads = num_heads // tensor_parallel_size self.d_k = d_model // num_heads self.local_d = self.local_num_heads * self.d_k # Column-parallel Q, K, V: each GPU has (d_model, local_d) self.W_q = nn.Linear(d_model, self.local_d, bias=False) self.W_k = nn.Linear(d_model, self.local_d, bias=False) self.W_v = nn.Linear(d_model, self.local_d, bias=False) # Row-parallel output: each GPU has (local_d, d_model) # Output is summed across GPUs via all-reduce self.W_o = nn.Linear(self.local_d, d_model, bias=False) self.scale = self.d_k ** -0.5 def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass with tensor parallelism. Args: x: Input tensor (batch, seq_len, d_model) Returns: Output tensor (batch, seq_len, d_model) """ batch_size, seq_len, _ = x.shape # Step 1: Local Q, K, V projections (no communication) Q = self.W_q(x) # (batch, seq, local_d) K = self.W_k(x) V = self.W_v(x) # Step 2: Reshape for local attention computation # (batch, seq, local_heads, d_k) -> (batch, local_heads, seq, d_k) Q = Q.view(batch_size, seq_len, self.local_num_heads, self.d_k).transpose(1, 2) K = K.view(batch_size, seq_len, self.local_num_heads, self.d_k).transpose(1, 2) V = V.view(batch_size, seq_len, self.local_num_heads, self.d_k).transpose(1, 2) # Step 3: Attention for local heads (no communication) attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale attn_weights = torch.softmax(attn_scores, dim=-1) attn_output = torch.matmul(attn_weights, V) # Step 4: Reshape back # (batch, local_heads, seq, d_k) -> (batch, seq, local_d) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, seq_len, self.local_d) # Step 5: Row-parallel output projection output = self.W_o(attn_output) # (batch, seq, d_model) # Step 6: All-reduce across tensor parallel group # Each GPU has computed output from its heads; sum contributions if self.tp_size > 1: dist.all_reduce(output, op=dist.ReduceOp.SUM) return output def demonstrate_tensor_parallelism(): """ Demonstrate how attention is split across multiple GPUs. """ print("Tensor Parallel Attention Demonstration") print("=" * 60) d_model = 768 num_heads = 12 tp_size = 4 # Simulating 4 GPUs print(f"Configuration:") print(f" d_model: {d_model}") print(f" Total heads: {num_heads}") print(f" Tensor parallel size: {tp_size}") print(f" Heads per GPU: {num_heads // tp_size}") print() # Show parameter distribution d_k = d_model // num_heads local_d = (num_heads // tp_size) * d_k print("Parameter distribution per GPU:") print(f" W_q, W_k, W_v: {d_model} × {local_d} = {d_model * local_d:,} params each") print(f" W_o: {local_d} × {d_model} = {local_d * d_model:,} params") print(f" Total per GPU: {3 * d_model * local_d + local_d * d_model:,}") print() # Compare to non-parallel full_params = 4 * d_model * d_model # Q, K, V, O per_gpu_params = (3 * d_model * local_d + local_d * d_model) print(f"Full model: {full_params:,} attention params") print(f"Per GPU: {per_gpu_params:,} params ({per_gpu_params/full_params*100:.1f}% of total)") print() # Communication analysis print("Communication pattern:") print(" Forward pass:") print(" - All-reduce after output projection") print(f" - Volume: batch × seq × {d_model} × 4 bytes") print(" Backward pass:") print(" - All-reduce for input gradient") print(" - Gradient all-reduce implicit in optimizer") if __name__ == "__main__": demonstrate_tensor_parallelism()3. Sequence Parallelism
For very long sequences, split the sequence across GPUs:
Communication Costs:
| Strategy | Forward Communication | Backward Communication |
|---|---|---|
| Data Parallel | None | All-reduce gradients |
| Tensor Parallel (heads) | All-gather after heads | All-reduce gradients |
| Sequence Parallel | All-to-all for K, V | All-to-all for gradients |
Practical Scaling:
For large-scale training (GPT-3, etc.), combinations are used:
We've explored how multi-head attention's parallel structure enables extraordinary computational efficiency on modern hardware. Let's consolidate the key insights:
What's Next:
In the next page, we'll explore what different attention heads learn—moving from the computational mechanics to the interpretability of multi-head attention. We'll see how heads specialize to capture different linguistic and semantic patterns, and how to analyze and visualize these learned representations.
You now understand why Transformers train orders of magnitude faster than RNNs despite higher theoretical complexity. The key insight: parallelism matters more than operation count. Transformers' O(n²) work is executed in O(log n) parallel depth, while RNNs' O(n) work requires O(n) sequential steps. This fundamental difference has reshaped deep learning.