Loading content...
The Causal Attention Mechanism is a fundamental component in autoregressive sequence models, forming the backbone of modern large language models (LLMs) like GPT, LLaMA, and Claude. Unlike standard self-attention, causal attention enforces a temporal ordering constraint that prevents tokens from attending to future positions in the sequence—ensuring that predictions at each position depend only on previously generated content.
The standard self-attention mechanism, introduced in the groundbreaking "Attention Is All You Need" paper, allows each position in a sequence to dynamically focus on relevant information from all other positions. Given three learnable projections—Queries (Q), Keys (K), and Values (V)—the mechanism computes attention as follows:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$
Where:
The dot product (QK^T) produces a matrix of attention scores (or "affinities") where each entry ((i, j)) indicates how much position (i) should attend to position (j).
In autoregressive generation tasks (like text generation), the model must predict the next token based only on previous tokens. If the model could "see" future tokens during training, it would learn to cheat rather than truly model the sequence distribution.
The causal mask enforces this constraint by setting attention scores for future positions to (-\infty) (or a very large negative number like (-10^9)) before applying softmax. Since (\exp(-\infty) = 0), these positions contribute nothing to the weighted sum:
$$\text{Causal Attention}(Q, K, V, M) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right) V$$
Where the mask M is an upper triangular matrix of (-\infty) values:
$$M_{ij} = \begin{cases} 0 & \text{if } i \geq j \ -\infty & \text{if } i < j \end{cases}$$
Your implementation should follow these steps precisely:
Compute Raw Attention Scores: Calculate (S = \frac{QK^T}{\sqrt{d_k}}) where (d_k) is the number of columns in (Q) (and (K))
Apply the Causal Mask: Add the mask matrix to the scores: (S_{\text{masked}} = S + M)
Normalize with Softmax: Apply row-wise softmax to obtain attention weights: (W = \text{softmax}(S_{\text{masked}}, \text{axis}=-1))
Aggregate Values: Compute the final output as (O = W \cdot V)
Round Results: Round each output element to 4 decimal places
Implement the causal_attention function that takes Query (Q), Key (K), Value (V) matrices, and an attention mask as inputs, and returns the attention-weighted output matrix. The mask will be provided as input with 0.0 for allowed positions and large negative values for masked positions.
Q = [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]
K = [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]
V = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]
mask = [[0.0, -1e9, -1e9], [0.0, 0.0, -1e9], [0.0, 0.0, 0.0]][[1.0, 2.0], [2.3395, 3.3395], [3.5105, 4.5105]]This is a 3-position sequence with a causal (lower-triangular) mask.
Step 1: Compute QK^T / sqrt(d_k) With d_k = 2, we scale by 1/√2 ≈ 0.707: • Position 0: Only attends to itself → score with position 0 dominates • Position 1: Can attend to positions 0 and 1 • Position 2: Can attend to all three positions
Step 2: Apply Mask The mask zeros out upper-triangular entries (future positions) by adding -1e9, which becomes 0 after softmax.
Step 3: Softmax (row-wise)
• Row 0: [1.0, 0.0, 0.0] → Only position 0 contributes
• Row 1: [0.3395, 0.6605, 0.0] → Weighted combination of positions 0 and 1
• Row 2: Distributed across all positions based on scaled dot-product similarities
Step 4: Weighted Value Aggregation • Output[0] = 1.0 × V[0] = [1.0, 2.0] • Output[1] ≈ 0.3395 × V[0] + 0.6605 × V[1] = [2.3395, 3.3395] • Output[2] = weighted sum across all V rows = [3.5105, 4.5105]
Q = [[1.0, 2.0], [3.0, 4.0]]
K = [[1.0, 0.0], [0.0, 1.0]]
V = [[5.0, 6.0], [7.0, 8.0]]
mask = [[0.0, 0.0], [0.0, 0.0]][[6.3395, 7.3395], [6.3395, 7.3395]]This example uses a zero mask (no masking), allowing full bidirectional attention.
Attention Score Computation: • Q @ K^T gives the raw dot products between queries and keys • After scaling by 1/√2 and applying softmax, both positions compute similar attention distributions
Value Aggregation: Both output rows contain [6.3395, 7.3395], which represents a weighted average of V[0] = [5.0, 6.0] and V[1] = [7.0, 8.0]. The weights skew slightly toward V[1] due to the scaled attention scores.
This demonstrates that without causal masking, all positions can freely attend to each other.
Q = [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
K = [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
V = [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0], [4.0, 4.0, 4.0, 4.0]]
mask = [[0.0, -1e9, -1e9, -1e9], [0.0, 0.0, -1e9, -1e9], [0.0, 0.0, 0.0, -1e9], [0.0, 0.0, 0.0, 0.0]][[1.0, 1.0, 1.0, 1.0], [1.6225, 1.6225, 1.6225, 1.6225], [2.1778, 2.1778, 2.1778, 2.1778], [2.7093, 2.7093, 2.7093, 2.7093]]This is a 4-position sequence with identity-like Q and K matrices (orthogonal rows).
Causal Masking Effect: The strictly lower-triangular mask ensures: • Position 0: Can only see itself → Output = V[0] = [1.0, 1.0, 1.0, 1.0] • Position 1: Sees positions 0-1 → Weighted average trending toward V[1] • Position 2: Sees positions 0-2 → More balanced average • Position 3: Sees all positions → Broadest average (but weighted by attention)
Progressive Averaging: Notice how each subsequent position's output value increases as it can attend to more previous positions. The orthogonal Q/K structure means positions have highest attention to themselves (due to the 1/√d_k scaled score of 0.5 vs 0.0 for non-matching positions).
The outputs [1.0, 1.6225, 2.1778, 2.7093] show the cumulative running average effect inherent in causal attention.
Constraints