Loading learning content...
Throughout this module, we've examined the Transformer's components in isolation: encoder-decoder structure, layer normalization, feed-forward networks, and residual connections. Now we synthesize these elements into the complete architecture, understanding how they work together as an integrated system.
The original Transformer from "Attention Is All You Need" (Vaswani et al., 2017) established an architecture that has proven remarkably versatile. With minor modifications, it underlies virtually all state-of-the-art NLP models, from BERT and GPT to T5 and beyond. Understanding the full architecture—not just its parts—is essential for using, adapting, and innovating upon Transformers.
This page provides a comprehensive walkthrough of the complete architecture, practical implementation guidance, common configurations, and the design rationale that makes the Transformer so effective.
By the end of this page, you should be able to implement a Transformer from scratch, understand every component's role in the information flow, debug common training issues, and adapt the architecture for various tasks.
Let's walk through the complete end-to-end information flow in a Transformer encoder-decoder model.
Input Processing Pipeline
Encoder Stack
For each of $N$ encoder layers:
The encoder produces contextualized representations $Z = (z_1, ..., z_n)$ for all input positions.
Decoder Stack
For each of $N$ decoder layers:
Output Projection
The diagram shows Post-LN (original). In Pre-LN, layer normalization is applied before each sublayer instead of after the residual addition, and a final layer normalization is added before the output projection.
Let's implement a complete Transformer architecture, integrating all the components we've studied. This implementation follows modern best practices (Pre-LN, GELU activation) while remaining faithful to the original design.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
import torchimport torch.nn as nnimport torch.nn.functional as Fimport mathfrom typing import Optional class MultiHeadAttention(nn.Module): """ Multi-head attention mechanism. Supports both self-attention and cross-attention. """ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads # Linear projections for Q, K, V, and output self.W_q = nn.Linear(d_model, d_model, bias=False) self.W_k = nn.Linear(d_model, d_model, bias=False) self.W_v = nn.Linear(d_model, d_model, bias=False) self.W_o = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: batch_size = query.size(0) # Project and reshape: [batch, seq, d_model] -> [batch, n_heads, seq, d_k] Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) K = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) V = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # Scaled dot-product attention scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) # Apply attention to values attn_output = torch.matmul(attn_weights, V) # Reshape and project: [batch, n_heads, seq, d_k] -> [batch, seq, d_model] attn_output = attn_output.transpose(1, 2).contiguous().view( batch_size, -1, self.d_model ) return self.W_o(attn_output) class FeedForward(nn.Module): """Position-wise feed-forward network with GELU activation.""" def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear2(self.dropout(F.gelu(self.linear1(x)))) class EncoderLayer(nn.Module): """Single encoder layer with Pre-LN configuration.""" def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, n_heads, dropout) self.ffn = FeedForward(d_model, d_ff, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward( self, x: torch.Tensor, src_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: # Pre-LN self-attention normed = self.norm1(x) x = x + self.dropout1(self.self_attn(normed, normed, normed, src_mask)) # Pre-LN feed-forward normed = self.norm2(x) x = x + self.dropout2(self.ffn(normed)) return x class DecoderLayer(nn.Module): """Single decoder layer with Pre-LN configuration.""" def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, n_heads, dropout) self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) self.ffn = FeedForward(d_model, d_ff, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) def forward( self, x: torch.Tensor, encoder_output: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: # Pre-LN masked self-attention normed = self.norm1(x) x = x + self.dropout1(self.self_attn(normed, normed, normed, tgt_mask)) # Pre-LN cross-attention normed = self.norm2(x) x = x + self.dropout2( self.cross_attn(normed, encoder_output, encoder_output, memory_mask) ) # Pre-LN feed-forward normed = self.norm3(x) x = x + self.dropout3(self.ffn(normed)) return x123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
class PositionalEncoding(nn.Module): """Sinusoidal positional encoding.""" def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # [1, max_len, d_model] self.register_buffer('pe', pe) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.pe[:, :x.size(1), :] return self.dropout(x) class Transformer(nn.Module): """ Complete Transformer model for sequence-to-sequence tasks. Implements the architecture from "Attention Is All You Need" with modern Pre-LN configuration for stable training. """ def __init__( self, src_vocab_size: int, tgt_vocab_size: int, d_model: int = 512, n_heads: int = 8, n_encoder_layers: int = 6, n_decoder_layers: int = 6, d_ff: int = 2048, dropout: float = 0.1, max_len: int = 5000, share_embeddings: bool = False ): super().__init__() self.d_model = d_model # Embeddings self.src_embedding = nn.Embedding(src_vocab_size, d_model) if share_embeddings and src_vocab_size == tgt_vocab_size: self.tgt_embedding = self.src_embedding else: self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model) # Positional encoding self.positional_encoding = PositionalEncoding(d_model, max_len, dropout) # Encoder stack self.encoder_layers = nn.ModuleList([ EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_encoder_layers) ]) self.encoder_norm = nn.LayerNorm(d_model) # Final norm for Pre-LN # Decoder stack self.decoder_layers = nn.ModuleList([ DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_decoder_layers) ]) self.decoder_norm = nn.LayerNorm(d_model) # Final norm for Pre-LN # Output projection self.output_projection = nn.Linear(d_model, tgt_vocab_size) # Initialize weights self._init_weights() def _init_weights(self): """Initialize parameters using Xavier uniform.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def encode( self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """Encode source sequence.""" # Embed and add positional encoding x = self.src_embedding(src) * math.sqrt(self.d_model) x = self.positional_encoding(x) # Pass through encoder layers for layer in self.encoder_layers: x = layer(x, src_mask) return self.encoder_norm(x) def decode( self, tgt: torch.Tensor, encoder_output: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """Decode target sequence given encoder output.""" # Embed and add positional encoding x = self.tgt_embedding(tgt) * math.sqrt(self.d_model) x = self.positional_encoding(x) # Pass through decoder layers for layer in self.decoder_layers: x = layer(x, encoder_output, tgt_mask, memory_mask) return self.decoder_norm(x) def forward( self, src: torch.Tensor, tgt: torch.Tensor, src_mask: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Full forward pass for training. Args: src: Source tokens [batch, src_len] tgt: Target tokens [batch, tgt_len] src_mask: Source padding mask tgt_mask: Target causal mask memory_mask: Cross-attention mask Returns: Logits [batch, tgt_len, vocab_size] """ encoder_output = self.encode(src, src_mask) decoder_output = self.decode(tgt, encoder_output, tgt_mask, memory_mask) logits = self.output_projection(decoder_output) return logits @staticmethod def generate_causal_mask(size: int, device: torch.device) -> torch.Tensor: """Generate causal mask for decoder self-attention.""" mask = torch.triu(torch.ones(size, size, device=device), diagonal=1) return mask == 0 # True where attention is allowed @torch.no_grad() def generate( self, src: torch.Tensor, max_len: int, bos_token_id: int, eos_token_id: int, src_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Autoregressive generation. Args: src: Source sequence [batch, src_len] max_len: Maximum generation length bos_token_id: Beginning of sequence token eos_token_id: End of sequence token src_mask: Source padding mask Returns: Generated sequences [batch, gen_len] """ batch_size = src.size(0) device = src.device # Encode source once encoder_output = self.encode(src, src_mask) # Start with BOS token generated = torch.full( (batch_size, 1), bos_token_id, dtype=torch.long, device=device ) for _ in range(max_len - 1): tgt_mask = self.generate_causal_mask(generated.size(1), device) decoder_output = self.decode( generated, encoder_output, tgt_mask, src_mask ) # Get next token (greedy) logits = self.output_projection(decoder_output[:, -1, :]) next_token = logits.argmax(dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) # Stop if all sequences have generated EOS if (next_token == eos_token_id).all(): break return generatedUnderstanding the parameter distribution and standard configurations helps in selecting appropriate model sizes and debugging memory issues.
Parameter Count Analysis
For a Transformer with:
Per Encoder Layer:
Per Decoder Layer:
Embeddings:
Total Model (encoder-decoder): $$\text{Params} \approx L \cdot 12d^2 + L \cdot 16d^2 + 2Vd = 28Ld^2 + 2Vd$$
| Configuration | d_model | d_ff | Heads | Layers | Parameters |
|---|---|---|---|---|---|
| Transformer Base | 512 | 2048 | 8 | 6+6 | ~65M |
| Transformer Big | 1024 | 4096 | 16 | 6+6 | ~213M |
| BERT-Base | 768 | 3072 | 12 | 12 (enc) | ~110M |
| BERT-Large | 1024 | 4096 | 16 | 24 (enc) | ~340M |
| GPT-2 Small | 768 | 3072 | 12 | 12 (dec) | ~124M |
| GPT-2 Medium | 1024 | 4096 | 16 | 24 (dec) | ~355M |
| GPT-2 Large | 1280 | 5120 | 20 | 36 (dec) | ~774M |
| GPT-2 XL | 1600 | 6400 | 25 | 48 (dec) | ~1.5B |
| T5-Small | 512 | 2048 | 8 | 6+6 | ~60M |
| T5-Base | 768 | 3072 | 12 | 12+12 | ~220M |
| T5-Large | 1024 | 4096 | 16 | 24+24 | ~770M |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
def count_transformer_params( vocab_size: int, d_model: int, d_ff: int, n_heads: int, n_encoder_layers: int, n_decoder_layers: int, share_embeddings: bool = True) -> dict: """ Calculate exact parameter count for a Transformer. """ # Per encoder layer enc_self_attn = 4 * d_model * d_model # Q, K, V, O enc_ffn = 2 * d_model * d_ff # two linear layers enc_norms = 2 * 2 * d_model # 2 LayerNorms, each with γ and β enc_layer_total = enc_self_attn + enc_ffn + enc_norms # Per decoder layer dec_self_attn = 4 * d_model * d_model dec_cross_attn = 4 * d_model * d_model dec_ffn = 2 * d_model * d_ff dec_norms = 3 * 2 * d_model # 3 LayerNorms dec_layer_total = dec_self_attn + dec_cross_attn + dec_ffn + dec_norms # Embeddings src_embed = vocab_size * d_model tgt_embed = 0 if share_embeddings else vocab_size * d_model out_proj = vocab_size * d_model # Final norms (Pre-LN) final_norms = 2 * 2 * d_model # encoder + decoder final norms # Total encoder_total = n_encoder_layers * enc_layer_total decoder_total = n_decoder_layers * dec_layer_total embed_total = src_embed + tgt_embed + out_proj total = encoder_total + decoder_total + embed_total + final_norms return { 'encoder_total': encoder_total, 'decoder_total': decoder_total, 'embeddings': embed_total, 'total': total, 'total_millions': total / 1e6 } # Example: Transformer Baseparams = count_transformer_params( vocab_size=32000, d_model=512, d_ff=2048, n_heads=8, n_encoder_layers=6, n_decoder_layers=6, share_embeddings=True) print("Transformer Base Parameter Count:")for key, value in params.items(): if key == 'total_millions': print(f" {key}: {value:.1f}M") else: print(f" {key}: {value:,}")Training Transformers effectively requires attention to several key factors: learning rate schedules, regularization, initialization, and optimizer choice.
Learning Rate Schedule
The original Transformer uses a warmup-then-decay schedule:
$$lr = d_{model}^{-0.5} \cdot \min(step^{-0.5}, step \cdot warmup_steps^{-1.5})$$
This:
Modern approaches often use:
Optimizer Settings
The original used Adam with:
Modern recommendations:
Label Smoothing
The original Transformer uses label smoothing with $\epsilon_{ls} = 0.1$:
$$p_{smooth}(y | x) = (1 - \epsilon_{ls}) \cdot \mathbf{1}{y=target} + \frac{\epsilon{ls}}{V}$$
This prevents overconfident predictions and improves generalization.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
import torchimport torch.nn as nnimport math class TransformerLRScheduler: """ Learning rate scheduler following the original Transformer paper. lr = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5)) """ def __init__(self, optimizer, d_model: int, warmup_steps: int = 4000): self.optimizer = optimizer self.d_model = d_model self.warmup_steps = warmup_steps self.step_count = 0 def step(self): self.step_count += 1 lr = self._get_lr() for param_group in self.optimizer.param_groups: param_group['lr'] = lr return lr def _get_lr(self) -> float: step = self.step_count warmup = self.warmup_steps # Formula from "Attention Is All You Need" arg1 = step ** (-0.5) arg2 = step * (warmup ** (-1.5)) return (self.d_model ** (-0.5)) * min(arg1, arg2) class LabelSmoothingLoss(nn.Module): """ Label smoothing cross-entropy loss. Distributes a small amount of probability mass uniformly across all classes, preventing the model from becoming overconfident. """ def __init__(self, vocab_size: int, smoothing: float = 0.1, ignore_index: int = -100): super().__init__() self.vocab_size = vocab_size self.smoothing = smoothing self.ignore_index = ignore_index def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute label-smoothed cross-entropy loss. Args: logits: [batch, seq_len, vocab_size] targets: [batch, seq_len] Returns: Scalar loss """ logits = logits.view(-1, self.vocab_size) targets = targets.view(-1) # Create smoothed distribution smooth_targets = torch.full_like(logits, self.smoothing / self.vocab_size) # Create mask for valid positions mask = (targets != self.ignore_index) valid_targets = targets[mask] # Set target probability smooth_targets[mask] = smooth_targets[mask].scatter( 1, valid_targets.unsqueeze(1), 1 - self.smoothing + self.smoothing / self.vocab_size ) # Cross-entropy with soft targets log_probs = torch.log_softmax(logits, dim=-1) loss = -torch.sum(smooth_targets * log_probs, dim=-1) # Average over valid positions return loss[mask].mean() def get_optimizer_and_scheduler( model: nn.Module, d_model: int, warmup_steps: int = 4000, weight_decay: float = 0.01, max_grad_norm: float = 1.0) -> tuple: """ Create optimizer and scheduler following best practices. Returns: (optimizer, scheduler, clip_fn) """ # Separate weight decay parameters decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # Don't apply weight decay to biases and LayerNorm if 'bias' in name or 'LayerNorm' in name or 'norm' in name: no_decay_params.append(param) else: decay_params.append(param) optimizer = torch.optim.AdamW([ {'params': decay_params, 'weight_decay': weight_decay}, {'params': no_decay_params, 'weight_decay': 0.0} ], lr=1.0, betas=(0.9, 0.98), eps=1e-9) scheduler = TransformerLRScheduler(optimizer, d_model, warmup_steps) def clip_fn(): torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) return optimizer, scheduler, clip_fn # Training loop exampledef training_step( model: nn.Module, batch: dict, criterion: nn.Module, optimizer: torch.optim.Optimizer, clip_fn: callable) -> float: """Single training step.""" model.train() src = batch['source'] tgt = batch['target'] tgt_input = tgt[:, :-1] # Input: all except last tgt_output = tgt[:, 1:] # Label: all except first # Generate masks tgt_mask = model.generate_causal_mask(tgt_input.size(1), tgt_input.device) # Forward pass logits = model(src, tgt_input, tgt_mask=tgt_mask) # Compute loss loss = criterion(logits, tgt_output) # Backward pass optimizer.zero_grad() loss.backward() clip_fn() optimizer.step() return loss.item()The Transformer architecture can be configured into three major variants, each suited for different task types.
Encoder-Only Models (BERT-style)
Examples: BERT, RoBERTa, ALBERT, DeBERTa, ELECTRA
Decoder-Only Models (GPT-style)
Examples: GPT-2, GPT-3, GPT-4, LLaMA, Claude, Gemini
Encoder-Decoder Models (T5-style)
Examples: T5, BART, mBART, Pegasus
| Aspect | Encoder-Only | Decoder-Only | Encoder-Decoder |
|---|---|---|---|
| Attention Type | Bidirectional | Causal (unidirectional) | Encoder: bidirectional; Decoder: causal + cross |
| Training Objective | Masked LM, NSP | Next token prediction | Span denoising, T2T |
| Generation | Not natively | Yes, autoregressive | Yes, with encoder conditioning |
| Understanding | Excellent | Good (via generation) | Excellent |
| Typical Tasks | Classification, NER | Generation, completion | Translation, summarization |
| Notable Models | BERT, RoBERTa | GPT family, LLaMA | T5, BART, mBART |
Why Decoder-Only Dominates Today
Decoder-only models (GPT-style) have become the dominant paradigm for large language models. Key reasons:
However, encoder-decoder models still excel for tasks with distinct input/output modalities (translation, structured summarization).
For classification/embeddings: encoder-only (fine-tune BERT). For generation/dialogue: decoder-only (fine-tune GPT/LLaMA). For seq2seq with clear input/output split: encoder-decoder (fine-tune T5). When in doubt, modern decoder-only models handle most tasks well via prompt engineering.
Understanding the memory and computational requirements of Transformers is essential for practical deployment.
Attention Complexity
Self-attention has $O(n^2 \cdot d)$ computation and $O(n^2)$ memory for the attention matrix:
$$\text{Attention FLOPs} = 4 \cdot n^2 \cdot d$$
This quadratic scaling in sequence length is the Transformer's primary bottleneck for long sequences.
Feed-Forward Complexity
FFN has $O(n \cdot d^2)$ computation (assuming $d_{ff} = 4d$):
$$\text{FFN FLOPs} = 16 \cdot n \cdot d^2$$
Crossover Point
When does attention dominate vs. FFN?
For $d = 512$: attention dominates when $n > 2048$ For $d = 4096$: attention dominates when $n > 16384$
Memory During Training
Training memory includes:
The activations often dominate for long sequences.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
def estimate_memory_usage( batch_size: int, seq_len: int, d_model: int, d_ff: int, n_layers: int, n_heads: int, vocab_size: int, precision: str = "fp32" # "fp32", "fp16", or "bf16") -> dict: """ Estimate GPU memory usage for Transformer training. Returns memory estimates in GB. """ bytes_per_param = 4 if precision == "fp32" else 2 # Model parameters per_layer = 12 * d_model * d_model + 2 * d_model * d_ff # simplified total_params = n_layers * per_layer + 2 * vocab_size * d_model model_memory = total_params * bytes_per_param gradient_memory = model_memory # Same size as parameters # Optimizer states (Adam: momentum + variance) optimizer_memory = 2 * total_params * 4 # Always FP32 for optimizer # Activations (dominant for long sequences) # Per layer: hidden states + attention weights hidden_per_layer = batch_size * seq_len * d_model * bytes_per_param ffn_intermediate = batch_size * seq_len * d_ff * bytes_per_param attention_matrix = batch_size * n_heads * seq_len * seq_len * bytes_per_param activations_per_layer = hidden_per_layer + ffn_intermediate + attention_matrix total_activations = n_layers * activations_per_layer # Total total_memory = model_memory + gradient_memory + optimizer_memory + total_activations return { 'model_params': total_params, 'model_memory_gb': model_memory / 1e9, 'gradient_memory_gb': gradient_memory / 1e9, 'optimizer_memory_gb': optimizer_memory / 1e9, 'activations_memory_gb': total_activations / 1e9, 'total_memory_gb': total_memory / 1e9 } # Example: GPT-2 Medium-ishresult = estimate_memory_usage( batch_size=4, seq_len=1024, d_model=1024, d_ff=4096, n_layers=24, n_heads=16, vocab_size=50257, precision="fp16") print("Memory Estimate (GPT-2 Medium scale, FP16, batch=4, seq=1024):")for key, value in result.items(): if 'memory' in key or 'total' in key: print(f" {key}: {value:.2f} GB") else: print(f" {key}: {value:,}") # Effect of sequence lengthprint("Sequence Length Impact (same model, batch=1):")for seq_len in [512, 1024, 2048, 4096, 8192]: result = estimate_memory_usage( batch_size=1, seq_len=seq_len, d_model=1024, d_ff=4096, n_layers=24, n_heads=16, vocab_size=50257, precision="fp16" ) print(f" seq_len={seq_len}: {result['total_memory_gb']:.2f} GB")We've synthesized all Transformer components into a complete understanding of the architecture. Let's consolidate the key insights:
Module Complete!
You've now completed a comprehensive study of the Transformer architecture. You understand:
This foundational knowledge prepares you for the next module on positional encoding—exploring how Transformers represent sequence order without recurrence.
Congratulations! You now have a comprehensive understanding of the Transformer architecture at the level expected of a principal engineer. You can implement Transformers from scratch, reason about their computational requirements, choose appropriate variants for different tasks, and troubleshoot common issues. This knowledge forms the foundation for understanding all modern large language models.