Loading content...
Rotary Position Embedding (RoPE), introduced by Su et al. in 2021, represents the current state of the art in positional encoding for transformer architectures. It has become the default choice for virtually all modern large language models, including LLaMA, Mistral, Yi, Qwen, and many others.
What makes RoPE special is its elegant synthesis of earlier ideas:
The core insight is deceptively simple: if we rotate query and key vectors based on their positions, the dot product between them will depend only on their relative position. This single geometric observation unifies absolute encoding, relative encoding, and computational efficiency.
This page provides a comprehensive treatment of RoPE: its mathematical derivation, geometric interpretation, implementation, extensions, and the reasons for its widespread adoption.
RoPE has become the de facto standard for positional encoding in large language models. LLaMA-1, LLaMA-2, LLaMA-3, Mistral, Mixtral, Yi, Qwen, CodeLLaMA, DeepSeek, and many other leading models all use RoPE or its extensions. Understanding RoPE is essential for working with modern LLM architectures.
The Central Problem
We want a function $f$ such that for query and key vectors $q_m$ and $k_n$ at positions $m$ and $n$:
$$\langle f(q, m), f(k, n) \rangle = g(q, k, m - n)$$
That is, the inner product depends on the content of $q$ and $k$ and the relative position $m - n$, but not on the absolute positions $m$ and $n$ individually.
The RoPE Solution
RoPE achieves this by rotating vectors in 2D subspaces:
$$f(x, m) = R_m \cdot x$$
Where $R_m$ is a rotation matrix determined by position $m$.
For a 2D vector $[x_1, x_2]^T$:
$$R_\theta = \begin{pmatrix} \cos\theta & -\sin\theta \ \sin\theta & \cos\theta \end{pmatrix}$$
The rotation angle $\theta$ varies with position: $\theta = m \cdot \omega$ for some frequency $\omega$.
The Key Property
For rotation matrices, we have:
$$\langle R_\alpha x, R_\beta y \rangle = \langle R_{\alpha - \beta} x, y \rangle = \langle x, R_{\beta - \alpha} y \rangle$$
This means the dot product of rotated vectors depends only on the difference of rotation angles—i.e., the relative position!
Imagine two arrows (vectors) on a plane. If you rotate both arrows by the same amount, the angle between them doesn't change. RoPE exploits this: rotating query and key by amounts proportional to their positions means their dot product depends only on the position difference.
Extension to High Dimensions
For a $d$-dimensional vector, RoPE applies independent rotations to $d/2$ pairs of dimensions. The full rotation matrix is block-diagonal:
$$R_m^{(d)} = \begin{pmatrix} R_{m\theta_1} & 0 & \cdots & 0 \ 0 & R_{m\theta_2} & \cdots & 0 \ \vdots & \vdots & \ddots & \vdots \ 0 & 0 & \cdots & R_{m\theta_{d/2}} \end{pmatrix}$$
Where each $R_{m\theta_i}$ is a 2×2 rotation matrix and $\theta_i$ are different frequencies:
$$\theta_i = 10000^{-2(i-1)/d}$$
This is the same geometric progression as sinusoidal positional encoding! And indeed, the sin/cos values from sinusoidal encoding reappear in the rotation matrices.
Let's derive RoPE rigorously from the requirement that the attention score encode relative position.
Setup
We seek a function $f(x, m): \mathbb{R}^d \times \mathbb{N} \to \mathbb{R}^d$ such that:
$$\langle f(q, m), f(k, n) \rangle = g(q, k, m - n)$$
for some function $g$.
Working in 2D First
Consider the 2D case. Represent vectors as complex numbers: $z = x_1 + ix_2 \in \mathbb{C}$.
The inner product of real parts of complex numbers relates to: $$\text{Re}(z_1^* z_2) = x_1 y_1 + x_2 y_2 = \langle (x_1, x_2), (y_1, y_2) \rangle$$
If we define: $$f(z, m) = z \cdot e^{im\theta}$$
Then: $$f(q, m)^* \cdot f(k, n) = q^* e^{-im\theta} \cdot k \cdot e^{in\theta} = q^* k \cdot e^{i(n-m)\theta}$$
Taking the real part: $$\text{Re}(q^* k \cdot e^{i(n-m)\theta}) = g(q, k, n - m)$$
This is exactly what we wanted: the result depends on $q$, $k$, and $(n - m)$ only!
Translating to Real Matrices
The complex multiplication $z \cdot e^{i\theta}$ corresponds to the real matrix operation:
$$\begin{pmatrix} x_1 \ x_2 \end{pmatrix} \to \begin{pmatrix} \cos\theta & -\sin\theta \ \sin\theta & \cos\theta \end{pmatrix} \begin{pmatrix} x_1 \ x_2 \end{pmatrix}$$
This is a rotation by angle $\theta$.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
import numpy as npimport torch def verify_rope_relative_position(): """ Verify that RoPE achieves relative position encoding. The dot product f(q,m)·f(k,n) should depend only on m-n. """ def rotate_2d(x, theta): """Apply 2D rotation.""" cos_t = np.cos(theta) sin_t = np.sin(theta) R = np.array([[cos_t, -sin_t], [sin_t, cos_t]]) return R @ x # Random query and key vectors np.random.seed(42) q = np.random.randn(2) k = np.random.randn(2) # Rotation frequency theta = 0.1 # radians per position # Test at different absolute positions with same relative offset print("=== Verifying RoPE Relative Position Property ===\n") print("q·k (unrotated):", np.dot(q, k)) print() # Relative offset = 5 relative_offset = 5 test_cases = [ (0, 5), # m=0, n=5 (10, 15), # m=10, n=15 (100, 105), # m=100, n=105 (1000, 1005) # m=1000, n=1005 ] print(f"All pairs have relative offset n-m = {relative_offset}") print("-" * 50) for m, n in test_cases: q_rotated = rotate_2d(q, m * theta) k_rotated = rotate_2d(k, n * theta) dot_product = np.dot(q_rotated, k_rotated) print(f"m={m:4d}, n={n:4d}: f(q,m)·f(k,n) = {dot_product:.6f}") print() print("All dot products are identical (within floating point precision)") print("because they depend only on relative position, not absolute positions!") # Also verify: change relative offset, dot product changes print("\n=== Different Relative Offsets ===\n") m = 50 # Fixed query position for offset in [0, 1, 5, 10, 20]: n = m + offset q_rotated = rotate_2d(q, m * theta) k_rotated = rotate_2d(k, n * theta) dot_product = np.dot(q_rotated, k_rotated) print(f"Offset {offset:3d}: f(q,m)·f(k,m+{offset}) = {dot_product:.6f}") verify_rope_relative_position() def show_complex_interpretation(): """ Show the complex number interpretation of RoPE. """ print("\n=== Complex Number Interpretation ===\n") # Represent 2D vector as complex number q = 1.0 + 0.5j # q = (1.0, 0.5) k = 0.8 - 0.3j # k = (0.8, -0.3) theta = 0.1 m, n = 10, 15 # positions # Rotate using complex exponential q_rotated = q * np.exp(1j * m * theta) k_rotated = k * np.exp(1j * n * theta) # Inner product via real part of conjugate product dot_complex = np.real(np.conj(q_rotated) * k_rotated) # Compare to direct rotation def rotate_2d(x, angle): R = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) return R @ x q_vec = np.array([1.0, 0.5]) k_vec = np.array([0.8, -0.3]) q_rot = rotate_2d(q_vec, m * theta) k_rot = rotate_2d(k_vec, n * theta) dot_real = np.dot(q_rot, k_rot) print(f"Complex multiplication: {dot_complex:.6f}") print(f"Matrix rotation: {dot_real:.6f}") print("Both methods give identical results!") show_complex_interpretation()While the complex number derivation is elegant, practical implementations use real arithmetic (2D rotations via sin/cos) for compatibility with standard deep learning frameworks. Both formulations are mathematically equivalent.
Let's implement RoPE in its full form, used in production models like LLaMA.
The RoPE Formula
For a vector $x = [x_0, x_1, \ldots, x_{d-1}]$ at position $m$:
$$\text{RoPE}(x, m) = \begin{pmatrix} x_0 \cos(m\theta_0) - x_1 \sin(m\theta_0) \ x_1 \cos(m\theta_0) + x_0 \sin(m\theta_0) \ x_2 \cos(m\theta_1) - x_3 \sin(m\theta_1) \ x_3 \cos(m\theta_1) + x_2 \sin(m\theta_1) \ \vdots \ x_{d-2} \cos(m\theta_{d/2-1}) - x_{d-1} \sin(m\theta_{d/2-1}) \ x_{d-1} \cos(m\theta_{d/2-1}) + x_{d-2} \sin(m\theta_{d/2-1}) \end{pmatrix}$$
Where $\theta_i = 10000^{-2i/d}$.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
import torchimport torch.nn as nnimport numpy as npfrom typing import Tuple, Optional class RotaryPositionalEmbedding(nn.Module): """ Rotary Position Embedding (RoPE) as used in LLaMA and modern LLMs. Key insight: Apply rotation to query and key vectors based on their positions. The dot product of rotated vectors automatically encodes relative position. """ def __init__( self, dim: int, max_seq_len: int = 8192, base: float = 10000.0 ): """ Args: dim: Dimension of the embedding (must be even) max_seq_len: Maximum sequence length to precompute base: Base for frequency computation (10000 in original) """ super().__init__() assert dim % 2 == 0, "Dimension must be even for RoPE" self.dim = dim self.max_seq_len = max_seq_len self.base = base # Precompute frequencies # theta_i = base^(-2i/dim) for i in [0, dim/2) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) # Precompute sin and cos for all positions self._precompute_cache(max_seq_len) def _precompute_cache(self, seq_len: int): """Precompute sin and cos values for efficiency.""" # Position indices: [0, 1, 2, ..., seq_len-1] t = torch.arange(seq_len, dtype=self.inv_freq.dtype) # Outer product: [seq_len, dim/2] # freqs[pos, i] = pos * theta_i freqs = torch.outer(t, self.inv_freq) # Duplicate for pairs: [seq_len, dim] # This creates [theta_0, theta_0, theta_1, theta_1, ...] freqs = torch.cat([freqs, freqs], dim=-1) # Compute sin and cos self.register_buffer("cos_cached", freqs.cos()) self.register_buffer("sin_cached", freqs.sin()) def forward( self, q: torch.Tensor, k: torch.Tensor, start_pos: int = 0 ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to query and key tensors. Args: q: Query tensor [batch, seq_len, num_heads, head_dim] k: Key tensor [batch, seq_len, num_heads, head_dim] start_pos: Starting position (for incremental decoding) Returns: Tuple of (rotated_q, rotated_k) with same shapes """ seq_len = q.shape[1] # Extend cache if needed if start_pos + seq_len > self.max_seq_len: self._precompute_cache(start_pos + seq_len) # Get relevant sin and cos values cos = self.cos_cached[start_pos:start_pos + seq_len] sin = self.sin_cached[start_pos:start_pos + seq_len] # Apply rotation q_rotated = self._apply_rotation(q, cos, sin) k_rotated = self._apply_rotation(k, cos, sin) return q_rotated, k_rotated def _apply_rotation( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> torch.Tensor: """ Apply rotation to tensor x. The rotation formula: [x0, x1] -> [x0*cos - x1*sin, x1*cos + x0*sin] We implement this efficiently using: x_rotated = x * cos + rotate_half(x) * sin where rotate_half swaps adjacent pairs and negates the first: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2] """ # x: [batch, seq_len, num_heads, head_dim] # cos, sin: [seq_len, head_dim] # Add dimensions for broadcasting cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, head_dim] sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, head_dim] # Rotate half: [-x1, x0, -x3, x2, ...] x_rotated_half = self._rotate_half(x) # Apply rotation formula return x * cos + x_rotated_half * sin def _rotate_half(self, x: torch.Tensor) -> torch.Tensor: """ Rotate half: rearrange pairs and negate first element. [x0, x1, x2, x3, ...] -> [-x1, x0, -x3, x2, ...] """ # Split into two halves x1, x2 = x[..., ::2], x[..., 1::2] # Interleave: [-x2, x1] return torch.stack([-x2, x1], dim=-1).flatten(-2) class RoPEAttention(nn.Module): """ Complete attention layer with RoPE. This is the structure used in LLaMA and similar models. """ def __init__( self, d_model: int, num_heads: int, num_kv_heads: Optional[int] = None, # For grouped-query attention max_seq_len: int = 8192, rope_base: float = 10000.0, dropout: float = 0.0 ): super().__init__() self.num_heads = num_heads self.num_kv_heads = num_kv_heads if num_kv_heads else num_heads self.head_dim = d_model // num_heads # Projections self.W_Q = nn.Linear(d_model, num_heads * self.head_dim, bias=False) self.W_K = nn.Linear(d_model, self.num_kv_heads * self.head_dim, bias=False) self.W_V = nn.Linear(d_model, self.num_kv_heads * self.head_dim, bias=False) self.W_O = nn.Linear(num_heads * self.head_dim, d_model, bias=False) # RoPE self.rope = RotaryPositionalEmbedding( dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base ) self.dropout = nn.Dropout(dropout) self.scale = self.head_dim ** -0.5 def forward( self, x: torch.Tensor, start_pos: int = 0, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: batch_size, seq_len, _ = x.shape # Project Q = self.W_Q(x).view(batch_size, seq_len, self.num_heads, self.head_dim) K = self.W_K(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim) V = self.W_V(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim) # Apply RoPE to Q and K Q, K = self.rope(Q, K, start_pos) # Expand KV heads if using grouped-query attention if self.num_kv_heads < self.num_heads: repeats = self.num_heads // self.num_kv_heads K = K.repeat(1, 1, repeats, 1) V = V.repeat(1, 1, repeats, 1) # Transpose for attention: [batch, heads, seq, head_dim] Q = Q.transpose(1, 2) K = K.transpose(1, 2) V = V.transpose(1, 2) # Compute attention scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale if mask is not None: scores = scores + mask attn_weights = torch.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) output = torch.matmul(attn_weights, V) # Reshape and project output = output.transpose(1, 2).contiguous() output = output.view(batch_size, seq_len, -1) output = self.W_O(output) return output # Demonstrationdef demonstrate_rope(): batch_size = 2 seq_len = 128 d_model = 512 num_heads = 8 attention = RoPEAttention( d_model=d_model, num_heads=num_heads, max_seq_len=4096 ) x = torch.randn(batch_size, seq_len, d_model) # Causal mask mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) mask = mask.masked_fill(mask == 1, float('-inf')) output = attention(x, mask=mask) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") print(f"Head dimension: {d_model // num_heads}") # Count parameters total_params = sum(p.numel() for p in attention.parameters()) print(f"Total parameters: {total_params:,}") print("Note: No position embedding parameters! RoPE is computed dynamically.") demonstrate_rope()The _rotate_half trick is key to efficient implementation. Instead of constructing sparse rotation matrices, we use element-wise operations. The rotation x*cos + rotate_half(x)*sin achieves the same result as matrix multiplication but is much faster on modern hardware.
RoPE has a beautiful geometric interpretation that provides deep intuition for its behavior.
The Rotation Manifold
Consider the d-dimensional embedding space as d/2 independent 2D planes. In each plane, moving forward one position rotates the vector by a fixed angle θ_i.
Visualization in 2D
Imagine a single 2D plane:
The query and key at positions m and n have:
Their dot product depends on the angle between them: $(\phi_q + m\theta) - (\phi_k + n\theta) = (\phi_q - \phi_k) + (m-n)\theta$
The content-dependent term $(\phi_q - \phi_k)$ and the position-dependent term $(m-n)\theta$ combine naturally!
Multi-Scale Rotation
Different frequency components (different $\theta_i$) create rotations at different rates:
| Dimension Pair | θ_i (radians/position) | Full Rotation Period | Sensitivity |
|---|---|---|---|
| 0-1 (fastest) | 1.0 | ≈6.3 positions | Adjacent tokens |
| 32-33 | ≈0.01 | ≈628 positions | Sentence level |
| 62-63 (slowest) | ≈0.0001 | ≈62,832 positions | Document level |
Phase Encoding
Each position can be thought of as a unique point in a high-dimensional phase space, where the $i$-th component encodes position modulo the wavelength $2\pi/\theta_i$.
This is analogous to:
Why Rotation Works Better Than Addition
Compare RoPE to additive sinusoidal encoding:
| Aspect | Additive (Sinusoidal) | Multiplicative (RoPE) |
|---|---|---|
| Integration | Adds to embeddings | Multiplies in attention |
| Content/Position | Entangled | Cleanly separated |
| Relative Position | Must be learned | Automatic via geometry |
| Norm preservation | Changes norms | Preserves norms |
A key advantage of rotation over addition: rotation preserves vector norms. ||R_θ x|| = ||x||. This means positional encoding doesn't change the scale of embeddings, avoiding potential training instabilities.
RoPE provides inherent length extrapolation capabilities, but extending context length in practice requires additional techniques. Several extensions have been developed:
Direct Extrapolation Challenges
While RoPE can compute rotations for any position, models trained on short contexts often fail at longer lengths:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
import torchimport numpy as np class RoPEWithExtensions: """ RoPE with various extension methods for length extrapolation. """ @staticmethod def linear_interpolation( positions: torch.Tensor, train_length: int, target_length: int ) -> torch.Tensor: """ Position Interpolation: Scale positions to fit training range. Example: If trained on 2048, want to use 8192: positions 0, 4096, 8191 become 0, 1024, 2047 """ scale = train_length / target_length return positions * scale @staticmethod def ntk_aware_scaling( original_base: float, scale_factor: float, dim: int ) -> float: """ NTK-aware Interpolation: Adjust the base frequency. This adjusts wavelengths so that relative position encoding remains meaningful at longer lengths. """ # The key insight: adjusting base by scale^(dim/(dim-2)) # preserves low-frequency components while compressing high-frequency new_base = original_base * (scale_factor ** (dim / (dim - 2))) return new_base @staticmethod def dynamic_ntk_scaling( seq_len: int, original_max_len: int, original_base: float, dim: int ) -> float: """ Dynamic NTK: Automatically adjust base based on current length. Only scales when sequence exceeds training length. """ if seq_len <= original_max_len: return original_base scale_factor = seq_len / original_max_len return original_base * (scale_factor ** (dim / (dim - 2))) @staticmethod def compute_yarn_frequencies( original_base: float, dim: int, scale_factor: float, beta_fast: float = 32, beta_slow: float = 1, mscale: float = 1.0 ) -> torch.Tensor: """ YaRN: Yet another RoPE extensioN Applies different interpolation strategies to different frequency bands for optimal extrapolation. """ # Compute original frequencies freq_indices = torch.arange(0, dim, 2).float() freqs = 1.0 / (original_base ** (freq_indices / dim)) # Compute interpolation factors per frequency # Low frequencies (slow rotation) get more interpolation # High frequencies (fast rotation) get less interpolation # Wavelength of each frequency component wavelengths = 2 * np.pi / freqs # Compute "ramp" - how much to interpolate each frequency # Ranges from 0 (no interpolation) to 1 (full interpolation) low_bound = wavelengths / (2 * np.pi * beta_fast) high_bound = wavelengths / (2 * np.pi * beta_slow) ramp = torch.clamp( (wavelengths / scale_factor - low_bound) / (high_bound - low_bound), 0.0, 1.0 ) # Interpolate frequencies based on ramp interpolated_freqs = freqs / scale_factor freqs_yarn = freqs * (1 - ramp) + interpolated_freqs * ramp return freqs_yarn def visualize_rope_extensions(): """Visualize how different RoPE extensions affect frequencies.""" dim = 64 original_base = 10000.0 original_max_len = 2048 target_len = 8192 scale_factor = target_len / original_max_len # Original frequencies freq_indices = torch.arange(0, dim, 2).float() original_freqs = 1.0 / (original_base ** (freq_indices / dim)) # Position Interpolation (doesn't change frequencies, just positions) pi_freqs = original_freqs # Same frequencies, scaled positions # NTK-aware ntk_base = RoPEWithExtensions.ntk_aware_scaling( original_base, scale_factor, dim ) ntk_freqs = 1.0 / (ntk_base ** (freq_indices / dim)) # YaRN yarn_freqs = RoPEWithExtensions.compute_yarn_frequencies( original_base, dim, scale_factor ) print("=== RoPE Extension Comparison ===\n") print(f"Extending from {original_max_len} to {target_len} tokens") print(f"Scale factor: {scale_factor}x\n") print("Frequency comparison (first 5 dimension pairs):") print("-" * 60) print(f"{'Pair':<8} {'Original':>12} {'NTK':>12} {'YaRN':>12}") print("-" * 60) for i in range(5): print(f"{i*2}-{i*2+1:<5} {original_freqs[i]:.6f} " f"{ntk_freqs[i]:.6f} {yarn_freqs[i]:.6f}") print() print("Wavelength comparison (positions for full rotation):") print("-" * 60) for i in [0, dim//4 - 1, dim//2 - 1]: orig_wl = 2 * np.pi / original_freqs[i] ntk_wl = 2 * np.pi / ntk_freqs[i] yarn_wl = 2 * np.pi / yarn_freqs[i] print(f"Pair {i*2:2d}: Original {orig_wl:8.1f}, " f"NTK {ntk_wl:8.1f}, YaRN {yarn_wl:8.1f}") visualize_rope_extensions()For extending pre-trained models: (1) Start with Position Interpolation—it's simplest and often works. (2) If quality degrades, try NTK-aware scaling or YaRN. (3) For best results, do continued training on longer sequences. Most production systems use YaRN or similar for 4x+ extensions.
Let's systematically compare RoPE against other positional encoding approaches:
| Aspect | Sinusoidal | Learned | T5 Bias | RoPE | ALiBi |
|---|---|---|---|---|---|
| Parameters | 0 | O(L×d) | O(buckets×heads) | 0 | 0 |
| Relative Position | Indirect | No | Direct | Automatic | Direct |
| Length Extrapolation | Moderate | None | Moderate | Good | Excellent |
| Computation | Low | Low | Low | Medium | Very Low |
| Integration Point | Embeddings | Embeddings | Attention logits | Q,K vectors | Attention logits |
| Norm Preservation | No | No | N/A | Yes | N/A |
| Adoption (2024) | Legacy | BERT-era | T5 family | Most LLMs | Some LLMs |
RoPE vs. ALiBi
ALiBi (Attention with Linear Biases) is RoPE's main competitor for modern LLMs. Key differences:
ALiBi Approach:
When to Choose RoPE:
When to Choose ALiBi:
Most modern LLMs use RoPE, but ALiBi has shown competitive results with potentially better extrapolation in some benchmarks.
When implementing RoPE in production systems, several practical considerations arise:
Memory and Computation
RoPE computation is efficient but not free:
Hardware Optimization
Modern implementations often:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
import torchimport torch.nn as nnfrom typing import Optional, Tuple class OptimizedRoPE(nn.Module): """ Production-optimized RoPE implementation with: - Precomputed sin/cos tables - Support for KV-cache (incremental decoding) - Half precision support - Efficient memory layout """ def __init__( self, head_dim: int, max_seq_len: int = 8192, base: float = 10000.0, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32 ): super().__init__() self.head_dim = head_dim self.max_seq_len = max_seq_len self.base = base # Precompute frequency table inv_freq = 1.0 / ( base ** (torch.arange(0, head_dim, 2, dtype=dtype, device=device) / head_dim) ) self.register_buffer("inv_freq", inv_freq) # Precompute sin/cos for max_seq_len self._compute_cache(max_seq_len, device, dtype) def _compute_cache( self, seq_len: int, device: Optional[torch.device], dtype: torch.dtype ): """Compute and cache sin/cos values.""" positions = torch.arange(seq_len, dtype=dtype, device=device) # [seq_len, head_dim/2] freqs = torch.outer(positions, self.inv_freq) # [seq_len, head_dim] emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cache", emb.cos(), persistent=False) self.register_buffer("sin_cache", emb.sin(), persistent=False) def forward( self, q: torch.Tensor, # [batch, seq_len, heads, head_dim] k: torch.Tensor, # [batch, seq_len, kv_heads, head_dim] position_ids: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply RoPE to queries and keys. Args: q, k: Query and key tensors position_ids: Optional explicit position indices [batch, seq_len] Used for KV-cache scenarios """ seq_len = q.size(1) if position_ids is None: # Standard case: positions are 0, 1, 2, ..., seq_len-1 cos = self.cos_cache[:seq_len] sin = self.sin_cache[:seq_len] else: # Custom positions (e.g., for KV-cache continuation) cos = self.cos_cache[position_ids].squeeze(0) sin = self.sin_cache[position_ids].squeeze(0) # Apply rotation q_embed = self._apply_rope(q, cos, sin) k_embed = self._apply_rope(k, cos, sin) return q_embed, k_embed def _apply_rope( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> torch.Tensor: """ Efficient RoPE application. x: [batch, seq_len, heads, head_dim] cos, sin: [seq_len, head_dim] """ # Expand for broadcasting # [1, seq_len, 1, head_dim] cos = cos.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2) # Split x into even and odd indices x_even = x[..., 0::2] x_odd = x[..., 1::2] # Create rotated half: [-x_odd, x_even] interleaved x_rotate = torch.stack([-x_odd, x_even], dim=-1).flatten(-2) # Apply rotation formula return x * cos + x_rotate * sin def extend_cache(self, new_max_len: int): """Extend cache for longer sequences.""" if new_max_len <= self.max_seq_len: return self._compute_cache( new_max_len, self.inv_freq.device, self.inv_freq.dtype ) self.max_seq_len = new_max_len class RoPEWithKVCache: """ Helper for using RoPE with KV-cache during inference. In autoregressive decoding, we cache past keys/values and only compute attention for the new token(s). """ def __init__(self, rope: OptimizedRoPE): self.rope = rope self.current_position = 0 def decode_step( self, q: torch.Tensor, # [batch, 1, heads, head_dim] - new query k: torch.Tensor, # [batch, 1, kv_heads, head_dim] - new key cached_k: torch.Tensor, # [batch, past_len, kv_heads, head_dim] cached_v: torch.Tensor # [batch, past_len, kv_heads, head_dim] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Process one decoding step. Returns: - Rotated query - Updated key cache (with new rotated key) - Updated value cache (with new value) """ # Position for new token position_ids = torch.tensor([[self.current_position]], device=q.device) # Apply RoPE to new query and key q_rot, k_rot = self.rope(q, k, position_ids=position_ids) # Append to caches new_cached_k = torch.cat([cached_k, k_rot], dim=1) new_cached_v = torch.cat([cached_v, cached_v[:, -1:]], dim=1) # Assuming v passed separately self.current_position += 1 return q_rot, new_cached_k, new_cached_v def reset(self): """Reset position counter for new sequence.""" self.current_position = 0 # Benchmarkdef benchmark_rope(): """Benchmark RoPE performance.""" import time device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') batch_size = 8 seq_len = 2048 num_heads = 32 head_dim = 128 rope = OptimizedRoPE(head_dim=head_dim, max_seq_len=seq_len, device=device) q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device) k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device) # Warmup for _ in range(10): q_rot, k_rot = rope(q, k) if device.type == 'cuda': torch.cuda.synchronize() # Benchmark start = time.time() num_iters = 100 for _ in range(num_iters): q_rot, k_rot = rope(q, k) if device.type == 'cuda': torch.cuda.synchronize() elapsed = time.time() - start print(f"Device: {device}") print(f"Shape: [{batch_size}, {seq_len}, {num_heads}, {head_dim}]") print(f"Average time per call: {elapsed / num_iters * 1000:.3f} ms") print(f"Throughput: {num_iters * batch_size * seq_len / elapsed / 1e6:.2f} M tokens/second") benchmark_rope()When using RoPE with KV-caching for autoregressive generation, you must use explicit position indices. The cached keys were rotated with their original positions; the new query must be rotated with the current position. Failing to track positions correctly breaks the generation.
Rotary Position Embeddings represent the current culmination of positional encoding research, elegantly solving the problems that motivated earlier approaches. Let's consolidate the key insights:
The Positional Encoding Journey
Over the course of this module, we've traced the evolution of positional encoding:
Each approach built upon insights from its predecessors, culminating in the modern understanding embodied by RoPE.
Future Directions
Research continues to push positional encoding forward:
Understanding RoPE and its predecessors provides the foundation for engaging with these ongoing advances.
Congratulations! You've completed the Positional Encoding module. From the fundamental problem of transformer position-blindness through sinusoidal encoding, learned embeddings, relative position methods, and finally RoPE, you now have comprehensive understanding of how modern transformers encode sequential structure. This knowledge is essential for understanding, implementing, and improving transformer-based models.