Loading learning content...
While attention mechanisms receive most of the spotlight, the position-wise feed-forward networks (FFN) are where the majority of parameters and computation reside in a Transformer. In the standard configuration, the FFN contains two-thirds of each layer's parameters and performs the bulk of the non-linear transformation.
If attention is the mechanism for gathering information across positions, the feed-forward network is the mechanism for processing that information at each position. It applies a learned non-linear transformation independently to each position's representation, adding the computational depth necessary for complex reasoning.
This page provides a comprehensive examination of feed-forward layers: their architecture, mathematical properties, role in the Transformer, activation function choices, and modern innovations that have improved efficiency and performance.
In a standard Transformer layer with d_model=512 and d_ff=2048, the FFN has 2×512×2048 = 2.1M parameters (two linear layers), while multi-head attention has 4×512×512 = 1.05M parameters (Q, K, V, O projections). The FFN dominates!
The position-wise feed-forward network applies the same neural network independently to each position in the sequence. "Position-wise" means no information flows between positions within the FFN—that's attention's job.
Standard FFN Formulation
For input $x \in \mathbb{R}^d$ (representation at a single position), the FFN computes:
$$\text{FFN}(x) = W_2 \cdot \sigma(W_1 x + b_1) + b_2$$
where:
Dimensionality Flow
$$x \in \mathbb{R}^{d_{model}} \xrightarrow{W_1} h \in \mathbb{R}^{d_{ff}} \xrightarrow{\sigma} h' \in \mathbb{R}^{d_{ff}} \xrightarrow{W_2} y \in \mathbb{R}^{d_{model}}$$
The intermediate dimension $d_{ff}$ is typically $4 \times d_{model}$:
This expansion-contraction pattern is sometimes called a "bottleneck" in reverse—it's actually an expansion that increases capacity before projecting back down.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
import torchimport torch.nn as nn class PositionWiseFFN(nn.Module): """ Position-wise Feed-Forward Network (FFN) as used in the original Transformer. Architecture: Linear(d_model → d_ff) → ReLU → Linear(d_ff → d_model) The same network is applied independently to each position in the sequence. No information flows between positions within the FFN. """ def __init__( self, d_model: int = 512, d_ff: int = 2048, dropout: float = 0.1, activation: str = "relu" ): super().__init__() # First linear projection: expand to intermediate dimension self.w1 = nn.Linear(d_model, d_ff) # Second linear projection: contract back to model dimension self.w2 = nn.Linear(d_ff, d_model) # Activation function if activation == "relu": self.activation = nn.ReLU() elif activation == "gelu": self.activation = nn.GELU() else: raise ValueError(f"Unknown activation: {activation}") self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply position-wise FFN. Args: x: Input tensor of shape [batch_size, seq_len, d_model] Returns: Output tensor of shape [batch_size, seq_len, d_model] """ # Expand: d_model → d_ff hidden = self.w1(x) # Apply non-linearity hidden = self.activation(hidden) # Apply dropout (after activation, before second projection) hidden = self.dropout(hidden) # Contract: d_ff → d_model output = self.w2(hidden) return output def visualize_ffn_dimensions(): """Demonstrate the dimension flow through FFN.""" d_model, d_ff = 512, 2048 batch_size, seq_len = 2, 10 ffn = PositionWiseFFN(d_model, d_ff) x = torch.randn(batch_size, seq_len, d_model) # Trace through the network print(f"Input shape: {x.shape}") print(f" → {[batch_size, seq_len, d_model]}") with torch.no_grad(): h1 = ffn.w1(x) print(f"After W1: {h1.shape}") print(f" → Expanded from {d_model} to {d_ff} (4x)") h2 = ffn.activation(h1) print(f"After activation: {h2.shape} (same)") y = ffn.w2(h2) print(f"After W2: {y.shape}") print(f" → Contracted back to {d_model}") # Parameter count w1_params = d_model * d_ff + d_ff # weights + bias w2_params = d_ff * d_model + d_model # weights + bias total = w1_params + w2_params print(f"Parameter count:") print(f" W1: {d_model} × {d_ff} + {d_ff} = {w1_params:,}") print(f" W2: {d_ff} × {d_model} + {d_model} = {w2_params:,}") print(f" Total: {total:,}") visualize_ffn_dimensions()The FFN is applied identically and independently to each position. This means the same weights W₁, W₂ are used for all positions, but each position's computation depends only on its own input—there's no cross-position information flow within the FFN.
What exactly does the FFN do? Why is this expansion-contraction architecture necessary? Understanding the FFN's role requires examining both theoretical and empirical perspectives.
The Mixing and Processing Interpretation
Consider the flow through a Transformer layer:
The attention mechanism can be viewed as a content-based retrieval operation—positions query other positions for relevant information. But attention is fundamentally a linear operation (weighted sum). The FFN provides necessary non-linearity and increased capacity for learning complex transformations.
The "Key-Value Memory" Perspective
Geva et al. (2021) provided an influential interpretation: FFN layers function as a form of learned key-value memory. The first layer $W_1$ acts as keys, and the second layer $W_2$ acts as values:
$$\text{FFN}(x) = \sum_{i=1}^{d_{ff}} \text{softmax-like}(W_1^{(i)} \cdot x) \cdot W_2^{(i)}$$
Each "slot" in the intermediate dimension corresponds to a learned pattern (key) and associated output (value). When the input matches a key (high activation), the corresponding value contributes to the output.
This explains why FFN width matters: more intermediate dimensions = more patterns the network can learn and recall.
Theoretical Analysis: Expressiveness
The FFN provides essential expressiveness. Consider what attention alone (without FFN) could compute:
The FFN adds:
The Universal Approximation Connection
A two-layer FFN with ReLU and sufficient intermediate dimension can approximate any continuous function to arbitrary precision (universal approximation theorem). While the full Transformer is overparameterized, having this approximation power at each layer and position enables learning very complex input-output mappings.
A useful mental model: attention gathers relevant information from across the sequence into each position's representation, then the FFN applies a learned transformation to process and refine that gathered information. Both are essential—attention without FFN lacks depth; FFN without attention lacks cross-position reasoning.
The choice of activation function in the FFN significantly impacts both training dynamics and final performance. The field has evolved from simple ReLU to more sophisticated alternatives.
ReLU (Rectified Linear Unit) – Original Transformer
$$\text{ReLU}(x) = \max(0, x)$$
Properties:
GELU (Gaussian Error Linear Unit) – BERT, GPT
$$\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]$$
Or the approximation: $$\text{GELU}(x) \approx 0.5x\left(1 + \tanh\left[\sqrt{2/\pi}(x + 0.044715x^3)\right]\right)$$
Properties:
Swish (SiLU) – EfficientNet, Some LLMs
$$\text{Swish}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}$$
Properties:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
import torchimport torch.nn as nnimport torch.nn.functional as Fimport matplotlib.pyplot as pltimport numpy as np # Activation function implementationsclass ActivationFunctions: """Common activation functions used in Transformer FFNs.""" @staticmethod def relu(x: torch.Tensor) -> torch.Tensor: """ReLU: max(0, x)""" return F.relu(x) @staticmethod def gelu(x: torch.Tensor) -> torch.Tensor: """GELU: x * Φ(x) where Φ is the CDF of standard normal.""" return F.gelu(x) @staticmethod def gelu_approx(x: torch.Tensor) -> torch.Tensor: """Fast GELU approximation used in many implementations.""" return 0.5 * x * (1 + torch.tanh( np.sqrt(2 / np.pi) * (x + 0.044715 * x ** 3) )) @staticmethod def swish(x: torch.Tensor) -> torch.Tensor: """Swish (SiLU): x * sigmoid(x)""" return x * torch.sigmoid(x) @staticmethod def swish_beta(x: torch.Tensor, beta: float = 1.0) -> torch.Tensor: """Parameterized Swish: x * sigmoid(beta * x)""" return x * torch.sigmoid(beta * x) def compare_activations(): """Compare activation functions and their gradients.""" x = torch.linspace(-4, 4, 1000, requires_grad=True) activations = { 'ReLU': F.relu, 'GELU': F.gelu, 'Swish': lambda x: x * torch.sigmoid(x), } print("Activation Comparison at key points:") print("-" * 60) test_points = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) for name, fn in activations.items(): values = fn(test_points) print(f"{name:10s}: {values.tolist()}") print("Gradient at x=0:") for name, fn in activations.items(): x0 = torch.tensor([0.0], requires_grad=True) y = fn(x0) y.backward() print(f" {name}: {x0.grad.item():.4f}") print("Gradient at x=-1:") for name, fn in activations.items(): x0 = torch.tensor([-1.0], requires_grad=True) y = fn(x0) y.backward() print(f" {name}: {x0.grad.item():.4f}") def sparsity_analysis(): """Analyze sparsity induced by different activations.""" torch.manual_seed(42) # Simulate FFN intermediate activations batch_size, seq_len, d_ff = 16, 128, 2048 pre_activation = torch.randn(batch_size, seq_len, d_ff) relu_out = F.relu(pre_activation) gelu_out = F.gelu(pre_activation) swish_out = pre_activation * torch.sigmoid(pre_activation) # Count "zero" or near-zero activations threshold = 1e-6 relu_zeros = (relu_out.abs() < threshold).float().mean().item() gelu_zeros = (gelu_out.abs() < threshold).float().mean().item() swish_zeros = (swish_out.abs() < threshold).float().mean().item() print("Sparsity Analysis (fraction of near-zero activations):") print(f" ReLU: {relu_zeros:.4f} ({relu_zeros*100:.1f}%)") print(f" GELU: {gelu_zeros:.4f} ({gelu_zeros*100:.1f}%)") print(f" Swish: {swish_zeros:.4f} ({swish_zeros*100:.1f}%)") # More meaningful: fraction of "small" outputs small_threshold = 0.1 relu_small = (relu_out.abs() < small_threshold).float().mean().item() gelu_small = (gelu_out.abs() < small_threshold).float().mean().item() swish_small = (swish_out.abs() < small_threshold).float().mean().item() print(f"Fraction of small activations (|x| < 0.1):") print(f" ReLU: {relu_small:.4f} ({relu_small*100:.1f}%)") print(f" GELU: {gelu_small:.4f} ({gelu_small*100:.1f}%)") print(f" Swish: {swish_small:.4f} ({swish_small*100:.1f}%)") compare_activations()sparsity_analysis()| Property | ReLU | GELU | Swish |
|---|---|---|---|
| Formula | max(0,x) | x·Φ(x) | x·σ(x) |
| Smoothness | Non-smooth at 0 | Smooth everywhere | Smooth everywhere |
| Gradient at x=0 | Undefined/0 | 0.5 | 0.5 |
| Sparsity | Exact zeros | Soft suppression | Soft suppression |
| Monotonic | Yes | No (slight dip) | No (slight dip) |
| Compute Cost | Lowest | Medium | Medium |
| Used In | Original Transformer | BERT, GPT family | EfficientNet, some LLMs |
Modern large language models increasingly use Gated Linear Units (GLU) variants in their FFN layers. These provide a different architectural pattern that has shown improved performance, particularly for larger models.
Original GLU (Dauphin et al., 2017)
The Gated Linear Unit splits the intermediate computation into two parallel paths:
$$\text{GLU}(x) = (W_1 x + b_1) \otimes \sigma(W_2 x + b_2)$$
where:
The gate $\sigma(W_2 x + b_2)$ controls how much of the content $(W_1 x + b_1)$ passes through.
SwiGLU (Shazeer, 2020)
SwiGLU replaces the sigmoid gate with Swish:
$$\text{SwiGLU}(x) = (W_1 x) \otimes \text{Swish}(W_2 x)$$
Or equivalently: $$\text{SwiGLU}(x) = (W_1 x) \otimes (W_2 x \cdot \sigma(W_2 x))$$
GeGLU
Uses GELU as the gating function: $$\text{GeGLU}(x) = (W_1 x) \otimes \text{GELU}(W_2 x)$$
Parameter Considerations
GLU variants require three weight matrices instead of two:
To maintain similar parameter count, LLaMA and others reduce $d_{ff}$ by a factor of $2/3$:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
import torchimport torch.nn as nnimport torch.nn.functional as F class SwiGLU_FFN(nn.Module): """ SwiGLU Feed-Forward Network as used in LLaMA and other modern LLMs. Architecture: gate = Swish(W_gate @ x) content = W_up @ x output = W_down @ (gate ⊙ content) The intermediate dimension is typically 8/3 * d_model to maintain similar parameter count to standard FFN with 4 * d_model. """ def __init__( self, d_model: int, d_ff: int = None, dropout: float = 0.0, bias: bool = False # LLaMA omits biases ): super().__init__() # Default to 8/3 * d_model, rounded to multiple of 256 if d_ff is None: d_ff = int(8 * d_model / 3) d_ff = (d_ff + 255) // 256 * 256 # Round up to multiple of 256 # Three projections for GLU self.w_gate = nn.Linear(d_model, d_ff, bias=bias) # Gate projection self.w_up = nn.Linear(d_model, d_ff, bias=bias) # Up/content projection self.w_down = nn.Linear(d_ff, d_model, bias=bias) # Down/output projection self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply SwiGLU FFN. Args: x: Input [batch_size, seq_len, d_model] Returns: Output [batch_size, seq_len, d_model] """ # Compute gate and content in parallel gate = F.silu(self.w_gate(x)) # SiLU is Swish with β=1 content = self.w_up(x) # Element-wise gating hidden = gate * content hidden = self.dropout(hidden) # Project back to model dimension output = self.w_down(hidden) return output class GeGLU_FFN(nn.Module): """ GeGLU Feed-Forward Network using GELU gating. Used in some T5 variants and other models. """ def __init__(self, d_model: int, d_ff: int = None, bias: bool = True): super().__init__() if d_ff is None: d_ff = int(8 * d_model / 3) self.w_gate = nn.Linear(d_model, d_ff, bias=bias) self.w_up = nn.Linear(d_model, d_ff, bias=bias) self.w_down = nn.Linear(d_ff, d_model, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: gate = F.gelu(self.w_gate(x)) content = self.w_up(x) return self.w_down(gate * content) class ReGLU_FFN(nn.Module): """ ReGLU: GLU variant with ReLU gating. """ def __init__(self, d_model: int, d_ff: int = None, bias: bool = True): super().__init__() if d_ff is None: d_ff = int(8 * d_model / 3) self.w_gate = nn.Linear(d_model, d_ff, bias=bias) self.w_up = nn.Linear(d_model, d_ff, bias=bias) self.w_down = nn.Linear(d_ff, d_model, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: gate = F.relu(self.w_gate(x)) content = self.w_up(x) return self.w_down(gate * content) def compare_ffn_variants(): """Compare standard FFN vs GLU variants.""" d_model = 4096 # Standard FFN d_ff_standard = 4 * d_model # 16384 standard_params = 2 * d_model * d_ff_standard # Ignoring bias # GLU FFN (adjusted d_ff to match params) d_ff_glu = int(8 * d_model / 3) # ~10922 glu_params = 3 * d_model * d_ff_glu print("FFN Parameter Comparison (d_model = 4096):") print(f" Standard FFN (d_ff = {d_ff_standard:,}):") print(f" Parameters: {standard_params:,}") print(f" SwiGLU FFN (d_ff = {d_ff_glu:,}):") print(f" Parameters: {glu_params:,}") print(f" Ratio: {glu_params / standard_params:.3f}") # Test forward pass batch_size, seq_len = 2, 128 x = torch.randn(batch_size, seq_len, d_model) swiglu = SwiGLU_FFN(d_model, d_ff_glu) # Count actual parameters actual_params = sum(p.numel() for p in swiglu.parameters()) print(f" Actual SwiGLU params: {actual_params:,}") compare_ffn_variants()GLU provides a multiplicative gating mechanism that can completely suppress certain dimensions (gate ≈ 0) or pass them unchanged (gate ≈ 1). This enables cleaner, more selective information flow compared to additive activations. Empirically, SwiGLU shows consistent improvements especially for larger models, hence its adoption in LLaMA, PaLM, and others.
As language models scale, the FFN becomes a computational bottleneck. Mixture of Experts (MoE) offers a way to massively scale FFN capacity while keeping computational cost tractable.
The MoE Concept
Instead of a single FFN, MoE uses multiple "expert" FFNs and a router that selects which experts process each token:
$$\text{MoE}(x) = \sum_{i=1}^{N} g_i(x) \cdot \text{FFN}_i(x)$$
where:
Sparse Computation
The key insight: if only $k$ experts are active per token, computation scales with $k$ rather than $N$. You can have 128 experts but only compute 2, giving 64x theoretical capacity increase with 2x computational increase.
The Router
A learned router network determines expert selection:
$$\text{router}(x) = \text{softmax}(W_r x)$$
The top-$k$ values determine which experts are active. Various routing strategies exist:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Tuple class MoELayer(nn.Module): """ Simplified Mixture of Experts layer. Each token is routed to top-k experts. Only those experts perform computation for that token, enabling sparse scaling. Used in: Switch Transformer, Mixtral, GLaM, etc. """ def __init__( self, d_model: int, d_ff: int, num_experts: int = 8, top_k: int = 2, dropout: float = 0.1 ): super().__init__() self.num_experts = num_experts self.top_k = top_k self.d_model = d_model # Router: produces expert scores for each token self.router = nn.Linear(d_model, num_experts, bias=False) # Expert FFNs (each is a standard FFN) self.experts = nn.ModuleList([ nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model), nn.Dropout(dropout) ) for _ in range(num_experts) ]) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass with expert routing. Args: x: Input [batch_size, seq_len, d_model] Returns: output: Processed output [batch_size, seq_len, d_model] router_logits: For auxiliary loss [batch_size, seq_len, num_experts] """ batch_size, seq_len, d_model = x.shape # Compute router scores router_logits = self.router(x) # [batch, seq, num_experts] # Select top-k experts for each token routing_weights, selected_experts = torch.topk( router_logits, self.top_k, dim=-1 ) routing_weights = F.softmax(routing_weights, dim=-1) # Initialize output output = torch.zeros_like(x) # Flatten for easier processing flat_x = x.view(-1, d_model) # [batch*seq, d_model] flat_output = output.view(-1, d_model) flat_selected = selected_experts.view(-1, self.top_k) flat_weights = routing_weights.view(-1, self.top_k) # Route tokens to experts (simplified, not optimized) for i in range(self.top_k): for expert_idx in range(self.num_experts): # Find tokens routed to this expert at position i mask = (flat_selected[:, i] == expert_idx) if mask.any(): expert_input = flat_x[mask] expert_output = self.experts[expert_idx](expert_input) # Weight by routing weight flat_output[mask] += flat_weights[mask, i:i+1] * expert_output output = flat_output.view(batch_size, seq_len, d_model) return output, router_logits def load_balance_loss(self, router_logits: torch.Tensor) -> torch.Tensor: """ Compute auxiliary load balancing loss. Encourages equal utilization of all experts. """ # Average routing probability per expert routing_probs = F.softmax(router_logits, dim=-1) avg_per_expert = routing_probs.mean(dim=[0, 1]) # [num_experts] # Ideal is 1/num_experts for each target = 1.0 / self.num_experts # Variance from ideal balance_loss = torch.sum((avg_per_expert - target) ** 2) return balance_loss def demonstrate_moe(): """Show MoE behavior and scaling.""" d_model = 512 d_ff = 2048 num_experts = 8 top_k = 2 moe = MoELayer(d_model, d_ff, num_experts, top_k) # Parameter count comparison single_ffn_params = 2 * d_model * d_ff moe_params = sum(p.numel() for p in moe.parameters()) print("MoE Parameter Analysis:") print(f" Single FFN: {single_ffn_params:,} parameters") print(f" MoE ({num_experts} experts, top-{top_k}): {moe_params:,} parameters") print(f" Capacity increase: {num_experts}x") print(f" Compute increase: ~{top_k}x (only top-{top_k} active)") # Forward pass batch_size, seq_len = 2, 16 x = torch.randn(batch_size, seq_len, d_model) output, router_logits = moe(x) print(f" Input shape: {x.shape}") print(f" Output shape: {output.shape}") # Analyze routing selected = router_logits.argmax(dim=-1) # Top-1 for simplicity unique, counts = torch.unique(selected, return_counts=True) print(f" Expert utilization (top-1):") for exp, cnt in zip(unique.tolist(), counts.tolist()): print(f" Expert {exp}: {cnt} tokens ({cnt/(batch_size*seq_len)*100:.1f}%)") demonstrate_moe()The FFN is often the most computationally expensive component of a Transformer layer. Understanding its complexity and optimization strategies is crucial for efficient implementation.
Computational Complexity
For a single FFN layer with input of shape $[B, S, D]$ (batch, sequence, model dimension):
$$\text{FLOPs} = 2 \times B \times S \times D \times D_{ff} + 2 \times B \times S \times D_{ff} \times D$$ $$= 4 \times B \times S \times D \times D_{ff}$$
With $D_{ff} = 4D$: $$\text{FLOPs} = 16 \times B \times S \times D^2$$
For comparison, self-attention has: $$\text{Attention FLOPs} \approx 4 \times B \times S^2 \times D + 4 \times B \times S \times D^2$$
When $S > 4D$ (long sequences), attention dominates. When $S < 4D$ (common for many models where S ≈ 512-2048 and D ≈ 512-4096), FFN dominates.
Memory Considerations
The FFN's memory footprint includes:
For large models, intermediate activations can require gigabytes of memory. Techniques to reduce this:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
import torchimport torch.nn as nnfrom torch.utils.checkpoint import checkpoint class EfficientFFN(nn.Module): """ FFN with efficiency optimizations for large-scale training. Features: - Optional gradient checkpointing - Mixed precision support - Fused operations where available """ def __init__( self, d_model: int, d_ff: int, dropout: float = 0.0, use_checkpoint: bool = False, bias: bool = False ): super().__init__() self.use_checkpoint = use_checkpoint # No bias improves efficiency slightly self.w1 = nn.Linear(d_model, d_ff, bias=bias) self.w2 = nn.Linear(d_ff, d_model, bias=bias) self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: """Core forward logic, separated for checkpointing.""" # Fuse W1 + activation where possible hidden = self.w1(x) hidden = torch.nn.functional.gelu(hidden) hidden = self.dropout(hidden) return self.w2(hidden) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_checkpoint and self.training: # Gradient checkpointing: save memory, recompute in backward return checkpoint( self._forward_impl, x, use_reentrant=False ) return self._forward_impl(x) class TensorParallelFFN(nn.Module): """ FFN designed for tensor parallelism across GPUs. W1 is column-parallel: split along d_ff dimension W2 is row-parallel: split along input dimension This allows partial results to be computed independently, then all-reduced only after W2. """ def __init__( self, d_model: int, d_ff: int, num_partitions: int, partition_id: int ): super().__init__() assert d_ff % num_partitions == 0 local_d_ff = d_ff // num_partitions self.num_partitions = num_partitions self.partition_id = partition_id # Each GPU has a slice of the FFN self.w1 = nn.Linear(d_model, local_d_ff, bias=False) self.w2 = nn.Linear(local_d_ff, d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward with local computation. Caller must handle all-reduce after this. """ hidden = torch.nn.functional.gelu(self.w1(x)) output = self.w2(hidden) # In practice: output = all_reduce(output) across partitions return output def benchmark_ffn_variants(): """Benchmark different FFN implementations.""" import time d_model = 4096 d_ff = 16384 batch_size = 8 seq_len = 2048 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") # Standard FFN standard = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model) ).to(device) # Checkpointed FFN checkpointed = EfficientFFN(d_model, d_ff, use_checkpoint=True).to(device) x = torch.randn(batch_size, seq_len, d_model, device=device) # Warmup for _ in range(3): _ = standard(x).sum().backward() standard.zero_grad() if device.type == 'cuda': torch.cuda.synchronize() # Benchmark forward + backward n_iters = 10 start = time.perf_counter() for _ in range(n_iters): y = standard(x) y.sum().backward() standard.zero_grad() if device.type == 'cuda': torch.cuda.synchronize() standard_time = time.perf_counter() - start start = time.perf_counter() for _ in range(n_iters): y = checkpointed(x) y.sum().backward() checkpointed.zero_grad() if device.type == 'cuda': torch.cuda.synchronize() checkpoint_time = time.perf_counter() - start print(f"Benchmark ({n_iters} iterations):") print(f" Standard FFN: {standard_time:.3f}s") print(f" Checkpointed: {checkpoint_time:.3f}s") print(f" Overhead: {(checkpoint_time/standard_time - 1)*100:.1f}%") if device.type == 'cuda': # Memory comparison torch.cuda.reset_peak_memory_stats() _ = standard(x).sum().backward() standard_mem = torch.cuda.max_memory_allocated() / 1e9 standard.zero_grad() torch.cuda.reset_peak_memory_stats() _ = checkpointed(x).sum().backward() checkpoint_mem = torch.cuda.max_memory_allocated() / 1e9 print(f"Peak memory:") print(f" Standard: {standard_mem:.2f} GB") print(f" Checkpointed: {checkpoint_mem:.2f} GB") print(f" Savings: {(1 - checkpoint_mem/standard_mem)*100:.1f}%") # Uncomment to run benchmark:# benchmark_ffn_variants()We have explored the position-wise feed-forward network in comprehensive detail. Let's consolidate the key insights:
Looking Ahead
The FFN works in concert with attention, residual connections, and layer normalization to form a complete Transformer layer. In the next page, we'll examine residual connections—the architectural component that enables training very deep Transformers by providing gradient highways through the network.
You now understand the position-wise feed-forward network's architecture, role, and modern variations. You can explain why FFN capacity matters, how activation functions have evolved, and what innovations like SwiGLU and MoE offer. Next, we'll explore residual connections.