Loading content...
At the heart of the transformer's attention mechanism lies a seemingly simple but critically important design choice: dividing attention scores by $\sqrt{d_k}$ before applying softmax. This scaling factor is not arbitrary—it's a mathematically motivated solution to a fundamental problem that plagued early attention implementations.
Without proper scaling:
The "Attention is All You Need" paper introduced scaled dot-product attention precisely to address these issues. Understanding why this scaling works requires diving into the statistics of dot products and the behavior of softmax in different input regimes.
This page provides the rigorous mathematical justification for the $\sqrt{d_k}$ factor and explores its consequences for transformer training.
By the end of this page, you will understand the statistical justification for the √d_k scaling factor, how it prevents softmax saturation, the relationship between dimension and variance, and practical implications for attention in high-dimensional spaces.
Before understanding the solution, let's clearly identify the problem. Consider computing attention scores without scaling:
$$S_{ij} = Q_i \cdot K_j = \sum_{l=1}^{d_k} Q_{il} K_{jl}$$
The dot product sums $d_k$ terms. As $d_k$ increases, the magnitude of this sum grows, even if individual vector elements are well-behaved.
The Variance Problem:
Assume $Q$ and $K$ elements are independently drawn from a standard normal distribution $\mathcal{N}(0, 1)$:
$$Q_{il} \sim \mathcal{N}(0, 1), \quad K_{jl} \sim \mathcal{N}(0, 1)$$
Each product $Q_{il} K_{jl}$ has:
Since the dot product is a sum of $d_k$ independent terms: $$\text{Var}(Q_i \cdot K_j) = \sum_{l=1}^{d_k} \text{Var}(Q_{il} K_{jl}) = d_k$$
So $Q_i \cdot K_j \sim \mathcal{N}(0, d_k)$ — the variance scales with dimension.
123456789101112131415161718192021222324252627282930313233343536373839404142
import numpy as np def demonstrate_variance_growth(): """Show how dot product variance grows with dimension.""" np.random.seed(42) dimensions = [1, 4, 16, 64, 256, 1024] n_samples = 10000 print("Dot Product Variance vs. Dimension") print("=" * 50) print(f"{'d_k':<8} {'Theoretical Var':<18} {'Empirical Var':<18} {'Std Dev':<12}") print("-" * 50) for d_k in dimensions: # Sample random Q and K vectors Q = np.random.randn(n_samples, d_k) K = np.random.randn(n_samples, d_k) # Compute dot products dot_products = np.sum(Q * K, axis=1) empirical_var = np.var(dot_products) empirical_std = np.std(dot_products) print(f"{d_k:<8} {d_k:<18.2f} {empirical_var:<18.2f} {empirical_std:<12.2f}") return dot_products # Run demonstrationdots = demonstrate_variance_growth() # Show distribution for a specific d_kd_k = 512Q = np.random.randn(10000, d_k)K = np.random.randn(10000, d_k)dot_products = np.sum(Q * K, axis=1) print(f"\nFor d_k = {d_k}:")print(f" Mean: {dot_products.mean():.3f} (expected: 0)")print(f" Std: {dot_products.std():.3f} (expected: {np.sqrt(d_k):.3f})")print(f" Range: [{dot_products.min():.1f}, {dot_products.max():.1f}]")Consequence for Softmax:
When $d_k$ is large (e.g., 64 in typical transformers), dot products have standard deviation $\sqrt{64} = 8$. This means typical scores range from roughly $-16$ to $+16$ (within 2 standard deviations).
When we apply softmax to values in this range: $$\text{softmax}([..., -16, 0, +16, ...])$$
The result is extremely peaked—essentially one-hot. The $+16$ entry dominates exponentially: $$e^{16} / (e^{-16} + e^0 + e^{16}) \approx e^{16} / e^{16} = 1$$
In the saturation regime, softmax outputs are nearly one-hot regardless of small input changes. The gradient becomes vanishingly small because softmax(z)_i ≈ 1 implies ∂softmax_i/∂z_i ≈ 0. Training grinds to a halt.
The solution is elegant: divide by $\sqrt{d_k}$ to normalize variance back to 1.
Mathematical Justification:
If $X \sim \mathcal{N}(0, d_k)$, then $X/\sqrt{d_k} \sim \mathcal{N}(0, 1)$.
More precisely: $$\text{Var}\left(\frac{Q_i \cdot K_j}{\sqrt{d_k}}\right) = \frac{\text{Var}(Q_i \cdot K_j)}{d_k} = \frac{d_k}{d_k} = 1$$
After scaling, attention scores have unit variance regardless of dimension. Softmax operates in a stable regime where gradients flow effectively.
Scaled Dot-Product Attention:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
import numpy as npfrom scipy.special import softmax def compare_scaled_unscaled(d_k: int, n: int = 10): """Compare attention behavior with and without scaling.""" np.random.seed(42) # Generate random Q and K Q = np.random.randn(n, d_k) K = np.random.randn(n, d_k) # Compute scores scores_unscaled = Q @ K.T scores_scaled = scores_unscaled / np.sqrt(d_k) # Apply softmax attn_unscaled = softmax(scores_unscaled, axis=-1) attn_scaled = softmax(scores_scaled, axis=-1) # Analyze distributions print(f"Dimension d_k = {d_k}") print("=" * 50) print(f"\nScore Statistics:") print(f" Unscaled - Var: {scores_unscaled.var():.2f}, Range: [{scores_unscaled.min():.1f}, {scores_unscaled.max():.1f}]") print(f" Scaled - Var: {scores_scaled.var():.2f}, Range: [{scores_scaled.min():.1f}, {scores_scaled.max():.1f}]") # Entropy measures how "spread out" the attention is def entropy(p): return -np.sum(p * np.log(p + 1e-10), axis=-1).mean() max_entropy = np.log(n) # Uniform distribution entropy print(f"\nAttention Entropy (max = {max_entropy:.2f}):") print(f" Unscaled: {entropy(attn_unscaled):.3f} ({100*entropy(attn_unscaled)/max_entropy:.1f}% of max)") print(f" Scaled: {entropy(attn_scaled):.3f} ({100*entropy(attn_scaled)/max_entropy:.1f}% of max)") # Max attention weight (peakedness) print(f"\nMax Attention Weight (per row):") print(f" Unscaled: {attn_unscaled.max(axis=-1).mean():.4f}") print(f" Scaled: {attn_scaled.max(axis=-1).mean():.4f}") return attn_scaled, attn_unscaled # Test with different dimensionsfor d_k in [4, 64, 512]: compare_scaled_unscaled(d_k) print()| d_k | Unscaled Variance | Scaled Variance | Unscaled Entropy | Scaled Entropy |
|---|---|---|---|---|
| 4 | ~4 | ~1 | High (healthy) | High (healthy) |
| 64 | ~64 | ~1 | Low (peaked) | Moderate (good) |
| 512 | ~512 | ~1 | Very low (saturated) | Moderate (good) |
After √d_k scaling, attention behavior is consistent across dimensions. Whether d_k is 32 or 512, softmax operates in approximately the same regime. This is crucial for stable hyperparameter transfer across model sizes.
The scaling factor's most important effect is on gradient flow during training. Let's analyze the softmax gradient in different input regimes.
Recall the Softmax Gradient:
For $s = \text{softmax}(z)$: $$\frac{\partial s_i}{\partial z_j} = s_i(\delta_{ij} - s_j)$$
Case 1: Saturated Softmax (Unscaled, Large d_k)
When scores are large and one dominates:
Case 2: Balanced Softmax (Scaled)
When scores are moderate:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
import numpy as npfrom scipy.special import softmax def analyze_softmax_gradients(scores: np.ndarray, label: str): """Analyze gradient magnitudes for softmax on given scores.""" s = softmax(scores) # Compute diagonal of Jacobian: d(s_i)/d(z_i) = s_i * (1 - s_i) diag_grad = s * (1 - s) # Off-diagonal gradients: d(s_i)/d(z_j) = -s_i * s_j # Average magnitude of off-diagonal jacobian = np.diag(diag_grad) - np.outer(s, s) off_diag = jacobian - np.diag(np.diag(jacobian)) print(f"\n{label}") print("-" * 40) print(f"Score range: [{scores.min():.1f}, {scores.max():.1f}]") print(f"Softmax output: {np.round(s, 4)}") print(f"Max softmax: {s.max():.6f}") print(f"Entropy: {-np.sum(s * np.log(s + 1e-10)):.4f}") print(f"Diagonal gradient magnitudes: {np.round(diag_grad, 6)}") print(f"Max gradient magnitude: {diag_grad.max():.6f}") print(f"Mean gradient magnitude: {diag_grad.mean():.6f}") # Scenario: d_k = 512 dimensionsd_k = 512np.random.seed(42) # Sample a single query-key dot product scenarioQ = np.random.randn(d_k)K = np.random.randn(4, d_k) # 4 keys # Unscaled scoresscores_unscaled = K @ Q # Shape (4,)scores_scaled = scores_unscaled / np.sqrt(d_k) analyze_softmax_gradients(scores_unscaled, "UNSCALED (d_k=512)")analyze_softmax_gradients(scores_scaled, "SCALED (d_k=512)") # Extreme case: even larger unscaledprint("\n" + "=" * 50)print("Impact on training:")print("=" * 50) # Simulate gradient accumulation over many examplesn_examples = 1000unscaled_grads = []scaled_grads = [] for _ in range(n_examples): Q = np.random.randn(d_k) K = np.random.randn(4, d_k) scores_u = K @ Q scores_s = scores_u / np.sqrt(d_k) s_u = softmax(scores_u) s_s = softmax(scores_s) grad_u = s_u * (1 - s_u) grad_s = s_s * (1 - s_s) unscaled_grads.append(grad_u.mean()) scaled_grads.append(grad_s.mean()) print(f"\nMean gradient magnitude over {n_examples} samples:")print(f" Unscaled: {np.mean(unscaled_grads):.6f}")print(f" Scaled: {np.mean(scaled_grads):.6f}")print(f" Ratio (scaled/unscaled): {np.mean(scaled_grads)/np.mean(unscaled_grads):.1f}x")Practical Impact:
The gradient difference is dramatic. For $d_k = 512$:
| Metric | Unscaled | Scaled | Ratio |
|---|---|---|---|
| Mean | gradient | ~0.00001 | |
| Training progress | Stalled | Normal | — |
| Epochs to converge | ∞ | Normal | — |
Without scaling, gradients are effectively zero, and the model cannot learn attention patterns. With scaling, gradients have sufficient magnitude for optimization to proceed.
The gradient issue isn't about training being 'slower'—it's about training being impossible. Vanishing gradients mean the attention weights become fixed at their random initialization. Learning never begins.
While $\sqrt{d_k}$ is the standard choice, it's worth understanding why this specific value was chosen and what alternatives exist.
Why Not Other Scalings?
The √d_k Derivation:
The key insight is matching the variance of QK^T to a target distribution. If we want: $$\text{Var}\left(\frac{Q \cdot K}{\alpha}\right) = 1$$
Then: $$\frac{\text{Var}(Q \cdot K)}{\alpha^2} = 1 \implies \alpha = \sqrt{\text{Var}(Q \cdot K)} = \sqrt{d_k}$$
123456789101112131415161718192021222324252627282930313233343536373839404142
import numpy as npfrom scipy.special import softmax def compare_scaling_strategies(d_k: int = 64, n: int = 8): """Compare different scaling strategies.""" np.random.seed(42) Q = np.random.randn(n, d_k) K = np.random.randn(n, d_k) scores = Q @ K.T strategies = { 'No scaling': 1, 'Scale by √d_k': np.sqrt(d_k), 'Scale by d_k': d_k, 'Scale by d_k^(1/4)': d_k ** 0.25, 'Scale by 2√d_k': 2 * np.sqrt(d_k), } print(f"Scaling Strategy Comparison (d_k = {d_k})") print("=" * 70) print(f"{'Strategy':<20} {'Scale':<10} {'Score Var':<12} {'Entropy':<12} {'Max Attn':<10}") print("-" * 70) max_entropy = np.log(n) for name, scale in strategies.items(): scaled_scores = scores / scale attn = softmax(scaled_scores, axis=-1) score_var = scaled_scores.var() entropy = -np.sum(attn * np.log(attn + 1e-10), axis=-1).mean() max_attn = attn.max(axis=-1).mean() print(f"{name:<20} {scale:<10.2f} {score_var:<12.2f} {entropy:<12.3f} {max_attn:<10.4f}") print(f"\n(Reference: Max entropy = {max_entropy:.3f} for uniform distribution)") print("√d_k scaling achieves unit variance and balanced entropy.") compare_scaling_strategies(d_k=64)print()compare_scaling_strategies(d_k=512)Learned Temperature:
Some architectures learn a scalar temperature parameter:
$$\text{Attention} = \text{softmax}(QK^T / \tau)$$
Where $\tau$ is learned. In practice:
An alternative approach is cosine attention, which normalizes Q and K before dot product: Attention = softmax(Q̂ · K̂^T / τ) where Q̂, K̂ are L2-normalized. This bounds dot products to [-1, 1], making a learned temperature τ more meaningful.
The $\sqrt{d_k}$ scaling assumes standard normal initialization for Q and K projections. If initialization differs, the optimal scaling changes.
Standard Assumptions:
Xavier/Glorot Initialization:
With Xavier initialization, $\sigma^2 = 2/(d_{in} + d_{out})$, which accounts for the fan-in and fan-out of the layer. Combined with $\sqrt{d_k}$ scaling, this produces well-behaved attention scores.
Interaction with Layer Normalization:
When inputs are layer-normalized before projection:
Without normalization, input variance can grow or shrink through layers, potentially requiring adaptive scaling.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
import numpy as npfrom scipy.special import softmax def analyze_initialization_effects(init_std: float, d_model: int = 64, d_k: int = 16): """Analyze how initialization affects attention score distribution.""" np.random.seed(42) n = 10 # Input with unit variance X = np.random.randn(n, d_model) # Initialize projections W_q = np.random.randn(d_model, d_k) * init_std W_k = np.random.randn(d_model, d_k) * init_std # Compute Q, K Q = X @ W_q K = X @ W_k # Compute attention scores scores = Q @ K.T scores_scaled = scores / np.sqrt(d_k) # Analyze print(f"Init std = {init_std}") print(f" Q variance: {Q.var():.4f}") print(f" Score variance (unscaled): {scores.var():.4f}") print(f" Score variance (√d_k scaled): {scores_scaled.var():.4f}") # Compute optimal scale optimal_scale = np.sqrt(scores.var()) scores_optimal = scores / optimal_scale print(f" Optimal scale: {optimal_scale:.2f} (vs √d_k = {np.sqrt(d_k):.2f})") print(f" Optimally scaled variance: {scores_optimal.var():.4f}") print("Effect of Initialization Standard Deviation")print("=" * 60)print(f"d_model = 64, d_k = 16, √d_k = 4.0")print() # Standard initialization: ~1/sqrt(d_model)std_standard = 1 / np.sqrt(64)print("Standard initialization (1/√d_model):")analyze_initialization_effects(std_standard) # Xavier initializationstd_xavier = np.sqrt(2 / (64 + 16))print("\nXavier initialization:")analyze_initialization_effects(std_xavier) # Small initializationprint("\nSmall initialization (0.01):")analyze_initialization_effects(0.01) # Large initializationprint("\nLarge initialization (0.5):")analyze_initialization_effects(0.5)Practical Guidance:
For standard transformer training:
The combination of proper initialization + scaling + normalization creates a robust system where attention operates in a healthy regime throughout training.
If attention training fails or produces degenerate patterns, check: (1) Are scores before softmax in a reasonable range (roughly -3 to +3)? (2) Are attention weights neither uniform nor peaked? (3) Is gradient magnitude through attention layers comparable to other layers?
Let's formalize the complete scaled dot-product attention mechanism as presented in "Attention is All You Need":
Algorithm: Scaled Dot-Product Attention
Input:
Output: Attention output $\in \mathbb{R}^{n_q \times d_v}$
Steps:
Return Output
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
import numpy as npfrom scipy.special import softmax def scaled_dot_product_attention( Q: np.ndarray, # (n_q, d_k) K: np.ndarray, # (n_k, d_k) V: np.ndarray, # (n_k, d_v) mask: np.ndarray = None, # (n_q, n_k) or broadcastable return_weights: bool = False) -> np.ndarray: """ Scaled dot-product attention as defined in 'Attention is All You Need'. Args: Q: Query matrix with shape (n_q, d_k) K: Key matrix with shape (n_k, d_k) V: Value matrix with shape (n_k, d_v) mask: Optional mask where -inf values prevent attention return_weights: If True, also return attention weights Returns: Output with shape (n_q, d_v) Optionally, attention weights with shape (n_q, n_k) """ d_k = Q.shape[-1] # Step 1: Compute QK^T scores = Q @ K.T # (n_q, n_k) # Step 2: Scale by sqrt(d_k) scores = scores / np.sqrt(d_k) # Step 3: Apply mask (if provided) if mask is not None: scores = scores + mask # Step 4: Softmax normalization attention_weights = softmax(scores, axis=-1) # (n_q, n_k) # Step 5: Weighted aggregation output = attention_weights @ V # (n_q, d_v) if return_weights: return output, attention_weights return output def create_causal_mask(n: int) -> np.ndarray: """Create causal (lower triangular) mask.""" mask = np.triu(np.ones((n, n)), k=1) * -1e9 return mask def create_padding_mask(lengths: np.ndarray, max_len: int) -> np.ndarray: """Create padding mask from sequence lengths.""" # lengths: (batch_size,) # Output: (batch_size, max_len) - True where should be masked mask = np.arange(max_len)[None, :] >= lengths[:, None] return mask.astype(float) * -1e9 # Demonstrationnp.random.seed(42)n_q, n_k, d_k, d_v = 3, 5, 8, 16 Q = np.random.randn(n_q, d_k)K = np.random.randn(n_k, d_k)V = np.random.randn(n_k, d_v) # Standard attentionoutput, weights = scaled_dot_product_attention(Q, K, V, return_weights=True) print("Scaled Dot-Product Attention")print("=" * 50)print(f"Q shape: {Q.shape}")print(f"K shape: {K.shape}")print(f"V shape: {V.shape}")print(f"Output shape: {output.shape}")print(f"Weights shape: {weights.shape}")print(f"\nAttention weights (each row sums to 1):")print(np.round(weights, 3))print(f"Row sums: {weights.sum(axis=-1)}") # With causal mask (self-attention case where n_q = n_k)Q_self = np.random.randn(4, d_k)K_self = np.random.randn(4, d_k)V_self = np.random.randn(4, d_v)causal_mask = create_causal_mask(4) output_causal, weights_causal = scaled_dot_product_attention( Q_self, K_self, V_self, mask=causal_mask, return_weights=True) print("\nWith Causal Mask:")print(np.round(weights_causal, 3))print("Notice: Upper triangle is 0 (can't attend to future)")When Q, K, V all derive from the same input X, we have self-attention. The algorithm is identical—we just set Q = XW_Q, K = XW_K, V = XW_V and apply scaled dot-product attention. The n_q = n_k = n in this case.
Understanding the computational cost of scaled dot-product attention is crucial for efficient implementation and scaling.
Time Complexity:
| Operation | Shape Transformation | FLOPs |
|---|---|---|
| QK^T | (n_q, d_k) × (d_k, n_k) | O(n_q · n_k · d_k) |
| Scale | (n_q, n_k) | O(n_q · n_k) |
| Softmax | (n_q, n_k) | O(n_q · n_k) |
| AV | (n_q, n_k) × (n_k, d_v) | O(n_q · n_k · d_v) |
Total: O(n_q · n_k · (d_k + d_v)) ≈ O(n² · d) for self-attention where n_q = n_k = n.
Space Complexity:
Total: O(n² + n · d) — the n² term dominates for long sequences.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
import numpy as npimport time def benchmark_attention(sequence_lengths: list, d_model: int = 512, num_runs: int = 5): """Benchmark attention computation for different sequence lengths.""" from scipy.special import softmax d_k = d_v = d_model // 8 # Typical head dimension print("Attention Complexity Benchmark") print("=" * 60) print(f"d_model = {d_model}, d_k = d_v = {d_k}") print(f"{'Seq Len':<10} {'Time (ms)':<15} {'Memory (MB)':<15} {'n² scaling':<15}") print("-" * 60) reference_time = None for n in sequence_lengths: Q = np.random.randn(n, d_k).astype(np.float32) K = np.random.randn(n, d_k).astype(np.float32) V = np.random.randn(n, d_v).astype(np.float32) # Warmup _ = softmax(Q @ K.T / np.sqrt(d_k), axis=-1) @ V # Benchmark times = [] for _ in range(num_runs): start = time.perf_counter() scores = Q @ K.T / np.sqrt(d_k) attn = softmax(scores, axis=-1) output = attn @ V times.append(time.perf_counter() - start) avg_time = np.mean(times) * 1000 # ms # Memory estimate (attention matrix is the bottleneck) attn_memory_mb = (n * n * 4) / (1024 * 1024) # float32 # n² scaling factor if reference_time is None: reference_time = avg_time reference_n = n scaling = 1.0 else: expected_scaling = (n / reference_n) ** 2 actual_scaling = avg_time / reference_time scaling = actual_scaling print(f"{n:<10} {avg_time:<15.3f} {attn_memory_mb:<15.2f} {scaling:<15.2f}") return None # Run benchmarksequence_lengths = [128, 256, 512, 1024, 2048]benchmark_attention(sequence_lengths) print("\nNote: Time should scale ~quadratically with sequence length")print("Memory for attention matrix scales exactly quadratically")The Quadratic Bottleneck:
For long sequences, the O(n²) memory and compute become prohibitive:
| Sequence Length | Attention Matrix Size | Memory (float32) |
|---|---|---|
| 512 | 262K | 1 MB |
| 2048 | 4.2M | 16 MB |
| 8192 | 67M | 256 MB |
| 32768 | 1.07B | 4 GB |
| 131072 | 17.2B | 64 GB |
This motivates efficient attention variants (Linear Attention, Longformer, BigBird, FlashAttention) that reduce the O(n²) complexity while approximating full attention behavior.
Modern LLMs like GPT-4 and Claude handle 100K+ tokens. Naive O(n²) attention would require over 40GB just for the attention matrix. Efficient attention implementations (like FlashAttention) are essential—they reduce memory from O(n²) to O(n) through clever recomputation.
For completeness, here's how scaled dot-product attention is implemented in practice using PyTorch, with all the standard features:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
import torchimport torch.nn as nnimport torch.nn.functional as Fimport mathfrom typing import Optional, Tuple def scaled_dot_product_attention( query: torch.Tensor, # (batch, seq_q, d_k) key: torch.Tensor, # (batch, seq_k, d_k) value: torch.Tensor, # (batch, seq_k, d_v) attn_mask: Optional[torch.Tensor] = None, # (batch, seq_q, seq_k) or broadcastable dropout_p: float = 0.0, is_causal: bool = False, training: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: """ Scaled dot-product attention with all standard features. Args: query, key, value: Query, key, and value tensors attn_mask: Additive mask (0 for attend, -inf for don't attend) dropout_p: Dropout probability on attention weights is_causal: If True, apply causal masking training: Whether in training mode (affects dropout) Returns: output: (batch, seq_q, d_v) attn_weights: (batch, seq_q, seq_k) """ d_k = query.size(-1) # Compute attention scores scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # Apply causal mask if requested if is_causal: seq_q, seq_k = query.size(-2), key.size(-2) causal_mask = torch.triu( torch.ones(seq_q, seq_k, device=query.device, dtype=torch.bool), diagonal=1 ) scores.masked_fill_(causal_mask, float('-inf')) # Apply explicit mask if provided if attn_mask is not None: scores = scores + attn_mask # Softmax normalization attn_weights = F.softmax(scores, dim=-1) # Apply dropout if training and dropout_p > 0.0: attn_weights = F.dropout(attn_weights, p=dropout_p) # Compute output output = torch.matmul(attn_weights, value) return output, attn_weights class SelfAttention(nn.Module): """Self-attention module with QKV projections.""" def __init__(self, d_model: int, d_k: int, d_v: int, dropout: float = 0.1): super().__init__() self.d_k = d_k # QKV projections self.W_q = nn.Linear(d_model, d_k, bias=False) self.W_k = nn.Linear(d_model, d_k, bias=False) self.W_v = nn.Linear(d_model, d_v, bias=False) # Output projection self.W_o = nn.Linear(d_v, d_model, bias=False) self.dropout = dropout # Initialize with Xavier for module in [self.W_q, self.W_k, self.W_v, self.W_o]: nn.init.xavier_uniform_(module.weight) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False ) -> torch.Tensor: # Project to Q, K, V Q = self.W_q(x) K = self.W_k(x) V = self.W_v(x) # Apply scaled dot-product attention attn_output, _ = scaled_dot_product_attention( Q, K, V, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, training=self.training ) # Output projection output = self.W_o(attn_output) return output # Usage exampleif __name__ == "__main__": batch_size, seq_len, d_model = 2, 10, 64 d_k = d_v = 16 x = torch.randn(batch_size, seq_len, d_model) # Create self-attention layer attention = SelfAttention(d_model, d_k, d_v, dropout=0.1) # Forward pass (causal for decoder-style) output = attention(x, is_causal=True) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") # Verify causal masking attention.eval() # Disable dropout for checking Q = attention.W_q(x) K = attention.W_k(x) V = attention.W_v(x) _, weights = scaled_dot_product_attention(Q, K, V, is_causal=True, training=False) print(f"\nCausal attention weights (should be lower triangular):") print(weights[0].round(decimals=2))PyTorch 2.0+ includes F.scaled_dot_product_attention() which uses FlashAttention when available. This provides significant speedups and memory savings. Always prefer built-in implementations over custom ones for production.
We've thoroughly explored the scaled dot-product attention mechanism—the fundamental operation that enables transformers to learn complex dependencies.
What's Next:
With the core scaled dot-product attention understood, we'll examine a subtle but important property: position independence. Self-attention is inherently permutation-equivariant—it doesn't know where tokens are in the sequence. The next page explores this property and its implications.
You now understand why scaled dot-product attention uses the √d_k factor, how it prevents training pathologies, and the computational characteristics of the attention operation. This knowledge is essential for implementing, debugging, and optimizing attention-based models.