Loading problem...
The self-attention mechanism is the backbone of modern Transformer architectures, enabling models to capture long-range dependencies in sequences. However, standard attention computation faces a significant bottleneck: it requires materializing the full N × N attention matrix, where N is the sequence length. This leads to O(N²) memory consumption, which becomes prohibitive for long sequences (e.g., processing documents with thousands of tokens or high-resolution images).
Consider a sequence of length N = 8192 with attention computed in float32. The attention matrix alone would require:
This quickly exhausts GPU memory, limiting the practical sequence lengths that Transformer models can handle.
Tiled Attention (also known as chunked or blocked attention) solves this by computing attention in smaller tiles or blocks, never materializing the complete N × N matrix. The key insight is that attention can be computed incrementally using the online softmax algorithm, which maintains running statistics (maximum value and exponential sum) that allow proper normalization without seeing all values at once.
The algorithm processes Query (Q), Key (K), and Value (V) matrices in blocks:
The online softmax algorithm maintains two running statistics for each query position:
When processing a new block, these statistics are updated:
m_new = max(m_old, max(current_scores))
l_new = l_old × exp(m_old - m_new) + sum(exp(current_scores - m_new))
The output is incrementally adjusted:
output_new = output_old × (l_old × exp(m_old - m_new) / l_new) + attention_weights × V_tile
Implement the forward pass of the memory-efficient tiled attention algorithm. Given Query (Q), Key (K), and Value (V) matrices along with a block size, compute the attention output without ever creating the full N × N attention matrix.
The final output should match standard attention: Output = softmax(Q × K^T) × V, but computed in a memory-efficient blocked manner.
For standard attention: $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
Note: For this problem, we omit the scaling factor $\sqrt{d_k}$ for simplicity. The focus is on the tiled computation with online softmax.
Q = [[1.0, 0.0], [0.0, 1.0]]
K = [[1.0, 0.0], [0.0, 1.0]]
V = [[1.0, 2.0], [3.0, 4.0]]
block_size = 1[[1.5379, 2.5379], [2.4621, 3.4621]]With Q and K as identity-like matrices, the attention scores are:
• S = Q × Kᵀ = [[1, 0], [0, 1]]
Applying softmax row-wise: • Row 1: softmax([1, 0]) ≈ [0.7311, 0.2689] • Row 2: softmax([0, 1]) ≈ [0.2689, 0.7311]
The output is computed as attention_weights × V: • Row 1: 0.7311 × [1, 2] + 0.2689 × [3, 4] ≈ [1.5379, 2.5379] • Row 2: 0.2689 × [1, 2] + 0.7311 × [3, 4] ≈ [2.4621, 3.4621]
With block_size=1, each query is processed separately, computing attention incrementally using online softmax. The result matches standard attention but uses O(N) memory instead of O(N²).
Q = [[1.0, 0.0], [0.0, 1.0]]
K = [[1.0, 0.0], [0.0, 1.0]]
V = [[1.0, 2.0], [3.0, 4.0]]
block_size = 2[[1.5379, 2.5379], [2.4621, 3.4621]]With block_size=2, the entire sequence fits in a single block, so the computation proceeds in one step rather than iterating through tiles. However, the mathematical result is identical to block_size=1—only the memory footprint differs during computation.
This example demonstrates that the tiled algorithm produces the same output regardless of block size, while larger blocks reduce loop overhead and smaller blocks reduce peak memory usage.
Q = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
K = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
V = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
block_size = 1[[2.9075, 3.9075, 4.9075], [4.0, 5.0, 6.0], [5.0925, 6.0925, 7.0925]]For a 3×3 identity-like Q and K, the attention scores form a matrix where the diagonal has highest values (1.0) and off-diagonals are 0.0.
After softmax, each row becomes approximately [0.5761, 0.2119, 0.2119] when the query aligns with the first key, with similar patterns shifted for other rows.
The tiled algorithm with block_size=1 processes this by:
The center row produces [4.0, 5.0, 6.0] because the second query has highest attention on the second value, which happens to be [4, 5, 6] itself, with symmetric influence from neighbors.
Constraints