Loading content...
For decades, convolutional neural networks (CNNs) dominated computer vision. From LeNet to ResNet, the inductive bias of convolutions—local connectivity, translation equivariance, parameter sharing—seemed perfectly suited to visual data. Then, in October 2020, Google's "An Image is Worth 16x16 Words" demonstrated that a pure transformer, with minimal image-specific modifications, could match or exceed the best CNNs on image classification.
The Vision Transformer (ViT) treats an image as a sequence of patches, applies standard transformer architecture, and achieves state-of-the-art results when pre-trained on sufficient data. This success sparked a fundamental rethinking of computer vision architectures.
This page covers how Vision Transformers convert images to sequences, the complete ViT architecture, pre-training strategies including CLIP and MAE, and extensions like Swin Transformer and DETR. You'll understand when ViTs outperform CNNs, their data requirements, and how to apply them effectively.
Why Transformers for Vision?
Convolutions have strong inductive biases that help with limited data:
But these biases also limit the model:
Transformers have weaker inductive biases (they don't "know" about spatial structure) but more flexibility to learn arbitrary patterns. Given sufficient data, this flexibility wins.
The key insight of ViT is treating image patches as "visual words." An image is divided into fixed-size patches, each patch is linearly projected to an embedding vector, and the resulting sequence is processed by a standard transformer.
The Patch Embedding Process:
Divide: Split image into non-overlapping P × P patches
Flatten: Convert each patch to a 1D vector
Project: Linear projection to embedding dimension
Add position: Add learnable position embeddings
Add [CLS]: Prepend classification token
| Model | Image Size | Patch Size | Patches | Hidden Dim | Layers | Heads | Params |
|---|---|---|---|---|---|---|---|
| ViT-B/16 | 224×224 | 16×16 | 196 | 768 | 12 | 12 | 86M |
| ViT-B/32 | 224×224 | 32×32 | 49 | 768 | 12 | 12 | 88M |
| ViT-L/16 | 224×224 | 16×16 | 196 | 1024 | 24 | 16 | 307M |
| ViT-H/14 | 224×224 | 14×14 | 256 | 1280 | 32 | 16 | 632M |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom einops import rearrangefrom typing import Optional class PatchEmbedding(nn.Module): """ Convert image to sequence of patch embeddings. Image: [B, C, H, W] -> Patches: [B, N, D] where N = (H/P) * (W/P) and D is embedding dimension """ def __init__( self, image_size: int = 224, patch_size: int = 16, in_channels: int = 3, embed_dim: int = 768 ): super().__init__() self.image_size = image_size self.patch_size = patch_size self.num_patches = (image_size // patch_size) ** 2 # Linear projection implemented as Conv2d with kernel_size = patch_size # This is equivalent to flatten + linear but more efficient self.projection = nn.Conv2d( in_channels, embed_dim, kernel_size=patch_size, stride=patch_size ) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [batch, channels, height, width] x = self.projection(x) # [batch, embed_dim, H/P, W/P] x = rearrange(x, 'b d h w -> b (h w) d') # [batch, num_patches, embed_dim] return x class ViTAttention(nn.Module): """ Multi-head self-attention for Vision Transformer. Standard transformer attention, nothing vision-specific. """ def __init__( self, embed_dim: int = 768, num_heads: int = 12, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0 ): super().__init__() self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scale = self.head_dim ** -0.5 self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(embed_dim, embed_dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape # Compute Q, K, V qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, head_dim] q, k, v = qkv.unbind(0) # Attention attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class ViTMLP(nn.Module): """ MLP block for Vision Transformer. Typically expands dimension by 4x, then contracts. """ def __init__( self, embed_dim: int = 768, mlp_ratio: float = 4.0, dropout: float = 0.0 ): super().__init__() hidden_dim = int(embed_dim * mlp_ratio) self.fc1 = nn.Linear(embed_dim, hidden_dim) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_dim, embed_dim) self.drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class ViTBlock(nn.Module): """ Transformer block for Vision Transformer. Pre-norm architecture: LayerNorm before attention/MLP. """ def __init__( self, embed_dim: int = 768, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, dropout: float = 0.0, attn_drop: float = 0.0 ): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.attn = ViTAttention(embed_dim, num_heads, qkv_bias, attn_drop, dropout) self.norm2 = nn.LayerNorm(embed_dim) self.mlp = ViTMLP(embed_dim, mlp_ratio, dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x class VisionTransformer(nn.Module): """ Vision Transformer (ViT) for image classification. """ def __init__( self, image_size: int = 224, patch_size: int = 16, in_channels: int = 3, num_classes: int = 1000, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, dropout: float = 0.0, attn_drop: float = 0.0 ): super().__init__() self.num_classes = num_classes self.embed_dim = embed_dim # Patch embedding self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim) num_patches = self.patch_embed.num_patches # Class token and position embeddings self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, embed_dim)) self.pos_drop = nn.Dropout(dropout) # Transformer blocks self.blocks = nn.ModuleList([ ViTBlock(embed_dim, num_heads, mlp_ratio, qkv_bias, dropout, attn_drop) for _ in range(depth) ]) # Classification head self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, num_classes) # Initialize weights self._init_weights() def _init_weights(self): # Initialize position embeddings with truncated normal nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: B = x.shape[0] # Patch embedding x = self.patch_embed(x) # [B, N, D] # Prepend class token cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat([cls_tokens, x], dim=1) # [B, N+1, D] # Add position embeddings x = x + self.pos_embed x = self.pos_drop(x) # Transformer blocks for block in self.blocks: x = block(x) # Classification from [CLS] token x = self.norm(x) cls_output = x[:, 0] # First token is [CLS] return self.head(cls_output)Smaller patch sizes create more tokens, giving finer granularity but higher computational cost (quadratic in number of patches). ViT-B/16 (patch size 16) typically outperforms ViT-B/32, but requires 4× more compute. For efficiency-constrained settings, larger patches or hierarchical approaches (Swin) are preferred.
ViT's performance is highly dependent on pre-training. Unlike CNNs, which work well with limited data due to their inductive biases, ViT requires large-scale pre-training to learn visual patterns from scratch.
The original ViT paper used ImageNet-21K (14M images) and JFT-300M (300M images):
| Pre-training Data | ImageNet Accuracy |
|---|---|
| ImageNet-1K only | 77.9% (worse than ResNet) |
| ImageNet-21K | 84.4% |
| JFT-300M | 88.6% |
Key insight: ViT only surpasses CNNs when pre-trained on >10M images. With limited data, CNNs' inductive biases provide a significant advantage.
Facebook's DeiT (2021) showed that with proper training techniques, ViT could match CNNs on ImageNet-1K alone:
Key Techniques:
DeiT achieved 85.2% accuracy on ImageNet-1K without external data, proving that training recipes matter as much as scale.
CLIP (OpenAI, 2021) pre-trains vision encoders (including ViT) by predicting which caption goes with which image across 400M image-text pairs. CLIP-trained ViTs have remarkable zero-shot transfer: they can classify images into arbitrary categories described in natural language, without any task-specific training.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Tuple class CLIPVisionEncoder(nn.Module): """ CLIP-style Vision Transformer. Encodes images to a shared embedding space with text. """ def __init__( self, image_size: int = 224, patch_size: int = 16, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, output_dim: int = 512 # Shared embedding dimension ): super().__init__() self.vit = VisionTransformer( image_size=image_size, patch_size=patch_size, embed_dim=embed_dim, depth=depth, num_heads=num_heads, num_classes=output_dim # Project to shared space ) def forward(self, images: torch.Tensor) -> torch.Tensor: """Encode images to normalized embeddings.""" features = self.vit(images) return F.normalize(features, dim=-1) class CLIP(nn.Module): """ CLIP: Contrastive Language-Image Pre-training. Learns aligned image and text embeddings. """ def __init__( self, vision_encoder: nn.Module, text_encoder: nn.Module, embed_dim: int = 512 ): super().__init__() self.vision_encoder = vision_encoder self.text_encoder = text_encoder # Learnable temperature parameter self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07))) def forward( self, images: torch.Tensor, texts: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute image and text embeddings. Returns: image_features: [batch, embed_dim] text_features: [batch, embed_dim] """ image_features = self.vision_encoder(images) text_features = self.text_encoder(texts) return image_features, text_features def compute_loss( self, image_features: torch.Tensor, text_features: torch.Tensor ) -> torch.Tensor: """ Contrastive loss: images should match their corresponding texts. """ # Compute similarity matrix logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.T logits_per_text = logits_per_image.T # Create labels (diagonal is positive) batch_size = image_features.shape[0] labels = torch.arange(batch_size, device=image_features.device) # Symmetric cross-entropy loss loss_i = F.cross_entropy(logits_per_image, labels) loss_t = F.cross_entropy(logits_per_text, labels) return (loss_i + loss_t) / 2 def clip_zero_shot_classify( model: CLIP, images: torch.Tensor, class_names: list[str], text_encoder_fn # Function to encode text) -> torch.Tensor: """ Zero-shot classification with CLIP. Args: model: CLIP model images: [batch, 3, H, W] class_names: List of class names text_encoder_fn: Function that encodes text to embeddings Returns: Predicted class indices [batch] """ # Encode class prompts prompts = [f"a photo of a {name}" for name in class_names] text_features = text_encoder_fn(prompts) # [num_classes, embed_dim] text_features = F.normalize(text_features, dim=-1) # Encode images image_features = model.vision_encoder(images) # [batch, embed_dim] # Compute similarity similarity = image_features @ text_features.T # [batch, num_classes] return similarity.argmax(dim=-1)Inspired by BERT's masked language modeling, MAE (2021) proposed masked image modeling:
Why 75% masking works:
MAE is remarkably efficient: reconstructing only 25% of patches during training speeds up pre-training by 3-4×.
Use supervised pre-training (ImageNet-21k) for classification-focused tasks. Use CLIP for zero-shot or multi-modal applications. Use MAE for dense prediction tasks (segmentation, detection) where learning local structure matters. For most applications, start with a CLIP-pretrained ViT—it's the most versatile.
The Swin Transformer (2021) addressed ViT's computational and architectural limitations for dense prediction tasks. It introduced a hierarchical structure with local attention, making it suitable for object detection, semantic segmentation, and other tasks requiring multi-scale features.
Key Innovations:
Hierarchical Feature Maps: Unlike ViT's fixed resolution, Swin produces feature maps at 1/4, 1/8, 1/16, 1/32 resolution (like CNNs)
Window-based Attention: Attention computed within local windows, not globally
Shifted Windows: Windows shift between layers for cross-window information flow
Patch Merging: Pools patches to reduce resolution between stages
| Aspect | ViT | Swin Transformer |
|---|---|---|
| Attention scope | Global (all patches) | Local (within windows) |
| Complexity | O(n²) in patches | O(n × w²) where w is window size |
| Feature hierarchy | Single resolution | Multi-scale pyramid |
| Dense prediction | Requires modification | Native support |
| Position encoding | Absolute | Relative |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom einops import rearrangefrom typing import Tuple, Optional def window_partition(x: torch.Tensor, window_size: int) -> torch.Tensor: """ Partition feature map into non-overlapping windows. Args: x: [B, H, W, C] window_size: Window size Returns: windows: [B * num_windows, window_size, window_size, C] """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() windows = windows.view(-1, window_size, window_size, C) return windows def window_reverse( windows: torch.Tensor, window_size: int, H: int, W: int) -> torch.Tensor: """ Reverse window partition. Args: windows: [B * num_windows, window_size, window_size, C] window_size: Window size H, W: Original feature map dimensions Returns: x: [B, H, W, C] """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): """ Window-based multi-head self-attention with relative position bias. """ def __init__( self, dim: int, window_size: Tuple[int, int], num_heads: int, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0 ): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 # Relative position bias table # Table size: (2M-1) × (2M-1) for window size M self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) ) # Compute relative position index coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # [2, H, W] coords_flatten = torch.flatten(coords, 1) # [2, H*W] relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, H*W, H*W] relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [H*W, H*W, 2] relative_coords[:, :, 0] += self.window_size[0] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # [H*W, H*W] self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: x: [num_windows * B, N, C] where N = window_size^2 mask: [num_windows, N, N] for shifted window attention """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) q = q * self.scale attn = (q @ k.transpose(-2, -1)) # Add relative position bias relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 ) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0) # Apply mask for shifted windows if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): """ Swin Transformer block with window attention and shifted windows. """ def __init__( self, dim: int, num_heads: int, window_size: int = 7, shift_size: int = 0, mlp_ratio: float = 4.0, qkv_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0 ): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio self.norm1 = nn.LayerNorm(dim) self.attn = WindowAttention( dim, window_size=(window_size, window_size), num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop ) self.norm2 = nn.LayerNorm(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, mlp_hidden_dim), nn.GELU(), nn.Dropout(drop), nn.Linear(mlp_hidden_dim, dim), nn.Dropout(drop) ) def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: B, L, C = x.shape shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # Cyclic shift for shifted windows if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # Create attention mask for shifted windows attn_mask = self._create_mask(H, W, x.device) else: shifted_x = x attn_mask = None # Partition into windows x_windows = window_partition(shifted_x, self.window_size) x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # Window attention attn_windows = self.attn(x_windows, mask=attn_mask) # Merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # Reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # Residual and MLP x = shortcut + x x = x + self.mlp(self.norm2(x)) return x def _create_mask(self, H: int, W: int, device: torch.device) -> torch.Tensor: """Create attention mask for shifted window attention.""" # Implementation creates mask that prevents attending across shift boundaries # Simplified for brevity img_mask = torch.zeros((1, H, W, 1), device=device) # ... mask creation logic return None # SimplifiedIn layer 1, windows are 7×7 grids. In layer 2, windows shift by 3.5 pixels (half window size). This means patches at window boundaries in layer 1 are now interior patches in layer 2, enabling cross-window information flow without increasing complexity.
Vision Transformers have been adapted for tasks beyond classification, including object detection, semantic segmentation, and other dense prediction tasks.
DETR (2020) was the first end-to-end object detection model using transformers:
Architecture:
Key Innovation: DETR eliminates hand-designed components like anchor boxes, NMS (non-maximum suppression), and region proposals. The model directly predicts a fixed set of bounding boxes and class labels.
| Task | Model | Key Innovation |
|---|---|---|
| Object Detection | DETR, Deformable DETR | End-to-end detection with set prediction |
| Semantic Segmentation | SETR, SegFormer | Dense prediction with transformer features |
| Instance Segmentation | Mask2Former | Unified mask prediction |
| Depth Estimation | DPT | Dense prediction transformer |
| Video Understanding | ViViT, TimeSformer | Spatio-temporal attention |
SegFormer (2021) designed an efficient segmentation architecture:
Key Features:
SegFormer achieves state-of-the-art segmentation with fewer parameters and FLOPs than previous methods.
For dense prediction tasks:
For classification: ViT or CLIP-ViT. For detection: Swin or DINO-ViT. For segmentation: Swin, SegFormer, or Mask2Former. For multi-modal (vision-language): CLIP-ViT or SigLIP. When in doubt, Swin Transformer is a safe default for most vision tasks.
Modern vision architectures often combine the best of CNNs and transformers.
Early Fusion:
Late Fusion:
Alternating:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
from transformers import ViTForImageClassification, ViTImageProcessorfrom PIL import Imageimport torch # Using pre-trained ViT for classificationdef classify_image(image_path: str, top_k: int = 5): """ Classify image using pre-trained ViT from Hugging Face. """ # Load model and processor model_name = "google/vit-base-patch16-224" processor = ViTImageProcessor.from_pretrained(model_name) model = ViTForImageClassification.from_pretrained(model_name) # Load and preprocess image image = Image.open(image_path) inputs = processor(images=image, return_tensors="pt") # Inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get top-k predictions probs = torch.softmax(logits, dim=-1) top_probs, top_indices = probs[0].topk(top_k) results = [] for prob, idx in zip(top_probs, top_indices): label = model.config.id2label[idx.item()] results.append({"label": label, "probability": prob.item()}) return results # Fine-tuning ViT for custom classificationimport torch.nn as nnfrom torch.optim import AdamWfrom torch.optim.lr_scheduler import CosineAnnealingLR class ViTFineTuner: """ Fine-tune ViT for custom classification task. """ def __init__( self, model_name: str = "google/vit-base-patch16-224", num_classes: int = 10, learning_rate: float = 1e-4, freeze_backbone: bool = False ): from transformers import ViTModel self.processor = ViTImageProcessor.from_pretrained(model_name) self.backbone = ViTModel.from_pretrained(model_name) if freeze_backbone: for param in self.backbone.parameters(): param.requires_grad = False # New classification head self.classifier = nn.Sequential( nn.LayerNorm(self.backbone.config.hidden_size), nn.Linear(self.backbone.config.hidden_size, num_classes) ) self.learning_rate = learning_rate def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: outputs = self.backbone(pixel_values=pixel_values) cls_output = outputs.last_hidden_state[:, 0] # [CLS] token return self.classifier(cls_output) def configure_optimizers(self, num_epochs: int, steps_per_epoch: int): optimizer = AdamW( list(self.backbone.parameters()) + list(self.classifier.parameters()), lr=self.learning_rate, weight_decay=0.01 ) scheduler = CosineAnnealingLR( optimizer, T_max=num_epochs * steps_per_epoch ) return optimizer, scheduler # Attention visualizationdef visualize_attention(model, image, layer_idx: int = -1, head_idx: int = 0): """ Visualize attention weights from ViT. Returns attention map that can be overlaid on original image. """ from transformers import ViTModel import numpy as np processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") inputs = processor(images=image, return_tensors="pt") # Get attention weights with torch.no_grad(): outputs = model( **inputs, output_attentions=True ) # Extract attention from specified layer # attentions: tuple of [batch, heads, seq, seq] attention = outputs.attentions[layer_idx] # [1, heads, N+1, N+1] # Get attention from [CLS] token to patches cls_attention = attention[0, head_idx, 0, 1:] # [N] # Reshape to 2D (patch grid) num_patches = int(cls_attention.shape[0] ** 0.5) attention_map = cls_attention.reshape(num_patches, num_patches) # Upsample to image size attention_map = attention_map.numpy() return attention_mapYou now understand Vision Transformers comprehensively—from patch embedding to hierarchical Swin Transformer, from supervised pre-training to CLIP and MAE, and from classification to dense prediction. ViTs have largely unified computer vision under the transformer architecture, enabling transfer learning across modalities and tasks. This concludes Module 6: Transformer Variants, covering the full spectrum of transformer architectures that power modern AI.