Loading content...
Masked modeling revolutionized natural language processing through BERT and GPT. The core idea is elegantly simple: hide parts of the input and train the model to reconstruct them. This forces the model to learn deep semantic relationships to fill in missing information.
Transferring this success to computer vision required overcoming fundamental differences between language and images:
Modern masked image modeling methods solve these challenges, achieving state-of-the-art results.
Masking creates a powerful pretext task because predicting masked content requires understanding context. To predict a missing word, you need grammar and semantics. To predict missing image patches, you need understanding of objects, textures, and spatial relationships.
MAE, introduced by He et al. (2022) at Meta AI, demonstrates that simple masked reconstruction with high masking ratios (75%) produces excellent visual representations.
Encoder: Processes only visible (unmasked) patches
Decoder: Reconstructs full image from encoded patches + mask tokens
Key insight: Asymmetric encoder-decoder design. Heavy encoder sees few patches; light decoder reconstructs all.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
import torchimport torch.nn as nnfrom einops import rearrange class MAE(nn.Module): """Masked Autoencoder for self-supervised learning.""" def __init__(self, encoder, decoder_dim=512, decoder_depth=8, patch_size=16, mask_ratio=0.75): super().__init__() self.mask_ratio = mask_ratio self.patch_size = patch_size self.encoder = encoder # Decoder components self.decoder_embed = nn.Linear(encoder.embed_dim, decoder_dim) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) self.decoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(decoder_dim, 8, decoder_dim * 4), num_layers=decoder_depth ) self.decoder_pred = nn.Linear(decoder_dim, patch_size ** 2 * 3) nn.init.normal_(self.mask_token, std=0.02) def random_masking(self, x): """Random masking: keep subset of patches.""" N, L, D = x.shape len_keep = int(L * (1 - self.mask_ratio)) noise = torch.rand(N, L, device=x.device) ids_shuffle = torch.argsort(noise, dim=1) ids_restore = torch.argsort(ids_shuffle, dim=1) ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1,-1,D)) mask = torch.ones(N, L, device=x.device) mask[:, :len_keep] = 0 mask = torch.gather(mask, 1, ids_restore) return x_masked, mask, ids_restore def forward(self, imgs): # Patchify and embed patches = self.encoder.patch_embed(imgs) patches = patches + self.encoder.pos_embed[:, 1:, :] # Mask and encode visible patches only visible, mask, ids_restore = self.random_masking(patches) encoded = self.encoder.blocks(visible) encoded = self.encoder.norm(encoded) # Decode: embed + add mask tokens + restore order x = self.decoder_embed(encoded) mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1]-x.shape[1], 1) x = torch.cat([x, mask_tokens], dim=1) x = torch.gather(x, 1, ids_restore.unsqueeze(-1).expand(-1,-1,x.shape[-1])) x = self.decoder(x) pred = self.decoder_pred(x) # Loss on masked patches only target = self.patchify(imgs) loss = ((pred - target) ** 2).mean(dim=-1) loss = (loss * mask).sum() / mask.sum() return loss def patchify(self, imgs): p = self.patch_size h = w = imgs.shape[2] // p x = rearrange(imgs, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) return xBEiT (BERT Pre-training of Image Transformers) by Bao et al. (2021) brings BERT's approach to vision by predicting discrete visual tokens rather than raw pixels.
BEiT uses a pre-trained VQ-VAE (dVAE) to convert images to discrete tokens:
$$\mathcal{L} = -\mathbb{E}\left[\sum_{i \in M} \log P(z_i | x_{\backslash M})\right]$$
where $M$ is the set of masked positions, $z_i$ is the discrete token, and $x_{\backslash M}$ is the visible context.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
import torchimport torch.nn as nnimport torch.nn.functional as F class BEiT(nn.Module): """BEiT: BERT Pre-Training of Image Transformers.""" def __init__(self, encoder, tokenizer, vocab_size=8192, mask_ratio=0.4): super().__init__() self.encoder = encoder self.tokenizer = tokenizer # Pre-trained dVAE self.mask_ratio = mask_ratio # Prediction head: predict discrete token self.head = nn.Linear(encoder.embed_dim, vocab_size) # Learnable mask token self.mask_token = nn.Parameter(torch.zeros(1, 1, encoder.embed_dim)) # Freeze tokenizer for p in self.tokenizer.parameters(): p.requires_grad = False @torch.no_grad() def get_visual_tokens(self, imgs): """Get discrete tokens from dVAE tokenizer.""" return self.tokenizer.get_codebook_indices(imgs) def forward(self, imgs): # Get target tokens with torch.no_grad(): target_tokens = self.get_visual_tokens(imgs) # Embed patches x = self.encoder.patch_embed(imgs) B, L, D = x.shape # Create random mask num_mask = int(L * self.mask_ratio) mask_idx = torch.rand(B, L, device=x.device).argsort(dim=1)[:, :num_mask] # Replace masked positions with mask token mask_tokens = self.mask_token.expand(B, L, -1) mask = torch.zeros(B, L, device=x.device).scatter_(1, mask_idx, 1).bool() x = torch.where(mask.unsqueeze(-1), mask_tokens, x) # Encode x = x + self.encoder.pos_embed[:, 1:, :] x = self.encoder.blocks(x) x = self.encoder.norm(x) # Predict only masked positions x_masked = x[mask].reshape(B, num_mask, -1) logits = self.head(x_masked) # Cross-entropy loss targets = target_tokens[mask].reshape(B, num_mask) loss = F.cross_entropy(logits.transpose(1, 2), targets) return loss| Aspect | MAE | BEiT |
|---|---|---|
| Target | Raw pixels | Discrete tokens |
| Mask ratio | 75% | 40% |
| Tokenizer | None | dVAE (pre-trained) |
| Loss | MSE | Cross-entropy |
| Encoder efficiency | 4x faster (masked) | Full image |
SimMIM (Simple Framework for Masked Image Modeling) by Xie et al. (2022) simplifies masked modeling by directly predicting raw pixels with a simple L1 loss, showing that elaborate tokenizers aren't necessary.
123456789101112131415161718192021222324252627282930313233343536373839
class SimMIM(nn.Module): """SimMIM: Simple masked image modeling.""" def __init__(self, encoder, patch_size=32, mask_ratio=0.6): super().__init__() self.encoder = encoder self.patch_size = patch_size self.mask_ratio = mask_ratio # Simple linear prediction head self.head = nn.Linear(encoder.embed_dim, patch_size ** 2 * 3) self.mask_token = nn.Parameter(torch.zeros(1, 1, encoder.embed_dim)) def forward(self, imgs): B, C, H, W = imgs.shape p = self.patch_size num_patches = (H // p) * (W // p) # Random mask num_mask = int(num_patches * self.mask_ratio) mask_idx = torch.rand(B, num_patches, device=imgs.device).argsort(1)[:,:num_mask] mask = torch.zeros(B, num_patches, device=imgs.device) mask.scatter_(1, mask_idx, 1) # Embed and mask x = self.encoder.patch_embed(imgs) mask_expanded = mask.unsqueeze(-1).expand(-1, -1, x.shape[-1]) x = x * (1 - mask_expanded) + self.mask_token * mask_expanded # Encode and predict x = self.encoder.forward_features(x) pred = self.head(x) # [B, num_patches, p*p*3] # Target: original pixel values target = rearrange(imgs, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) # L1 loss on masked patches only loss = F.l1_loss(pred[mask.bool()], target[mask.bool()]) return lossData2Vec by Baevski et al. (2022) unifies masked modeling across modalities (vision, speech, text) by predicting contextualized representations rather than raw inputs or discrete tokens.
This approach lets the model learn to predict rich, contextualized features rather than low-level signals.
123456789101112131415161718192021222324252627282930313233343536
class Data2Vec(nn.Module): """Data2Vec: Unified self-supervised learning.""" def __init__(self, encoder, ema_decay=0.9998, top_k=8): super().__init__() self.student = encoder self.teacher = copy.deepcopy(encoder) self.ema_decay = ema_decay self.top_k = top_k # Freeze teacher for p in self.teacher.parameters(): p.requires_grad = False # Prediction head self.pred_head = nn.Linear(encoder.embed_dim, encoder.embed_dim) @torch.no_grad() def update_teacher(self): for s, t in zip(self.student.parameters(), self.teacher.parameters()): t.data = self.ema_decay * t.data + (1 - self.ema_decay) * s.data def forward(self, x, mask): # Teacher: full input, get last K layers with torch.no_grad(): teacher_features = self.teacher.get_intermediate_layers(x, n=self.top_k) target = torch.stack(teacher_features, dim=0).mean(dim=0) target = F.layer_norm(target, target.shape[-1:]) # Student: masked input student_out = self.student(x, mask=mask) pred = self.pred_head(student_out) # Smooth L1 loss on masked positions loss = F.smooth_l1_loss(pred[mask], target[mask], beta=2.0) return lossI-JEPA (Image Joint Embedding Predictive Architecture) by Assran et al. (2023) predicts abstract representations of target regions rather than pixels, avoiding the pixel-level details that can dominate reconstruction losses.
Instead of reconstructing masked patches in pixel space:
This forces learning of semantic features rather than texture details.
1234567891011121314151617181920212223242526272829303132
class IJEPA(nn.Module): """I-JEPA: Joint Embedding Predictive Architecture.""" def __init__(self, context_encoder, predictor, target_encoder): super().__init__() self.context_encoder = context_encoder self.predictor = predictor # Lightweight transformer self.target_encoder = target_encoder # EMA of context encoder for p in self.target_encoder.parameters(): p.requires_grad = False def forward(self, x, context_masks, target_masks): # Encode context (visible regions) context = self.context_encoder(x, mask=context_masks) # Predict target representations predictions = [] for target_mask in target_masks: pred = self.predictor(context, target_mask) predictions.append(pred) # Get target representations from EMA encoder with torch.no_grad(): targets = [] for target_mask in target_masks: target = self.target_encoder(x) targets.append(target[target_mask]) # Loss: match predicted to target representations loss = sum(F.mse_loss(p, t) for p, t in zip(predictions, targets)) return loss / len(predictions)Pixel-level reconstruction can focus on textures and low-level patterns. By predicting in representation space, I-JEPA learns higher-level semantic features that transfer better to downstream tasks requiring object understanding.
| Method | Target | Key Innovation |
|---|---|---|
| MAE | Pixels | 75% masking, asymmetric encoder-decoder |
| BEiT | Discrete tokens | dVAE tokenizer, BERT-style training |
| SimMIM | Pixels | Simple L1 loss, large mask patches |
| Data2Vec | Features | Predict contextualized teacher features |
| I-JEPA | Representations | Predict abstract target embeddings |
You now understand masked modeling approaches from MAE through I-JEPA. These methods learn powerful visual representations by reconstructing hidden portions of images. Next, we'll explore vision-language pre-training that connects visual and textual understanding.