Loading learning content...
In their 2019 paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer," Google researchers proposed a deceptively simple but remarkably powerful idea: treat every NLP task as a text-to-text problem.
Where BERT uses different output heads for different tasks (classification head, span extraction head, etc.) and GPT varies its prompting strategy, T5 uses the exact same model, loss function, and decoding procedure for every task. Classification? Output the class name as text. Translation? Output the translation. Summarization? Output the summary. Question answering? Output the answer.
This unification isn't merely elegant—it enables unprecedented systematic comparison of pre-training strategies, architectures, and scaling approaches, all within a single experimental framework.
This page covers T5's encoder-decoder architecture, the span corruption pre-training objective, task prefixing methodology, and how the text-to-text framing enables multi-task learning. You'll understand T5's relationship to BERT and GPT, explore variants like mT5 and Flan-T5, and learn when the encoder-decoder approach excels.
The Research Contribution:
The T5 paper was as much a comprehensive empirical study as it was a model introduction. The authors systematically explored:
This systematic approach yielded insights that continue to guide transformer research today.
T5 uses the original transformer's encoder-decoder architecture, the same structure designed for machine translation in "Attention Is All You Need." This architecture naturally handles sequence-to-sequence tasks where the output length differs from the input.
Architecture Overview:
Encoder: Bidirectional transformer that processes the input sequence
Decoder: Autoregressive transformer that generates the output sequence
Cross-Attention: Bridges encoder and decoder
| Model | Parameters | Layers (enc/dec) | Hidden Size | Attention Heads | FFN Size |
|---|---|---|---|---|---|
| T5-Small | 60M | 6/6 | 512 | 8 | 2048 |
| T5-Base | 220M | 12/12 | 768 | 12 | 3072 |
| T5-Large | 770M | 24/24 | 1024 | 16 | 4096 |
| T5-3B | 3B | 24/24 | 1024 | 32 | 16384 |
| T5-11B | 11B | 24/24 | 1024 | 128 | 65536 |
Key Architectural Choices in T5:
Relative Position Embeddings: T5 uses simplified relative position biases instead of absolute position embeddings. Each attention head learns a small set of scalar biases based on the relative distance between positions.
Pre-Layer Normalization: LayerNorm is applied before (not after) attention and FFN sub-layers, matching GPT-2's approach.
No Bias Terms: Linear layers omit bias terms for slight efficiency gains.
GeGLU Activation: Later variants use GeGLU activation in FFN: $$\text{GeGLU}(x) = \text{GELU}(xW_1) \otimes (xW_2)$$
Parameter Sharing: Optional cross-layer parameter sharing was explored but found to hurt performance.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
import torchimport torch.nn as nnimport torch.nn.functional as Fimport mathfrom typing import Optional, Tuple class RelativePositionBias(nn.Module): """ T5-style relative position bias. Learns scalar biases based on relative position between tokens. """ def __init__(self, num_heads: int, num_buckets: int = 32, max_distance: int = 128): super().__init__() self.num_heads = num_heads self.num_buckets = num_buckets self.max_distance = max_distance # Learnable bias for each (head, bucket) pair self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) def _relative_position_bucket(self, relative_position: torch.Tensor, bidirectional: bool = True) -> torch.Tensor: """ Map relative position to bucket index. Uses logarithmic bucketing for distant positions. """ ret = 0 n = -relative_position if bidirectional: num_buckets = self.num_buckets // 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) else: n = torch.max(n, torch.zeros_like(n)) max_exact = num_buckets // 2 is_small = n < max_exact # Logarithmic bucketing for larger distances val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(self.max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, query_length: int, key_length: int, device: torch.device, bidirectional: bool = True) -> torch.Tensor: """Compute relative position bias matrix.""" context_position = torch.arange(query_length, device=device)[:, None] memory_position = torch.arange(key_length, device=device)[None, :] relative_position = memory_position - context_position relative_position_bucket = self._relative_position_bucket(relative_position, bidirectional) values = self.relative_attention_bias(relative_position_bucket) # [query, key, heads] values = values.permute([2, 0, 1]).unsqueeze(0) # [1, heads, query, key] return values class T5Attention(nn.Module): """ T5 attention with relative position bias. """ def __init__( self, hidden_size: int, num_heads: int, is_decoder: bool = False, has_relative_bias: bool = True, dropout: float = 0.1 ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.is_decoder = is_decoder self.q = nn.Linear(hidden_size, hidden_size, bias=False) self.k = nn.Linear(hidden_size, hidden_size, bias=False) self.v = nn.Linear(hidden_size, hidden_size, bias=False) self.o = nn.Linear(hidden_size, hidden_size, bias=False) self.dropout = nn.Dropout(dropout) if has_relative_bias: self.relative_attention_bias = RelativePositionBias(num_heads) else: self.relative_attention_bias = None def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: batch_size, seq_length = hidden_states.shape[:2] # Self-attention or cross-attention if key_value_states is None: key_value_states = hidden_states query = self.q(hidden_states) key = self.k(key_value_states) value = self.v(key_value_states) # Reshape for multi-head attention query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # Attention scores scores = torch.matmul(query, key.transpose(-2, -1)) # Add relative position bias if position_bias is None and self.relative_attention_bias is not None: position_bias = self.relative_attention_bias( query.size(2), key.size(2), query.device, bidirectional=not self.is_decoder ) if position_bias is not None: scores = scores + position_bias if attention_mask is not None: scores = scores + attention_mask attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) output = torch.matmul(attn_weights, value) output = output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size) output = self.o(output) return output, position_bias class T5Block(nn.Module): """ T5 transformer block with optional cross-attention (for decoder). """ def __init__( self, hidden_size: int, num_heads: int, ffn_size: int, is_decoder: bool = False, has_relative_bias: bool = True, dropout: float = 0.1 ): super().__init__() self.is_decoder = is_decoder # Self-attention self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-6) self.self_attention = T5Attention( hidden_size, num_heads, is_decoder=is_decoder, has_relative_bias=has_relative_bias, dropout=dropout ) # Cross-attention (decoder only) if is_decoder: self.cross_layer_norm = nn.LayerNorm(hidden_size, eps=1e-6) self.cross_attention = T5Attention( hidden_size, num_heads, is_decoder=True, has_relative_bias=False, # No relative bias for cross-attention dropout=dropout ) # Feed-forward self.ff_layer_norm = nn.LayerNorm(hidden_size, eps=1e-6) self.ffn = nn.Sequential( nn.Linear(hidden_size, ffn_size, bias=False), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ffn_size, hidden_size, bias=False), nn.Dropout(dropout) ) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, self_attention_mask: Optional[torch.Tensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Pre-norm self-attention normed_states = self.layer_norm(hidden_states) attn_output, position_bias = self.self_attention( normed_states, attention_mask=self_attention_mask, position_bias=position_bias ) hidden_states = hidden_states + attn_output # Cross-attention (decoder) if self.is_decoder and encoder_hidden_states is not None: normed_states = self.cross_layer_norm(hidden_states) cross_output, _ = self.cross_attention( normed_states, key_value_states=encoder_hidden_states, attention_mask=cross_attention_mask ) hidden_states = hidden_states + cross_output # Pre-norm FFN normed_states = self.ff_layer_norm(hidden_states) hidden_states = hidden_states + self.ffn(normed_states) return hidden_states, position_biasThe encoder-decoder split is computationally efficient for seq2seq tasks. The encoder processes the input once with full bidirectional attention, then the decoder can attend to this cached representation during autoregressive generation. This is more efficient than a decoder-only model that must reprocess the prompt at each generation step.
T5's pre-training objective, span corruption, is a generalization of BERT's masked language modeling designed for the text-to-text framework and encoder-decoder architecture.
The Span Corruption Process:
Select spans for corruption (not individual tokens)
Replace spans with sentinel tokens <extra_id_0>, <extra_id_1>, etc.
Decoder targets are the original span contents with their sentinels
<extra_id_0> original span 1 <extra_id_1> original span 2 ...Example:
Original: "Thank you for inviting me to your party last week."
Corrupted: "Thank you <extra_id_0> to your party <extra_id_1> week."
Target: "<extra_id_0> for inviting me <extra_id_1> last <extra_id_2>"
Span corruption has several advantages: (1) It's computationally more efficient—the target sequence is much shorter than the input, (2) It forces the model to understand context well enough to fill multi-token gaps, (3) It naturally produces a text-to-text training signal, and (4) Spans are more semantically meaningful than individual tokens.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
import torchimport numpy as npfrom typing import List, Tuplefrom dataclasses import dataclass @dataclassclass SpanCorruptionConfig: """Configuration for T5 span corruption.""" noise_density: float = 0.15 # Fraction of tokens to corrupt mean_noise_span_length: float = 3.0 # Average span length sentinel_start_id: int = 32099 # <extra_id_0> in T5 tokenizer def random_spans_noise_mask( length: int, noise_density: float, mean_noise_span_length: float) -> np.ndarray: """ Create a mask indicating which tokens to corrupt. Returns boolean mask where True = corrupted. """ # Calculate number of noise tokens and spans num_noise_tokens = int(np.round(length * noise_density)) num_noise_spans = int(np.round(num_noise_tokens / mean_noise_span_length)) num_noise_spans = max(num_noise_spans, 1) num_nonnoise_tokens = length - num_noise_tokens # Randomly select span starting positions def _random_segmentation(num_items: int, num_segments: int) -> List[int]: """Partition num_items into num_segments random segments.""" # Use stick-breaking process first_in_segment = np.sort( np.random.choice(np.arange(1, num_items), num_segments - 1, replace=False) ) first_in_segment = np.concatenate([[0], first_in_segment, [num_items]]) return np.diff(first_in_segment).tolist() noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans) # Interleave noise and non-noise spans interleaved = [] for noise_len, nonnoise_len in zip(noise_span_lengths, nonnoise_span_lengths): interleaved.append(np.zeros(nonnoise_len, dtype=bool)) interleaved.append(np.ones(noise_len, dtype=bool)) mask = np.concatenate(interleaved) # Handle edge case where we generated slightly wrong length if len(mask) < length: mask = np.concatenate([mask, np.zeros(length - len(mask), dtype=bool)]) elif len(mask) > length: mask = mask[:length] return mask def create_span_corruption_data( input_ids: List[int], config: SpanCorruptionConfig) -> Tuple[List[int], List[int]]: """ Apply span corruption to input tokens. Returns: encoder_input: Corrupted input with sentinel tokens decoder_target: Original spans with sentinel tokens """ length = len(input_ids) mask = random_spans_noise_mask( length, config.noise_density, config.mean_noise_span_length ) encoder_input = [] decoder_target = [] sentinel_idx = 0 i = 0 while i < length: if not mask[i]: # Non-corrupted token: add to encoder input encoder_input.append(input_ids[i]) i += 1 else: # Start of corrupted span # Add sentinel to encoder input sentinel_token = config.sentinel_start_id - sentinel_idx encoder_input.append(sentinel_token) decoder_target.append(sentinel_token) # Add all tokens in span to decoder target while i < length and mask[i]: decoder_target.append(input_ids[i]) i += 1 sentinel_idx += 1 # Add final sentinel to decoder target final_sentinel = config.sentinel_start_id - sentinel_idx decoder_target.append(final_sentinel) return encoder_input, decoder_target class T5Dataset(torch.utils.data.Dataset): """ Dataset for T5 pre-training with span corruption. """ def __init__( self, texts: List[str], tokenizer, max_length: int = 512, config: SpanCorruptionConfig = None ): self.texts = texts self.tokenizer = tokenizer self.max_length = max_length self.config = config or SpanCorruptionConfig() def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] # Tokenize input_ids = self.tokenizer.encode(text, add_special_tokens=False)[:self.max_length] # Apply span corruption encoder_input, decoder_target = create_span_corruption_data(input_ids, self.config) return { 'encoder_input_ids': torch.tensor(encoder_input), 'decoder_input_ids': torch.tensor(decoder_target[:-1]), # Shift for AR 'labels': torch.tensor(decoder_target[1:]) # Target for loss }Pre-training Data: C4 (Colossal Clean Crawled Corpus)
T5 introduced C4, a massive cleaned web corpus designed specifically for language model pre-training:
This careful cleaning was shown to significantly improve downstream task performance compared to raw or lightly filtered web text.
T5's defining innovation is treating every task as text generation. This is achieved through task prefixes—short strings prepended to the input that specify the task.
Task Prefix Examples:
| Task | Input Format | Target |
|---|---|---|
| Translation | translate English to German: The house is wonderful. | Das Haus ist wunderbar. |
| Summarization | summarize: <long article text> | <summary text> |
| Classification | sst2 sentence: This movie was great! | positive |
| Question Answering | question: <question> context: <passage> | <answer> |
| Similarity | stsb sentence1: <s1> sentence2: <s2> | 3.5 |
Benefits of Text-to-Text:
For classification, T5 generates the class label as a word. This seems wasteful (generating multiple tokens when only an index is needed), but it enables transfer: a model that learned 'positive' in sentiment can transfer that understanding to other tasks using the same word. The semantic meaning of labels becomes useful.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
from dataclasses import dataclassfrom typing import Dict, List, Optional, Callablefrom enum import Enum class TaskType(Enum): CLASSIFICATION = "classification" GENERATION = "generation" SPAN_EXTRACTION = "span_extraction" REGRESSION = "regression" @dataclassclass T5TaskConfig: """Configuration for a T5 task.""" task_name: str prefix: str task_type: TaskType label_map: Optional[Dict[int, str]] = None max_target_length: int = 128 class T5TaskFormatters: """ Format various NLP tasks into T5's text-to-text format. """ @staticmethod def classification( input_text: str, prefix: str, label: Optional[int] = None, label_map: Optional[Dict[int, str]] = None ) -> Dict[str, str]: """Format classification task.""" formatted_input = f"{prefix}: {input_text}" target = label_map[label] if label is not None and label_map else None return {"input": formatted_input, "target": target} @staticmethod def sentence_pair_classification( sentence1: str, sentence2: str, prefix: str, label: Optional[int] = None, label_map: Optional[Dict[int, str]] = None ) -> Dict[str, str]: """Format sentence pair classification (NLI, paraphrase).""" formatted_input = f"{prefix} sentence1: {sentence1} sentence2: {sentence2}" target = label_map[label] if label is not None and label_map else None return {"input": formatted_input, "target": target} @staticmethod def translation( source_text: str, source_lang: str, target_lang: str, target_text: Optional[str] = None ) -> Dict[str, str]: """Format translation task.""" formatted_input = f"translate {source_lang} to {target_lang}: {source_text}" return {"input": formatted_input, "target": target_text} @staticmethod def summarization( document: str, summary: Optional[str] = None, prefix: str = "summarize" ) -> Dict[str, str]: """Format summarization task.""" formatted_input = f"{prefix}: {document}" return {"input": formatted_input, "target": summary} @staticmethod def question_answering( question: str, context: str, answer: Optional[str] = None ) -> Dict[str, str]: """Format extractive question answering.""" formatted_input = f"question: {question} context: {context}" return {"input": formatted_input, "target": answer} @staticmethod def regression( input_text: str, prefix: str, score: Optional[float] = None ) -> Dict[str, str]: """Format regression task (output score as text).""" formatted_input = f"{prefix}: {input_text}" # Round to 1 decimal place for stable generation target = f"{score:.1f}" if score is not None else None return {"input": formatted_input, "target": target} # Pre-defined task configurations for common benchmarksGLUE_TASK_CONFIGS = { "sst2": T5TaskConfig( task_name="sst2", prefix="sst2 sentence", task_type=TaskType.CLASSIFICATION, label_map={0: "negative", 1: "positive"} ), "mnli": T5TaskConfig( task_name="mnli", prefix="mnli hypothesis", task_type=TaskType.CLASSIFICATION, label_map={0: "entailment", 1: "neutral", 2: "contradiction"} ), "qnli": T5TaskConfig( task_name="qnli", prefix="qnli question", task_type=TaskType.CLASSIFICATION, label_map={0: "entailment", 1: "not_entailment"} ), "stsb": T5TaskConfig( task_name="stsb", prefix="stsb sentence1", task_type=TaskType.REGRESSION, max_target_length=5 ),} class T5DataProcessor: """ Process datasets into T5 format. """ def __init__(self, tokenizer, max_input_length: int = 512, max_target_length: int = 128): self.tokenizer = tokenizer self.max_input_length = max_input_length self.max_target_length = max_target_length def process_example(self, formatted: Dict[str, str]) -> Dict[str, any]: """Tokenize a formatted example.""" input_encoding = self.tokenizer( formatted["input"], max_length=self.max_input_length, truncation=True, padding="max_length", return_tensors="pt" ) result = { "input_ids": input_encoding.input_ids.squeeze(), "attention_mask": input_encoding.attention_mask.squeeze() } if formatted["target"] is not None: target_encoding = self.tokenizer( formatted["target"], max_length=self.max_target_length, truncation=True, padding="max_length", return_tensors="pt" ) result["labels"] = target_encoding.input_ids.squeeze() # Replace padding token id with -100 for loss computation result["labels"][result["labels"] == self.tokenizer.pad_token_id] = -100 return resultT5's text-to-text format naturally enables multi-task learning—training on multiple tasks simultaneously. The T5 paper extensively studied different approaches to multi-task training.
Multi-task Training Strategies:
Proportional mixing: Sample tasks proportionally to dataset size
Temperature-scaled mixing: Use temperature τ to flatten the distribution
Equal mixing: Sample uniformly across tasks
Gradual unfreezing: Start with pre-training, gradually add task data
The T5 paper found that multi-task pre-training followed by single-task fine-tuning performed about the same as pre-training alone followed by fine-tuning. The multi-task stage didn't significantly help. However, multi-task learning was useful for training a single model that works on multiple tasks without task-specific fine-tuning.
Transfer Learning Findings:
The T5 paper's extensive experiments revealed several insights about transfer:
Pre-training is crucial: Models pre-trained on unsupervised data significantly outperformed those trained only on supervised data
More pre-training helps: Performance continued improving with more pre-training steps, though with diminishing returns
Encoder-decoder matches or beats decoder-only: For most tasks, especially those requiring understanding input, encoder-decoder was equal or better
Span corruption is effective: The span corruption objective matched or beat other objectives while being computationally efficient
Scale helps, but with diminishing returns: Each 4x increase in model size gave approximately 2-3% improvement on benchmarks
| Model | Parameters | SuperGLUE Score |
|---|---|---|
| T5-Small | 60M | 66.4 |
| T5-Base | 220M | 79.1 |
| T5-Large | 770M | 85.5 |
| T5-3B | 3B | 88.5 |
| T5-11B | 11B | 90.1 |
| Human baseline | 89.8 |
Following T5's success, several important variants have been developed, each addressing different needs.
mT5 extends T5 to 101 languages using the mC4 (multilingual C4) dataset.
Key differences from T5:
ByT5 operates directly on raw bytes, eliminating the tokenizer entirely.
Advantages:
Trade-offs:
ByT5 excels on noisy text, code, morphologically rich languages, and tasks where subword tokenization hurts (e.g., character-level tasks). For standard English text tasks, regular T5 is more efficient. Consider ByT5 when you need a tokenizer-free solution or work with unusual inputs.
Flan-T5 applies the Flan instruction tuning methodology to T5, dramatically improving zero-shot and few-shot performance.
The Flan Recipe:
The result: A model that follows instructions much better than vanilla T5, approaching GPT-3.5 capabilities for many tasks at a fraction of the size.
| Variant | Key Feature | Best For |
|---|---|---|
| T5 | Original English text-to-text | Standard English NLP tasks |
| mT5 | 101 languages | Multilingual applications |
| ByT5 | Byte-level, no tokenizer | Noisy text, code, cross-lingual |
| Flan-T5 | Instruction tuning | Zero/few-shot instruction following |
| UL2 | Unified pre-training | Mixed denoising objectives |
Using T5 in practice involves understanding its generation-based inference, choosing the right variant, and optimizing for your specific use case.
Using Hugging Face Transformers:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
from transformers import T5ForConditionalGeneration, T5Tokenizerimport torch # Load model and tokenizermodel_name = "t5-base" # Or "flan-t5-base" for instruction-tunedtokenizer = T5Tokenizer.from_pretrained(model_name)model = T5ForConditionalGeneration.from_pretrained(model_name) def t5_predict(input_text: str, max_length: int = 128) -> str: """Generate output for any T5 task.""" inputs = tokenizer( input_text, return_tensors="pt", max_length=512, truncation=True ) outputs = model.generate( inputs.input_ids, max_length=max_length, num_beams=4, # Beam search for better quality early_stopping=True ) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Example: Translationtranslation = t5_predict("translate English to French: The house is beautiful.")print(f"Translation: {translation}") # Example: Summarizationarticle = '''Artificial intelligence has made remarkable strides in recent years, with large language models demonstrating capabilities that were thought to be decades away. These models can now write essays, answer questions, and even engage in creative writing tasks.'''summary = t5_predict(f"summarize: {article}")print(f"Summary: {summary}") # Example: Sentiment with Flan-T5 style promptsentiment = t5_predict( "Classify the sentiment of the following text as positive or negative: " "This restaurant exceeded all my expectations!")print(f"Sentiment: {sentiment}") class T5FineTuner: """ Fine-tune T5 on a custom task. """ def __init__( self, model_name: str = "t5-base", learning_rate: float = 3e-5, max_epochs: int = 3 ): self.tokenizer = T5Tokenizer.from_pretrained(model_name) self.model = T5ForConditionalGeneration.from_pretrained(model_name) self.learning_rate = learning_rate self.max_epochs = max_epochs def prepare_batch(self, batch: list[dict]) -> dict: """Prepare batch for training.""" inputs = [item["input"] for item in batch] targets = [item["target"] for item in batch] input_encodings = self.tokenizer( inputs, padding=True, truncation=True, max_length=512, return_tensors="pt" ) target_encodings = self.tokenizer( targets, padding=True, truncation=True, max_length=128, return_tensors="pt" ) labels = target_encodings.input_ids labels[labels == self.tokenizer.pad_token_id] = -100 return { "input_ids": input_encodings.input_ids, "attention_mask": input_encodings.attention_mask, "labels": labels } def train_step(self, batch: dict, optimizer) -> float: """Single training step.""" self.model.train() outputs = self.model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"] ) loss = outputs.loss loss.backward() optimizer.step() optimizer.zero_grad() return loss.item() # Inference with beam search vs samplingdef generate_with_options(input_text: str, method: str = "beam") -> str: """Generate with different decoding strategies.""" inputs = tokenizer(input_text, return_tensors="pt") if method == "beam": outputs = model.generate( inputs.input_ids, max_length=128, num_beams=4, early_stopping=True ) elif method == "sample": outputs = model.generate( inputs.input_ids, max_length=128, do_sample=True, top_k=50, top_p=0.95, temperature=0.7 ) elif method == "diverse_beam": outputs = model.generate( inputs.input_ids, max_length=128, num_beams=4, num_beam_groups=4, diversity_penalty=0.5 ) return tokenizer.decode(outputs[0], skip_special_tokens=True)You now understand T5's encoder-decoder architecture, span corruption pre-training, text-to-text task formulation, and its variants like mT5 and Flan-T5. T5 demonstrated that a unified text-to-text approach can match or beat task-specific architectures while simplifying the ML pipeline. Next, we'll explore efficient transformers—architectures designed to overcome the quadratic attention bottleneck.