Loading learning content...
The transformer's self-attention mechanism is both its greatest strength and its fundamental limitation. Self-attention allows every token to interact with every other token, capturing long-range dependencies that eluded earlier architectures. However, this comes at a steep computational cost: O(n²) time and memory complexity where n is the sequence length.
For short sequences (up to ~512 tokens), this is manageable. But as we push toward longer contexts—full documents, entire codebases, hour-long audio, or high-resolution images—the quadratic scaling becomes prohibitive:
| Sequence Length | Attention Operations | Memory (FP32) |
|---|---|---|
| 512 | 262,144 | ~1 MB |
| 2,048 | 4,194,304 | ~16 MB |
| 8,192 | 67,108,864 | ~256 MB |
| 32,768 | 1,073,741,824 | ~4 GB |
| 131,072 | 17,179,869,184 | ~64 GB |
Processing a single book (~100K tokens) with standard attention would require ~40 GB just for the attention matrix—per layer, per head, per batch item. This is clearly impractical.
This page explores the landscape of efficient transformer designs. You'll understand sparse attention patterns, linear attention approximations, memory-efficient algorithms like FlashAttention, and purpose-built architectures like Longformer and BigBird. By the end, you'll know how to choose the right efficient transformer for your use case.
The Quest for Efficient Attention:
Researchers have pursued multiple strategies to reduce attention complexity:
Each approach offers different trade-offs between efficiency, expressiveness, and ease of implementation. Understanding these trade-offs is crucial for practitioners working with long sequences.
The simplest approach to reducing attention complexity is to restrict which positions can attend to which other positions. Instead of full O(n²) attention, we define a sparse attention pattern where each position attends only to a subset of positions.
Core Insight: Not all token pairs need to interact directly. Local context is often sufficient, and long-range dependencies can be captured through a few global positions or via multi-hop attention across layers.
Each token attends only to its local neighborhood:
$$\text{Attention}(i) = {j : |i - j| \leq w}$$
where w is the window size. This gives O(n × w) = O(n) complexity for fixed window size.
Properties:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
import torchimport torch.nn as nnimport torch.nn.functional as Fimport mathfrom typing import Optional, Tuple def create_local_attention_mask(seq_length: int, window_size: int, device: torch.device) -> torch.Tensor: """ Create a local (sliding window) attention mask. Returns: Boolean mask where True = position CAN attend (not masked) """ # Create position indices positions = torch.arange(seq_length, device=device) # Distance matrix: |i - j| distance = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1)) # Local attention: |i - j| <= window_size local_mask = distance <= window_size return local_mask def create_strided_attention_mask( seq_length: int, stride: int, device: torch.device) -> torch.Tensor: """ Create a strided (dilated) attention mask. Token i attends to positions where (i - j) % stride == 0. """ positions = torch.arange(seq_length, device=device) relative = positions.unsqueeze(0) - positions.unsqueeze(1) strided_mask = (relative % stride == 0) return strided_mask def create_global_local_mask( seq_length: int, window_size: int, global_indices: list[int], device: torch.device) -> torch.Tensor: """ Create mask combining local attention with global tokens. Global tokens can attend to and be attended by all positions. """ # Start with local attention mask = create_local_attention_mask(seq_length, window_size, device) # Global tokens attend to everything for idx in global_indices: mask[idx, :] = True # Global token attends to all mask[:, idx] = True # All tokens attend to global return mask class SparseAttention(nn.Module): """ Sparse attention with configurable attention pattern. """ def __init__( self, hidden_size: int, num_heads: int, window_size: int = 256, global_tokens: int = 0, dropout: float = 0.1 ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.window_size = window_size self.global_tokens = global_tokens self.q_proj = nn.Linear(hidden_size, hidden_size) self.k_proj = nn.Linear(hidden_size, hidden_size) self.v_proj = nn.Linear(hidden_size, hidden_size) self.out_proj = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(dropout) self.scale = self.head_dim ** -0.5 def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: batch_size, seq_length, _ = hidden_states.shape # Project Q, K, V Q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim) K = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim) V = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim) Q = Q.transpose(1, 2) # [batch, heads, seq, head_dim] K = K.transpose(1, 2) V = V.transpose(1, 2) # Create sparse attention pattern global_indices = list(range(self.global_tokens)) sparse_mask = create_global_local_mask( seq_length, self.window_size, global_indices, hidden_states.device ) # Compute attention with sparse mask attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # Apply sparse mask (set non-attended positions to -inf) sparse_mask = sparse_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq, seq] attn_weights = attn_weights.masked_fill(~sparse_mask, float('-inf')) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_probs = F.softmax(attn_weights, dim=-1) attn_probs = self.dropout(attn_probs) output = torch.matmul(attn_probs, V) output = output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size) return self.out_proj(output)Each token attends to every k-th token:
$$\text{Attention}(i) = {j : (i - j) \mod k = 0}$$
This captures patterns at different scales. Combined with local attention, you get both local and strided patterns.
OpenAI's Sparse Transformer combined local and strided patterns:
Factorized Attention: Split attention into two patterns
Alternating Layers: Half the layers use each pattern
Complexity: O(n√n) instead of O(n²)
This enabled training on sequences 30× longer than vanilla transformers with the same memory budget.
With sparse attention, information from distant positions must propagate through multiple layers rather than directly. If the attention pattern has connectivity c and there are L layers, information can reach ~c^L positions. This is sufficient for most tasks but may limit very long-range reasoning that requires direct token interactions.
Longformer (2020) was designed specifically for processing long documents. It combines local windowed attention with task-specific global attention, achieving O(n) complexity while maintaining strong performance.
Attention Mechanism:
Sliding Window Attention: Every token attends to w/2 tokens on each side
Global Attention: Selected tokens attend to (and are attended by) all positions
| Task | Global Tokens | Window Size | Max Length |
|---|---|---|---|
| Classification | [CLS] only | 512 | 4096 |
| Question Answering | Question tokens | 512 | 4096 |
| Summarization | First tokens | 512 | 16384 |
| Language Modeling | None | 512 | 4096 |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Optional, List class LongformerAttention(nn.Module): """ Longformer-style attention: sliding window + global attention. """ def __init__( self, hidden_size: int, num_heads: int, window_size: int = 512, # One-sided window dropout: float = 0.1 ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.window_size = window_size self.scale = self.head_dim ** -0.5 # Separate projections for local and global attention self.q_local = nn.Linear(hidden_size, hidden_size) self.k_local = nn.Linear(hidden_size, hidden_size) self.v_local = nn.Linear(hidden_size, hidden_size) self.q_global = nn.Linear(hidden_size, hidden_size) self.k_global = nn.Linear(hidden_size, hidden_size) self.v_global = nn.Linear(hidden_size, hidden_size) self.out_proj = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(dropout) def _sliding_window_attention( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Efficient sliding window attention. For simplicity, this implementation uses the mask-based approach. Production versions use blocked sparse attention for true efficiency. """ batch_size, num_heads, seq_length, head_dim = query.shape # Create sliding window mask positions = torch.arange(seq_length, device=query.device) distance = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1)) window_mask = distance <= self.window_size # Compute attention scores scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale # Apply window mask window_mask = window_mask.unsqueeze(0).unsqueeze(0) scores = scores.masked_fill(~window_mask, float('-inf')) if attention_mask is not None: scores = scores + attention_mask attn_probs = F.softmax(scores, dim=-1) attn_probs = self.dropout(attn_probs) return torch.matmul(attn_probs, value) def _global_attention( self, hidden_states: torch.Tensor, global_indices: List[int], ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute global attention for specified positions. Returns: global_output: Output for global positions (attending to all) to_global_output: Output from all positions attending to global """ batch_size, seq_length, _ = hidden_states.shape if not global_indices: return None, None # Global queries attend to all keys/values Q_global = self.q_global(hidden_states[:, global_indices, :]) K_all = self.k_global(hidden_states) V_all = self.v_global(hidden_states) # Reshape for attention Q_global = Q_global.view(batch_size, len(global_indices), self.num_heads, self.head_dim).transpose(1, 2) K_all = K_all.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) V_all = V_all.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) # Global queries attend to all positions scores = torch.matmul(Q_global, K_all.transpose(-2, -1)) * self.scale attn_probs = F.softmax(scores, dim=-1) global_output = torch.matmul(attn_probs, V_all) global_output = global_output.transpose(1, 2).contiguous() # All queries attend to global keys Q_all = self.q_global(hidden_states) K_global = self.k_global(hidden_states[:, global_indices, :]) V_global = self.v_global(hidden_states[:, global_indices, :]) Q_all = Q_all.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) K_global = K_global.view(batch_size, len(global_indices), self.num_heads, self.head_dim).transpose(1, 2) V_global = V_global.view(batch_size, len(global_indices), self.num_heads, self.head_dim).transpose(1, 2) scores = torch.matmul(Q_all, K_global.transpose(-2, -1)) * self.scale attn_probs = F.softmax(scores, dim=-1) to_global_output = torch.matmul(attn_probs, V_global) to_global_output = to_global_output.transpose(1, 2).contiguous() return global_output, to_global_output def forward( self, hidden_states: torch.Tensor, global_attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: batch_size, seq_length, _ = hidden_states.shape # Local attention projections Q = self.q_local(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) K = self.k_local(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_local(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) # Local sliding window attention local_output = self._sliding_window_attention(Q, K, V, attention_mask) local_output = local_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size) # Global attention (if specified) if global_attention_mask is not None: global_indices = torch.where(global_attention_mask[0] > 0)[0].tolist() global_out, to_global_out = self._global_attention(hidden_states, global_indices) if global_out is not None: # Update global positions with their output for idx, global_idx in enumerate(global_indices): local_output[:, global_idx, :] = global_out[:, idx, :].view(batch_size, self.hidden_size) # Add contribution from attending to global tokens local_output = local_output + to_global_out.view(batch_size, seq_length, self.hidden_size) return self.out_proj(local_output)Longformer is particularly effective for document-level tasks where local context matters (reading comprehension, summarization). It's available through Hugging Face as 'allenai/longformer-base-4096' and can process documents up to 16K tokens. For classification, simply mark the [CLS] token as global; the model handles the rest.
A fundamentally different approach to efficient attention is to avoid the O(n²) computation entirely by reformulating attention with linear complexity kernels.
The Key Insight:
Standard softmax attention: $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V$$
The bottleneck is computing $QK^T$, an $n \times n$ matrix. But if we could decompose the softmax as a product of functions: $$\text{softmax}(q \cdot k) \approx \phi(q) \cdot \phi(k)$$
Then we could rewrite attention as: $$\text{Attention} \approx \phi(Q)(\phi(K)^T V)$$
Notice the associativity: $(\phi(K)^T V)$ is a $d \times d$ matrix (independent of n), and we can multiply $\phi(Q)$ by this matrix in O(nd²) time—linear in sequence length!
The Performer model uses random feature approximation to approximate the softmax kernel. Specifically, it uses FAVOR+ (Fast Attention Via positive Orthogonal Random features) to create an unbiased estimator of softmax attention with only O(n⋅r⋅d) complexity, where r is the number of random features (typically 256).
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
import torchimport torch.nn as nnimport torch.nn.functional as Fimport mathfrom typing import Optional class LinearAttention(nn.Module): """ Linear attention using kernel feature maps. Complexity: O(n⋅d²) instead of O(n²⋅d) """ def __init__( self, hidden_size: int, num_heads: int, feature_map: str = "elu", # or "softmax_kernel" eps: float = 1e-6 ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.feature_map = feature_map self.eps = eps self.q_proj = nn.Linear(hidden_size, hidden_size) self.k_proj = nn.Linear(hidden_size, hidden_size) self.v_proj = nn.Linear(hidden_size, hidden_size) self.out_proj = nn.Linear(hidden_size, hidden_size) def _feature_map(self, x: torch.Tensor) -> torch.Tensor: """Apply feature map to make attention kernel linear.""" if self.feature_map == "elu": # ELU + 1 ensures positivity return F.elu(x) + 1 elif self.feature_map == "relu": return F.relu(x) elif self.feature_map == "softmax_kernel": # Approximates softmax behavior return F.softmax(x, dim=-1) else: raise ValueError(f"Unknown feature map: {self.feature_map}") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: batch_size, seq_length, _ = hidden_states.shape # Project and reshape Q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim) K = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim) V = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim) Q = Q.transpose(1, 2) # [batch, heads, seq, head_dim] K = K.transpose(1, 2) V = V.transpose(1, 2) # Apply feature map Q = self._feature_map(Q) K = self._feature_map(K) # Linear attention: instead of QK^TV, compute Q(K^TV) # K^T V: [batch, heads, head_dim, head_dim] KV = torch.einsum('bhnd,bhnv->bhdv', K, V) # Normalization: sum of attention weights # Z = Q ⋅ (K^T ⋅ 1) where 1 is all-ones Z = torch.einsum('bhnd,bhd->bhn', Q, K.sum(dim=2)) Z = Z.unsqueeze(-1) + self.eps # Avoid division by zero # Output: Q(K^TV) / Z output = torch.einsum('bhnd,bhdv->bhnv', Q, KV) / Z # Reshape and project output = output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size) return self.out_proj(output) class PerformerAttention(nn.Module): """ Performer attention with FAVOR+ (Random Feature Approximation). Uses orthogonal random features to approximate softmax kernel. """ def __init__( self, hidden_size: int, num_heads: int, num_features: int = 256, ortho_scaling: float = 0.0, eps: float = 1e-6 ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.num_features = num_features self.ortho_scaling = ortho_scaling self.eps = eps self.q_proj = nn.Linear(hidden_size, hidden_size) self.k_proj = nn.Linear(hidden_size, hidden_size) self.v_proj = nn.Linear(hidden_size, hidden_size) self.out_proj = nn.Linear(hidden_size, hidden_size) # Register random feature matrix (fixed, not learned) self.register_buffer( "random_features", self._create_random_features() ) def _create_random_features(self) -> torch.Tensor: """Create orthogonal random features for kernel approximation.""" # Create random matrix random_matrix = torch.randn(self.num_features, self.head_dim) # Make it orthogonal via QR decomposition q, _ = torch.linalg.qr(random_matrix.T) random_matrix = q.T[:self.num_features, :] # Scale appropriately random_matrix *= math.sqrt(self.head_dim) return random_matrix def _favor_plus_features(self, x: torch.Tensor) -> torch.Tensor: """ Compute FAVOR+ features: positive random features for softmax kernel. φ(x) = exp(x @ Ω.T - ||x||²/2) / sqrt(m) """ # x: [batch, heads, seq, head_dim] # random_features: [num_features, head_dim] projection = torch.einsum('bhsd,fd->bhsf', x, self.random_features) # Softmax kernel approximation norm_sq = (x ** 2).sum(dim=-1, keepdim=True) / 2 features = torch.exp(projection - norm_sq) features = features / math.sqrt(self.num_features) return features def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: batch_size, seq_length, _ = hidden_states.shape # Project Q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim) K = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim) V = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim) Q = Q.transpose(1, 2) K = K.transpose(1, 2) V = V.transpose(1, 2) # Apply FAVOR+ features Q_prime = self._favor_plus_features(Q) # [batch, heads, seq, features] K_prime = self._favor_plus_features(K) # Linear attention with random features KV = torch.einsum('bhsf,bhsv->bhfv', K_prime, V) Z = K_prime.sum(dim=2) # [batch, heads, features] output = torch.einsum('bhsf,bhfv->bhsv', Q_prime, KV) normalizer = torch.einsum('bhsf,bhf->bhs', Q_prime, Z).unsqueeze(-1) + self.eps output = output / normalizer output = output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size) return self.out_proj(output)While linear attention achieves O(n) complexity, it often underperforms standard attention on tasks requiring precise token-to-token matching. The kernel approximation introduces noise, and the lack of explicit n×n attention matrix limits the model's ability to compute exact sparse attention patterns. For many tasks, sparse attention (Longformer) outperforms linear attention (Performer).
BigBird (2020) combined multiple sparse attention patterns with theoretical analysis showing that sparse attention can approximate full attention under certain conditions.
BigBird's Three Attention Components:
The random component is the key innovation—it provides theoretical guarantees that information can propagate across the full sequence in O(1) layers (like an expander graph).
| Component | Description | Complexity Contribution |
|---|---|---|
| Local | Window size w | O(n × w) |
| Global | g global tokens | O(n × g) |
| Random | r random edges per token | O(n × r) |
| Total | w + g + r << n | O(n) |
Theoretical Properties:
BigBird's paper proved that sparse attention with these three components can:
These theoretical results provide justification for why sparse attention doesn't fundamentally limit transformer capabilities.
ETC (Extended Transformer Construction):
BigBird uses a variant called ETC that carefully handles global tokens:
The random attention pattern is typically fixed at initialization (not resampled each forward pass) to enable caching and parallel computation. In practice, the random component provides diminishing returns beyond the first few random connections—most of the benefit comes from local + global attention.
While sparse and linear attention reduce theoretical complexity, FlashAttention takes a different approach: compute exact standard attention, but do it much faster and with less memory through hardware-aware algorithms.
The Memory Bottleneck:
Standard attention implementations materialize the full n×n attention matrix, which:
FlashAttention's Key Ideas:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
"""FlashAttention: A Conceptual Implementation Note: This is a simplified Python version for understanding.The actual FlashAttention is a CUDA kernel with careful memory management.""" import torchimport torch.nn.functional as Fimport mathfrom typing import Tuple def standard_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor: """ Standard O(n²) memory attention implementation. This is what we're trying to improve. """ scale = Q.shape[-1] ** -0.5 # Materialize full n×n attention matrix (memory bottleneck!) attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * scale # O(n²) memory attn_weights = F.softmax(attn_weights, dim=-1) # Still O(n²) output = torch.matmul(attn_weights, V) return output def online_softmax_update( m_old: torch.Tensor, # Running max l_old: torch.Tensor, # Running sum of exp O_old: torch.Tensor, # Running output S_block: torch.Tensor, # Current block of attention scores V_block: torch.Tensor # Current block of values) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Online softmax: update running statistics with new block. This is the key to avoiding full materialization. """ # New max considering this block m_new = torch.maximum(m_old, S_block.max(dim=-1, keepdim=True).values) # Rescale old values with new max scale_old = torch.exp(m_old - m_new) scale_new = torch.exp(S_block - m_new) # Update running sum l_new = scale_old * l_old + scale_new.sum(dim=-1, keepdim=True) # Update running output O_new = (scale_old * l_old * O_old + torch.matmul(scale_new, V_block)) / l_new return m_new, l_new, O_new def flash_attention_forward( Q: torch.Tensor, # [batch, heads, seq_q, head_dim] K: torch.Tensor, # [batch, heads, seq_kv, head_dim] V: torch.Tensor, # [batch, heads, seq_kv, head_dim] block_size: int = 64) -> torch.Tensor: """ FlashAttention-style blocked computation. Memory: O(n) instead of O(n²) The actual implementation would be a fused CUDA kernel. This shows the algorithm conceptually. """ batch_size, num_heads, seq_q, head_dim = Q.shape _, _, seq_kv, _ = K.shape scale = head_dim ** -0.5 # Initialize output and running statistics O = torch.zeros_like(Q) m = torch.full((batch_size, num_heads, seq_q, 1), float('-inf'), device=Q.device) l = torch.zeros((batch_size, num_heads, seq_q, 1), device=Q.device) # Process K, V in blocks num_kv_blocks = (seq_kv + block_size - 1) // block_size for kv_block_idx in range(num_kv_blocks): kv_start = kv_block_idx * block_size kv_end = min(kv_start + block_size, seq_kv) K_block = K[:, :, kv_start:kv_end, :] V_block = V[:, :, kv_start:kv_end, :] # Compute attention scores for this block S_block = torch.matmul(Q, K_block.transpose(-2, -1)) * scale # Online softmax update m, l, O = online_softmax_update(m, l, O, S_block, V_block) return O class FlashAttentionModule(torch.nn.Module): """ Flash Attention wrapper using PyTorch's scaled_dot_product_attention. PyTorch 2.0+ includes FlashAttention optimizations automatically. """ def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.0): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.q_proj = torch.nn.Linear(hidden_size, hidden_size) self.k_proj = torch.nn.Linear(hidden_size, hidden_size) self.v_proj = torch.nn.Linear(hidden_size, hidden_size) self.out_proj = torch.nn.Linear(hidden_size, hidden_size) self.dropout = dropout def forward(self, hidden_states: torch.Tensor, attention_mask=None) -> torch.Tensor: batch_size, seq_length, _ = hidden_states.shape Q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) K = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) # Use PyTorch's optimized attention (uses FlashAttention when available) output = F.scaled_dot_product_attention( Q, K, V, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=False ) output = output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size) return self.out_proj(output)| Metric | Standard Attention | FlashAttention | Improvement |
|---|---|---|---|
| Memory (seq=2048) | 64 MB | 4 MB | 16× |
| Speed (A100, seq=2048) | 1.0× baseline | 3-5× faster | 3-5× |
| Max sequence length | ~4K tokens | ~64K tokens | 16× |
FlashAttention has become the default attention implementation in most modern frameworks. PyTorch 2.0+ includes it automatically via torch.nn.functional.scaled_dot_product_attention. HuggingFace Transformers uses it when available. Always use FlashAttention unless you have a specific reason not to—it provides identical outputs with better performance.
With many efficient transformer variants available, selecting the right one requires understanding your specific requirements and constraints.
| Model | Complexity | Accuracy | Best For |
|---|---|---|---|
| Standard + FlashAttention | O(n²) time, O(n) memory | Highest | Sequences up to ~16K, any task |
| Longformer | O(n) time and memory | Very High | Long documents, classification, QA |
| BigBird | O(n) time and memory | Very High | Long documents with theoretical guarantees |
| Performer | O(n) time and memory | Medium | Very long sequences where approximation is OK |
| Linear Transformer | O(n) time and memory | Medium | Extreme length, less precision needed |
Start with FlashAttention on standard transformers. It handles sequences up to ~64K tokens on modern GPUs. Only switch to sparse or linear attention if you're still memory/compute constrained. The accuracy trade-offs of approximate methods are often not worth it unless you're at extreme sequence lengths.
You now understand the landscape of efficient transformer architectures: sparse attention patterns (Longformer, BigBird), linear attention approximations (Performer), and hardware-aware optimizations (FlashAttention). The key takeaway is that the O(n²) bottleneck has been largely solved in practice—FlashAttention for most use cases, sparse attention for very long documents. Next, we'll explore Vision Transformers, which adapt the transformer architecture to computer vision.