Loading content...
So far, we've seen self-supervised methods that rely on instance discrimination—treating each image as its own class and learning to distinguish between them. But this approach ignores a fundamental property of visual data: natural images form semantic clusters.
Images of dogs cluster together. Images of cars cluster together. Even without labels, the underlying data manifold has structure. Clustering-based self-supervised learning exploits this structure directly, using cluster assignments as pseudo-labels to guide representation learning.
Instead of contrasting every image against every other image, we can:
This approach is more computationally efficient (no need for thousands of negatives) and captures higher-level semantic structure (images in the same cluster should have related representations).
Clustering-based methods face a fundamental challenge: we need good representations to cluster well, but we need good clusters to learn good representations. Breaking this circularity is the key innovation of methods like DeepCluster and SwAV. They solve it through alternating optimization, online clustering, and careful design choices that prevent degenerate solutions.
DeepCluster, introduced by Caron et al. (2018) at Facebook AI Research, established the foundational approach for clustering-based self-supervised learning. Its elegance lies in repeatedly alternating between two simple operations:
Step 1: Feature Extraction Use the current neural network to extract features for all images in the dataset.
Step 2: Clustering Apply k-means clustering to the extracted features, assigning each image to one of K clusters.
Step 3: Training Treat cluster assignments as pseudo-labels and train the network with standard cross-entropy loss.
Repeat until convergence.
Let $f_θ$ be the feature extractor with parameters θ. For dataset ${x_1, ..., x_n}$:
where $g_θ$ is a classification head and $\mathcal{L}$ is cross-entropy loss.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom sklearn.cluster import KMeansimport numpy as npfrom typing import Tuple, List class DeepCluster(nn.Module): """ DeepCluster: Learning Visual Features by Clustering. Alternates between: 1. K-means clustering on current features 2. Training network to predict cluster assignments Key insight: Even random clusters provide useful signal; as features improve, clusters become more semantic. """ def __init__( self, backbone: nn.Module, num_clusters: int = 10000, projection_dim: int = 128 ): super().__init__() self.backbone = backbone self.num_clusters = num_clusters # Get feature dimension from backbone feature_dim = backbone.fc.in_features backbone.fc = nn.Identity() # Projection head for clustering self.projector = nn.Sequential( nn.Linear(feature_dim, projection_dim * 4), nn.BatchNorm1d(projection_dim * 4), nn.ReLU(inplace=True), nn.Linear(projection_dim * 4, projection_dim) ) # Classification head (output size = num_clusters) self.classifier = nn.Linear(projection_dim, num_clusters) # Store current cluster assignments self.cluster_assignments = None def extract_features(self, x: torch.Tensor) -> torch.Tensor: """Extract normalized features for clustering.""" with torch.no_grad(): features = self.backbone(x) features = self.projector(features) features = F.normalize(features, dim=1) return features def compute_cluster_assignments( self, dataloader, device: torch.device ) -> np.ndarray: """ Run k-means on all features and return cluster assignments. This is the 'offline' step that groups similar images. """ self.eval() all_features = [] all_indices = [] # Extract features for all images for batch_idx, (images, _) in enumerate(dataloader): features = self.extract_features(images.to(device)) all_features.append(features.cpu().numpy()) # Track original indices for assignment mapping batch_size = images.size(0) start_idx = batch_idx * batch_size all_indices.extend(range(start_idx, start_idx + batch_size)) features = np.concatenate(all_features, axis=0) # Run k-means clustering print(f"Running k-means with {self.num_clusters} clusters...") kmeans = KMeans( n_clusters=self.num_clusters, n_init=10, max_iter=300, random_state=42 ) assignments = kmeans.fit_predict(features) # Check for empty clusters and reassign if necessary assignments = self._handle_empty_clusters(assignments, features) self.cluster_assignments = assignments return assignments def _handle_empty_clusters( self, assignments: np.ndarray, features: np.ndarray ) -> np.ndarray: """ Handle empty clusters to prevent training issues. Empty clusters cause problems in cross-entropy loss. We reassign samples from large clusters to empty ones. """ cluster_counts = np.bincount(assignments, minlength=self.num_clusters) empty_clusters = np.where(cluster_counts == 0)[0] if len(empty_clusters) > 0: print(f"Found {len(empty_clusters)} empty clusters, reassigning...") # Find largest clusters sorted_clusters = np.argsort(cluster_counts)[::-1] for i, empty_idx in enumerate(empty_clusters): # Take from largest cluster source_cluster = sorted_clusters[i % len(sorted_clusters)] source_samples = np.where(assignments == source_cluster)[0] # Move half to empty cluster num_to_move = len(source_samples) // 2 samples_to_move = np.random.choice( source_samples, num_to_move, replace=False ) assignments[samples_to_move] = empty_idx return assignments def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass returning logits over clusters.""" features = self.backbone(x) projected = self.projector(features) logits = self.classifier(projected) return logits def train_deepcluster_epoch( model: DeepCluster, dataloader, optimizer, cluster_assignments: np.ndarray, device: torch.device) -> float: """ Train for one epoch using current cluster assignments. Standard classification training with pseudo-labels. """ model.train() total_loss = 0 num_batches = 0 for batch_idx, (images, _) in enumerate(dataloader): # Get cluster assignments for this batch start_idx = batch_idx * images.size(0) end_idx = start_idx + images.size(0) targets = torch.tensor( cluster_assignments[start_idx:end_idx], dtype=torch.long, device=device ) images = images.to(device) optimizer.zero_grad() logits = model(images) loss = F.cross_entropy(logits, targets) loss.backward() optimizer.step() total_loss += loss.item() num_batches += 1 return total_loss / num_batches def deepcluster_training_loop( model: DeepCluster, dataloader, feature_dataloader, # For clustering (no augmentation) num_epochs: int, num_clustering_steps: int, device: torch.device): """ Full DeepCluster training loop. Outer loop: Re-cluster periodically Inner loop: Train on current cluster assignments """ optimizer = torch.optim.SGD( model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4 ) epochs_per_clustering = num_epochs // num_clustering_steps for cluster_step in range(num_clustering_steps): print(f"\n=== Clustering Step {cluster_step + 1}/{num_clustering_steps} ===") # Step 1: Compute new cluster assignments assignments = model.compute_cluster_assignments( feature_dataloader, device ) # Print cluster statistics cluster_counts = np.bincount(assignments) print(f"Cluster size: min={cluster_counts.min()}, " f"max={cluster_counts.max()}, mean={cluster_counts.mean():.1f}") # Step 2: Train on these assignments for epoch in range(epochs_per_clustering): loss = train_deepcluster_epoch( model, dataloader, optimizer, assignments, device ) print(f"Epoch {cluster_step * epochs_per_clustering + epoch}: " f"Loss = {loss:.4f}")DeepCluster can converge to trivial solutions where all images are assigned to one cluster. Key prevention strategies include: (1) empty cluster reassignment, (2) uniform cluster weighting in the loss, (3) sufficient k-means iterations, and (4) proper feature normalization before clustering.
SwAV (Swapping Assignments between Views), introduced by Caron et al. (2020), revolutionizes clustering-based self-supervised learning by introducing online clustering. Instead of the expensive offline k-means step in DeepCluster, SwAV computes soft cluster assignments on-the-fly.
SwAV's key innovation is the swapped prediction task:
The loss function: $$\mathcal{L}(z^t, z^s) = \ell(z^t, q^s) + \ell(z^s, q^t)$$
where $\ell(z, q) = -\sum_k q_k \log p_k$ with $p_k = \frac{\exp(z^T c_k / \tau)}{\sum_{k'} \exp(z^T c_{k'} / \tau)}$
Here, ${c_k}$ are learnable prototype vectors, and $\tau$ is a temperature parameter.
To prevent all samples collapsing to the same cluster, SwAV enforces that cluster assignments are equipartitioned across the batch using the Sinkhorn-Knopp algorithm. This iterative procedure ensures:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.distributed as distfrom typing import Optional, Tuple class SwAV(nn.Module): """ SwAV: Swapping Assignments between Views. Key innovations: 1. Online clustering with learnable prototypes 2. Swapped prediction task between views 3. Sinkhorn-Knopp for equipartitioned codes 4. Multi-crop strategy for efficiency """ def __init__( self, backbone: nn.Module, projection_dim: int = 128, num_prototypes: int = 3000, temperature: float = 0.1, sinkhorn_iterations: int = 3, epsilon: float = 0.05 ): super().__init__() self.temperature = temperature self.sinkhorn_iterations = sinkhorn_iterations self.epsilon = epsilon # Backbone self.backbone = backbone feature_dim = backbone.fc.in_features backbone.fc = nn.Identity() # Projection MLP self.projector = nn.Sequential( nn.Linear(feature_dim, 2048), nn.BatchNorm1d(2048), nn.ReLU(inplace=True), nn.Linear(2048, projection_dim) ) # Learnable prototypes (cluster centers) self.prototypes = nn.Linear(projection_dim, num_prototypes, bias=False) # Normalize prototypes with torch.no_grad(): w = self.prototypes.weight.data.clone() w = F.normalize(w, dim=1, p=2) self.prototypes.weight.copy_(w) def forward(self, crops: list) -> torch.Tensor: """ Forward pass with multi-crop strategy. Args: crops: List of augmented views [2 global + optional local crops] Returns: SwAV loss """ # Compute embeddings for all crops embeddings = [] for crop in crops: z = self.backbone(crop) z = self.projector(z) z = F.normalize(z, dim=1) embeddings.append(z) # Normalize prototypes with torch.no_grad(): w = self.prototypes.weight.data.clone() w = F.normalize(w, dim=1, p=2) self.prototypes.weight.copy_(w) # Compute prototype assignments for all embeddings prototype_scores = [self.prototypes(z) for z in embeddings] # Compute soft codes using Sinkhorn-Knopp with torch.no_grad(): codes = [self.sinkhorn_knopp(s / self.epsilon) for s in prototype_scores[:2]] # Only for global views # Swapped prediction loss loss = 0 num_global = 2 # Number of global crops for i in range(num_global): # Get code from view i q = codes[i] # Predict code from all OTHER views for j, s in enumerate(prototype_scores): if i != j: p = F.softmax(s / self.temperature, dim=1) loss -= torch.mean(torch.sum(q * torch.log(p + 1e-10), dim=1)) # Average over number of prediction pairs num_pairs = len(prototype_scores) * num_global - num_global loss /= num_pairs return loss @torch.no_grad() def sinkhorn_knopp(self, scores: torch.Tensor) -> torch.Tensor: """ Sinkhorn-Knopp algorithm for computing equipartitioned codes. Guarantees: 1. Each sample sums to 1 across clusters (is fully assigned) 2. Each cluster has equal total assignment across samples This prevents collapse to a single cluster. """ Q = torch.exp(scores) Q /= Q.sum() # Normalize K, B = Q.shape # K=num_prototypes, B=batch_size for _ in range(self.sinkhorn_iterations): # Row normalization: each sample sums to 1/B Q /= Q.sum(dim=0, keepdim=True) Q /= B # Column normalization: each prototype sums to 1/K Q /= Q.sum(dim=1, keepdim=True) Q /= K # Transpose and multiply to get per-sample codes return (Q / Q.sum(dim=0, keepdim=True)).T class MultiCropDataset: """ Multi-crop augmentation strategy for SwAV. Generates: - 2 global crops (224x224, covering 35-100% of image) - V local crops (96x96, covering 5-35% of image) This provides multiple views at different scales. """ def __init__( self, dataset, num_global_crops: int = 2, num_local_crops: int = 6, global_scale: Tuple[float, float] = (0.35, 1.0), local_scale: Tuple[float, float] = (0.05, 0.35) ): self.dataset = dataset self.num_global = num_global_crops self.num_local = num_local_crops from torchvision import transforms # Global crop transforms self.global_transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=global_scale), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.4, 0.4, 0.2, 0.1), transforms.RandomGrayscale(p=0.2), transforms.GaussianBlur(kernel_size=23), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Local crop transforms (smaller scale) self.local_transform = transforms.Compose([ transforms.RandomResizedCrop(96, scale=local_scale), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.4, 0.4, 0.2, 0.1), transforms.RandomGrayscale(p=0.2), transforms.GaussianBlur(kernel_size=23), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def __getitem__(self, idx): image, label = self.dataset[idx] # Generate all crops crops = [] # Global crops first for _ in range(self.num_global): crops.append(self.global_transform(image)) # Then local crops for _ in range(self.num_local): crops.append(self.local_transform(image)) return crops, label def __len__(self): return len(self.dataset)SwAV's multi-crop strategy is key to its efficiency. Local crops (96×96) are much cheaper to process than global crops (224×224), but they still provide useful learning signal. By predicting codes from local crops to match global crop codes, SwAV learns scale-invariant features without the computational cost of more global crops.
The Sinkhorn-Knopp algorithm is central to SwAV's success. It solves an optimal transport problem to ensure cluster assignments are evenly distributed, preventing the trivial solution where all samples collapse to a single cluster.
The algorithm solves a regularized optimal transport problem:
$$\max_Q \text{Tr}(Q^T C) + \epsilon H(Q)$$
subject to $Q \mathbf{1} = \frac{1}{K}\mathbf{1}$ and $Q^T \mathbf{1} = \frac{1}{B}\mathbf{1}$
where:
Convergence: Sinkhorn-Knopp converges exponentially fast to the unique solution. In practice, 3-5 iterations suffice.
Differentiability: The algorithm is fully differentiable, but SwAV applies it with torch.no_grad() to prevent gradients flowing through the target codes.
Memory efficiency: Only requires O(K × B) memory for the assignment matrix.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
import torchimport torch.nn.functional as Ffrom typing import Tuple def sinkhorn_knopp_detailed( scores: torch.Tensor, epsilon: float = 0.05, num_iterations: int = 3, world_size: int = 1) -> Tuple[torch.Tensor, dict]: """ Detailed Sinkhorn-Knopp implementation with monitoring. Args: scores: Similarity scores [batch_size, num_prototypes] epsilon: Temperature for softmax (lower = harder assignments) num_iterations: Number of Sinkhorn iterations world_size: Number of GPUs for distributed training Returns: codes: Soft cluster assignments [batch_size, num_prototypes] stats: Dictionary of convergence statistics """ with torch.no_grad(): batch_size = scores.shape[0] num_prototypes = scores.shape[1] # Initialize Q with softmax of scaled scores Q = torch.exp(scores / epsilon) # For distributed: aggregate Q across all GPUs if world_size > 1: Q = concat_all_gather(Q) batch_size = Q.shape[0] # Initial normalization Q /= Q.sum() convergence_stats = [] for iteration in range(num_iterations): # Save previous for convergence check Q_prev = Q.clone() # Step 1: Row normalization # Each sample should sum to 1/B across clusters row_sums = Q.sum(dim=1, keepdim=True) Q = Q / row_sums Q = Q / batch_size # Step 2: Column normalization # Each cluster should sum to 1/K across samples col_sums = Q.sum(dim=0, keepdim=True) Q = Q / col_sums Q = Q / num_prototypes # Track convergence delta = torch.abs(Q - Q_prev).max().item() convergence_stats.append({ 'iteration': iteration, 'max_delta': delta, 'row_sum_std': row_sums.std().item(), 'col_sum_std': col_sums.std().item() }) # Final normalization to get per-sample codes codes = Q / Q.sum(dim=1, keepdim=True) # Truncate if we gathered from multiple GPUs if world_size > 1: codes = codes[:batch_size // world_size] stats = { 'convergence': convergence_stats, 'code_entropy': -(codes * torch.log(codes + 1e-10)).sum(dim=1).mean().item(), 'assignment_sharpness': codes.max(dim=1)[0].mean().item() } return codes, stats def analyze_cluster_quality( model: SwAV, dataloader, device: torch.device) -> dict: """ Analyze the quality of learned clusters. Good clusters should: 1. Have uniform usage (no empty clusters) 2. Have consistent within-cluster representations 3. Have distinct between-cluster representations """ model.eval() all_embeddings = [] all_assignments = [] with torch.no_grad(): for images, _ in dataloader: if isinstance(images, list): images = images[0] # Take first crop images = images.to(device) # Get embeddings z = model.backbone(images) z = model.projector(z) z = F.normalize(z, dim=1) # Get assignments scores = model.prototypes(z) hard_assignments = scores.argmax(dim=1) all_embeddings.append(z.cpu()) all_assignments.append(hard_assignments.cpu()) embeddings = torch.cat(all_embeddings, dim=0) assignments = torch.cat(all_assignments, dim=0) # Metric 1: Cluster usage uniformity num_prototypes = model.prototypes.weight.shape[0] cluster_counts = torch.bincount(assignments, minlength=num_prototypes) usage_uniformity = cluster_counts.float().std() / cluster_counts.float().mean() # Metric 2: Within-cluster cosine similarity within_sims = [] for k in range(min(100, num_prototypes)): # Sample clusters mask = (assignments == k) if mask.sum() > 1: cluster_embs = embeddings[mask] sim = (cluster_embs @ cluster_embs.T).mean().item() within_sims.append(sim) # Metric 3: Between-cluster separation (using prototypes) prototypes = model.prototypes.weight.data.detach().cpu() prototypes = F.normalize(prototypes, dim=1) proto_sim = (prototypes @ prototypes.T) # Remove diagonal mask = ~torch.eye(num_prototypes, dtype=bool) between_sim = proto_sim[mask].mean().item() return { 'num_active_clusters': (cluster_counts > 0).sum().item(), 'usage_uniformity_std': usage_uniformity.item(), 'mean_within_cluster_similarity': np.mean(within_sims) if within_sims else 0, 'mean_between_cluster_similarity': between_sim, 'cluster_quality_score': np.mean(within_sims) - between_sim if within_sims else 0 } def concat_all_gather(tensor: torch.Tensor) -> torch.Tensor: """ Gather tensors from all GPUs, preserving gradients (for distributed training). """ if not dist.is_initialized(): return tensor world_size = dist.get_world_size() tensors_gather = [torch.ones_like(tensor) for _ in range(world_size)] dist.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output| Parameter | Typical Value | Effect of Increase | Effect of Decrease |
|---|---|---|---|
| epsilon (ε) | 0.05 | Softer assignments, more exploration | Sharper assignments, faster convergence |
| num_iterations | 3 | Better constraint satisfaction | Faster but less accurate |
| num_prototypes (K) | 3000 | Finer-grained clusters | Coarser semantic grouping |
| temperature (τ) | 0.1 | More uniform predictions | Sharper predictions |
ProtoNCE and PCL (Prototypical Contrastive Learning) bridge clustering and contrastive learning by using cluster centroids (prototypes) as contrastive targets.
Instead of contrasting with individual negative samples, ProtoNCE contrasts with cluster prototypes:
$$\mathcal{L}{\text{ProtoNCE}} = -\log \frac{\exp(z \cdot c^+ / \tau)}{\exp(z \cdot c^+ / \tau) + \sum{c^- \in \mathcal{C}^-} \exp(z \cdot c^- / \tau)}$$
where:
PCL combines:
$$\mathcal{L}{\text{PCL}} = \mathcal{L}{\text{instance}} + \lambda \mathcal{L}_{\text{proto}}$$
The instance loss encourages augmentation invariance, while the prototype loss encourages semantic clustering.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom sklearn.cluster import KMeansimport numpy as np class PrototypicalContrastiveLearning(nn.Module): """ Prototypical Contrastive Learning (PCL). Combines instance-level and prototype-level contrastive losses. Maintains a feature memory bank for efficient prototype computation. """ def __init__( self, backbone: nn.Module, projection_dim: int = 128, num_prototypes: int = 1000, temperature: float = 0.07, proto_temperature: float = 0.2, memory_bank_size: int = 65536 ): super().__init__() self.num_prototypes = num_prototypes self.temperature = temperature self.proto_temperature = proto_temperature # Backbone and projector self.backbone = backbone feature_dim = backbone.fc.in_features backbone.fc = nn.Identity() self.projector = nn.Sequential( nn.Linear(feature_dim, feature_dim), nn.ReLU(), nn.Linear(feature_dim, projection_dim) ) # Memory bank for features self.register_buffer( 'memory_bank', torch.randn(memory_bank_size, projection_dim) ) self.memory_bank = F.normalize(self.memory_bank, dim=1) self.register_buffer('memory_ptr', torch.zeros(1, dtype=torch.long)) # Prototypes (cluster centers) self.register_buffer( 'prototypes', torch.randn(num_prototypes, projection_dim) ) self.prototypes = F.normalize(self.prototypes, dim=1) # Cluster assignments for memory bank entries self.register_buffer( 'cluster_assignments', torch.zeros(memory_bank_size, dtype=torch.long) ) def update_memory_bank(self, features: torch.Tensor): """Update memory bank with new features (FIFO queue).""" batch_size = features.shape[0] ptr = int(self.memory_ptr) if ptr + batch_size > self.memory_bank.shape[0]: ptr = 0 self.memory_bank[ptr:ptr + batch_size] = features.detach() self.memory_ptr[0] = (ptr + batch_size) % self.memory_bank.shape[0] @torch.no_grad() def update_prototypes(self): """Update prototypes using k-means on memory bank.""" features = self.memory_bank.cpu().numpy() kmeans = KMeans(n_clusters=self.num_prototypes, n_init=5) assignments = kmeans.fit_predict(features) # Update prototypes and assignments self.prototypes.copy_( torch.from_numpy(kmeans.cluster_centers_).to(self.prototypes.device) ) self.prototypes = F.normalize(self.prototypes, dim=1) self.cluster_assignments.copy_( torch.from_numpy(assignments).to(self.cluster_assignments.device) ) def instance_contrastive_loss( self, z1: torch.Tensor, z2: torch.Tensor ) -> torch.Tensor: """ Instance-level InfoNCE loss. Positive: augmented views of same image Negatives: memory bank entries """ batch_size = z1.shape[0] # Positive pairs pos_sim = torch.sum(z1 * z2, dim=1) / self.temperature # Negative pairs from memory bank neg_sim = torch.mm(z1, self.memory_bank.T) / self.temperature # InfoNCE loss logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) labels = torch.zeros(batch_size, dtype=torch.long, device=z1.device) return F.cross_entropy(logits, labels) def prototype_contrastive_loss( self, z: torch.Tensor, assignments: torch.Tensor ) -> torch.Tensor: """ Prototype-level contrastive loss. Positive: sample's assigned prototype Negatives: all other prototypes """ # Similarity to all prototypes sim = torch.mm(z, self.prototypes.T) / self.proto_temperature # Cross-entropy with prototype assignment as target return F.cross_entropy(sim, assignments) def forward( self, view1: torch.Tensor, view2: torch.Tensor ) -> torch.Tensor: """ Forward pass computing combined PCL loss. """ # Get embeddings z1 = F.normalize(self.projector(self.backbone(view1)), dim=1) z2 = F.normalize(self.projector(self.backbone(view2)), dim=1) # Get cluster assignments for current batch with torch.no_grad(): proto_sim = torch.mm(z1, self.prototypes.T) assignments = proto_sim.argmax(dim=1) # Instance-level loss (symmetric) loss_instance = ( self.instance_contrastive_loss(z1, z2) + self.instance_contrastive_loss(z2, z1) ) / 2 # Prototype-level loss (symmetric) loss_proto = ( self.prototype_contrastive_loss(z1, assignments) + self.prototype_contrastive_loss(z2, assignments) ) / 2 # Update memory bank self.update_memory_bank(torch.cat([z1, z2], dim=0)) # Combined loss return loss_instance + 0.5 * loss_proto class AdaptivePrototypes(nn.Module): """ Improved prototype management with concentration estimation. Prototypes are weighted by their 'concentration' - how tightly clustered their assigned samples are. Low-concentration prototypes are refined more aggressively. """ def __init__(self, num_prototypes: int, feature_dim: int): super().__init__() self.register_buffer('prototypes', torch.randn(num_prototypes, feature_dim)) self.prototypes = F.normalize(self.prototypes, dim=1) # Concentration parameter for each prototype self.register_buffer('concentrations', torch.ones(num_prototypes)) # Running count of samples per prototype self.register_buffer('counts', torch.zeros(num_prototypes)) def update( self, features: torch.Tensor, assignments: torch.Tensor, momentum: float = 0.9 ): """ Update prototypes and concentrations based on assigned features. """ for k in range(self.prototypes.shape[0]): mask = (assignments == k) if mask.sum() > 0: cluster_features = features[mask] cluster_mean = cluster_features.mean(dim=0) cluster_mean = F.normalize(cluster_mean, dim=0) # Update prototype with momentum self.prototypes[k] = momentum * self.prototypes[k] + (1 - momentum) * cluster_mean self.prototypes[k] = F.normalize(self.prototypes[k], dim=0) # Update concentration (inverse variance) variance = ((cluster_features - cluster_mean) ** 2).mean() self.concentrations[k] = 1 / (variance + 1e-6) # Update count self.counts[k] += mask.sum().item()DINO (Self-Distillation with No Labels), introduced by Caron et al. (2021), elegantly combines self-distillation with clustering-like behavior. It learns by having a student network match the output distribution of a teacher network—with the teacher updated via exponential moving average.
DINO's key insight is that self-distillation with proper design choices naturally leads to emergence of semantic features without explicit clustering or contrastive objectives.
Architecture:
Training objective: $$\mathcal{L} = H(P_t, P_s) = -\sum_k P_t^{(k)} \log P_s^{(k)}$$
where $P_t = \text{softmax}(g_ξ(x) / \tau_t)$ and $P_s = \text{softmax}(g_θ(x') / \tau_s)$
DINO prevents collapse through two mechanisms:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom copy import deepcopy class DINO(nn.Module): """ DINO: Self-Distillation with No Labels. Key innovations: 1. Self-distillation between student and EMA teacher 2. Centering to prevent collapse 3. Multi-crop training strategy 4. Emergent semantic properties (attention maps) """ def __init__( self, backbone: nn.Module, output_dim: int = 65536, hidden_dim: int = 2048, bottleneck_dim: int = 256, student_temp: float = 0.1, teacher_temp: float = 0.04, center_momentum: float = 0.9, ema_decay: float = 0.996 ): super().__init__() self.student_temp = student_temp self.teacher_temp = teacher_temp self.center_momentum = center_momentum self.ema_decay = ema_decay # Student network self.student_backbone = backbone feature_dim = backbone.embed_dim if hasattr(backbone, 'embed_dim') else 2048 self.student_head = DINOHead( feature_dim, output_dim, hidden_dim, bottleneck_dim ) # Teacher network (EMA of student) self.teacher_backbone = deepcopy(backbone) self.teacher_head = deepcopy(self.student_head) # Freeze teacher for p in self.teacher_backbone.parameters(): p.requires_grad = False for p in self.teacher_head.parameters(): p.requires_grad = False # Center for teacher outputs (for centering operation) self.register_buffer('center', torch.zeros(1, output_dim)) @torch.no_grad() def update_teacher(self): """Update teacher using exponential moving average.""" for student_ps, teacher_ps in zip( self.student_backbone.parameters(), self.teacher_backbone.parameters() ): teacher_ps.data = ( self.ema_decay * teacher_ps.data + (1 - self.ema_decay) * student_ps.data ) for student_ps, teacher_ps in zip( self.student_head.parameters(), self.teacher_head.parameters() ): teacher_ps.data = ( self.ema_decay * teacher_ps.data + (1 - self.ema_decay) * student_ps.data ) @torch.no_grad() def update_center(self, teacher_output: torch.Tensor): """Update center with EMA of teacher outputs.""" batch_center = teacher_output.mean(dim=0, keepdim=True) self.center = ( self.center_momentum * self.center + (1 - self.center_momentum) * batch_center ) def forward(self, crops: list) -> torch.Tensor: """ Forward pass with multi-crop. Args: crops: [global_1, global_2, local_1, ..., local_V] Returns: DINO loss (distillation from centered, sharpened teacher) """ n_global = 2 # Number of global crops # Get all student outputs student_outputs = [] for crop in crops: feat = self.student_backbone(crop) out = self.student_head(feat) student_outputs.append(out) # Get teacher outputs for global crops only with torch.no_grad(): teacher_outputs = [] for crop in crops[:n_global]: feat = self.teacher_backbone(crop) out = self.teacher_head(feat) teacher_outputs.append(out) # Compute loss loss = 0 n_loss_terms = 0 for iq, q in enumerate(teacher_outputs): # Apply centering and sharpening to teacher output q = F.softmax((q - self.center) / self.teacher_temp, dim=-1) for iv, v in enumerate(student_outputs): # Skip when same view if iv == iq: continue # Student output with its temperature p = F.log_softmax(v / self.student_temp, dim=-1) # Cross-entropy loss (KL divergence without constant) loss += torch.sum(-q * p, dim=-1).mean() n_loss_terms += 1 loss /= n_loss_terms # Update center with torch.no_grad(): teacher_out = torch.cat(teacher_outputs, dim=0) self.update_center(teacher_out) return loss class DINOHead(nn.Module): """ DINO projection head. Uses weight normalization on last layer to constrain the output space and prevent representation collapse. """ def __init__( self, in_dim: int, out_dim: int, hidden_dim: int = 2048, bottleneck_dim: int = 256, use_bn: bool = False, norm_last_layer: bool = True ): super().__init__() layers = [ nn.Linear(in_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, bottleneck_dim) ] self.mlp = nn.Sequential(*layers) # Last layer with optional weight normalization self.last_layer = nn.utils.weight_norm( nn.Linear(bottleneck_dim, out_dim, bias=False) ) self.last_layer.weight_g.data.fill_(1) if norm_last_layer: self.last_layer.weight_g.requires_grad = False def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mlp(x) x = F.normalize(x, dim=-1) x = self.last_layer(x) return x def visualize_dino_attention( model: DINO, image: torch.Tensor, patch_size: int = 16) -> torch.Tensor: """ Extract and visualize self-attention from DINO's Vision Transformer. DINO naturally learns to attend to semantically meaningful regions, even without any semantic supervision. """ model.eval() # Get attention weights from the last layer with torch.no_grad(): # For ViT backbone features = model.teacher_backbone.get_intermediate_layers( image, n=1 )[0] # Get self-attention weights attentions = model.teacher_backbone.get_last_selfattention(image) # Process attention map nh = attentions.shape[1] # Number of heads w_featmap = image.shape[-2] // patch_size h_featmap = image.shape[-1] // patch_size # Take CLS token attention to other tokens attentions = attentions[0, :, 0, 1:].reshape(nh, w_featmap, h_featmap) # Average over heads attention_map = attentions.mean(dim=0) return attention_mapOne of DINO's most remarkable properties is that its attention maps naturally segment objects without any segmentation supervision. The CLS token's attention in the final layer reliably attends to foreground objects, demonstrating that semantic understanding emerges from self-distillation alone. This makes DINO features excellent for downstream tasks like object detection and segmentation.
The evolution from DeepCluster through SwAV to DINO represents a progression toward more elegant and effective clustering-based self-supervised learning. Each method addresses limitations of its predecessors while introducing new innovations.
| Method | Clustering | Collapse Prevention | Key Innovation | ImageNet Acc. |
|---|---|---|---|---|
| DeepCluster (2018) | Offline k-means | Empty cluster reassignment | Alternating clustering/training | ~73% |
| SwAV (2020) | Online (learnable prototypes) | Sinkhorn-Knopp equipartition | Swapped prediction + multi-crop | ~75% |
| DINO (2021) | Implicit (self-distillation) | Centering + sharpening | Emergent semantic attention | ~77% |
| iBOT (2021) | Implicit + token-level | Centering + masking | Masked token prediction | ~79% |
Clustering-based self-supervised learning represents a powerful paradigm that leverages the natural structure of visual data. By treating cluster assignments as pseudo-labels, these methods learn semantically meaningful representations without any manual annotation.
You now understand clustering-based self-supervised learning methods from DeepCluster through SwAV to DINO. These methods exploit the natural clustering structure of visual data to learn representations without labels. Next, we'll explore masked modeling—a completely different paradigm inspired by language model pre-training.