Loading content...
In the previous modules, we explored contrastive learning methods like SimCLR and MoCo that learn representations by contrasting positive pairs against negative samples. These methods achieved remarkable success, but they came with a fundamental constraint: the need for negative samples.
Contrastive methods require carefully curated negative pairs—samples that should be pushed apart in the representation space. This creates several practical challenges:
What if we could eliminate the need for negative samples entirely? This question led to two groundbreaking methods: BYOL (Bootstrap Your Own Latent) and SimSiam (Simple Siamese).
The central puzzle of non-contrastive learning is avoiding representation collapse—where the network learns to map all inputs to the same constant output. Without negative samples pushing representations apart, what prevents the trivial solution of outputting zeros for everything? BYOL and SimSiam provide elegant answers to this fundamental question.
BYOL, introduced by Grill et al. (2020) at DeepMind, represents a paradigm shift in self-supervised learning. It achieves state-of-the-art representations without any negative samples, using only positive pairs from augmented views of the same image.
BYOL employs an asymmetric architecture with two neural networks:
Online Network (θ):
Target Network (ξ):
The target network's parameters ξ are an exponential moving average (EMA) of the online network's parameters θ:
$$ξ ← τξ + (1 - τ)θ$$
where τ is the momentum coefficient (typically 0.99 to 0.999).
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torchvision.models import resnet50from copy import deepcopy class MLPHead(nn.Module): """MLP projection/prediction head used in BYOL.""" def __init__(self, in_dim: int, hidden_dim: int = 4096, out_dim: int = 256): super().__init__() self.net = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, out_dim) ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class BYOL(nn.Module): """ Bootstrap Your Own Latent (BYOL) implementation. Key insight: The asymmetry between online and target networks, combined with EMA updates, prevents representation collapse without needing negative samples. """ def __init__( self, backbone: nn.Module = None, projection_dim: int = 256, hidden_dim: int = 4096, momentum: float = 0.996 ): super().__init__() self.momentum = momentum # Online network components self.online_encoder = backbone or resnet50(pretrained=False) feature_dim = self.online_encoder.fc.in_features self.online_encoder.fc = nn.Identity() # Remove classification head self.online_projector = MLPHead(feature_dim, hidden_dim, projection_dim) self.online_predictor = MLPHead(projection_dim, hidden_dim, projection_dim) # Target network (no predictor - key asymmetry) self.target_encoder = deepcopy(self.online_encoder) self.target_projector = deepcopy(self.online_projector) # Freeze target network - updated via EMA only for param in self.target_encoder.parameters(): param.requires_grad = False for param in self.target_projector.parameters(): param.requires_grad = False @torch.no_grad() def update_target_network(self): """ Update target network using exponential moving average. This slow-moving target provides stable learning signals. """ for online_params, target_params in zip( list(self.online_encoder.parameters()) + list(self.online_projector.parameters()), list(self.target_encoder.parameters()) + list(self.target_projector.parameters()) ): target_params.data = ( self.momentum * target_params.data + (1 - self.momentum) * online_params.data ) def forward(self, view1: torch.Tensor, view2: torch.Tensor): """ Forward pass computing BYOL loss. Args: view1, view2: Two augmented views of the same batch of images Returns: Symmetric BYOL loss """ # Online network forward pass online_proj_1 = self.online_projector(self.online_encoder(view1)) online_proj_2 = self.online_projector(self.online_encoder(view2)) online_pred_1 = self.online_predictor(online_proj_1) online_pred_2 = self.online_predictor(online_proj_2) # Target network forward pass (no gradients) with torch.no_grad(): target_proj_1 = self.target_projector(self.target_encoder(view1)) target_proj_2 = self.target_projector(self.target_encoder(view2)) # Detach to stop gradients flowing to target target_proj_1 = target_proj_1.detach() target_proj_2 = target_proj_2.detach() # Compute BYOL loss (symmetric) loss_1 = self.regression_loss(online_pred_1, target_proj_2) loss_2 = self.regression_loss(online_pred_2, target_proj_1) return (loss_1 + loss_2).mean() def regression_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Normalized MSE loss between predictions and targets. L2-normalizing both vectors makes this equivalent to 2 - 2 * cosine_similarity, but more numerically stable. """ pred = F.normalize(pred, dim=-1, p=2) target = F.normalize(target, dim=-1, p=2) return 2 - 2 * (pred * target).sum(dim=-1)The predictor network in BYOL is not just an architectural detail—it's fundamental to preventing collapse. The predictor must learn to 'predict' the target representation, creating a non-trivial learning task. Without it, the online and target networks would simply converge to the same constant output.
The question of why BYOL doesn't collapse to trivial solutions sparked significant research debate. Several mechanisms contribute to its stability:
One theory suggests that batch normalization (BatchNorm) in BYOL's architecture implicitly introduces a form of contrastive learning:
Mathematically, BatchNorm computes: $$\hat{z}_i = \frac{z_i - \mu_B}{\sqrt{\sigma^2_B + \epsilon}}$$
where $\mu_B$ and $\sigma^2_B$ are batch statistics. If all $z_i$ were identical, the variance would be zero, causing numerical issues.
A more complete explanation focuses on the architectural asymmetry:
The predictor creates an information bottleneck: The online network must learn features that the predictor can successfully transform to match the target
EMA creates a slowly-moving target: The target changes slowly, providing consistent learning signals
The prediction task is non-trivial: The predictor must bridge the gap between two augmented views, requiring meaningful representations
| Mechanism | How It Works | Evidence |
|---|---|---|
| Predictor Network | Creates asymmetry; online network must learn 'predictable' features | Removing predictor causes immediate collapse |
| EMA Target Updates | Slow-moving target provides stable, consistent learning signal | Direct copying (τ=0) causes collapse |
| Batch Normalization | Implicit sample competition through batch statistics | Removing BN requires careful tuning to avoid collapse |
| Augmentation Strength | Strong augmentations create meaningful prediction task | Weak augmentations lead to trivial solutions |
| Network Capacity | Sufficient capacity to capture meaningful distinctions | Very shallow networks may collapse |
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
def analyze_collapse_indicators(model: BYOL, dataloader): """ Monitor metrics that indicate potential representation collapse. Key indicators: 1. Representation std: Should be high (diverse representations) 2. Cosine similarity: Should be moderate (not all identical) 3. Effective rank: Should be high (full rank representations) """ representations = [] model.eval() with torch.no_grad(): for images, _ in dataloader: # Get representations before projection reps = model.online_encoder(images) representations.append(reps) all_reps = torch.cat(representations, dim=0) # Metric 1: Standard deviation across batch # Low std indicates collapse toward identical outputs std_per_dim = all_reps.std(dim=0) mean_std = std_per_dim.mean().item() # Metric 2: Average pairwise cosine similarity # High similarity (close to 1) indicates collapse normalized = F.normalize(all_reps, dim=1) similarity_matrix = normalized @ normalized.T # Exclude diagonal (self-similarity = 1) mask = ~torch.eye(len(all_reps), dtype=bool) avg_similarity = similarity_matrix[mask].mean().item() # Metric 3: Effective rank via singular values # Low effective rank indicates dimensional collapse _, singular_values, _ = torch.svd(all_reps - all_reps.mean(dim=0)) normalized_sv = singular_values / singular_values.sum() entropy = -(normalized_sv * torch.log(normalized_sv + 1e-10)).sum() effective_rank = torch.exp(entropy).item() return { 'mean_std': mean_std, 'avg_cosine_similarity': avg_similarity, 'effective_rank': effective_rank, 'collapse_warning': avg_similarity > 0.9 or mean_std < 0.01 } # Example usage during trainingdef training_loop_with_monitoring(model, dataloader, epochs): for epoch in range(epochs): # ... training code ... if epoch % 10 == 0: metrics = analyze_collapse_indicators(model, dataloader) print(f"Epoch {epoch}:") print(f" Std: {metrics['mean_std']:.4f}") print(f" Similarity: {metrics['avg_cosine_similarity']:.4f}") print(f" Effective Rank: {metrics['effective_rank']:.2f}") if metrics['collapse_warning']: print(" ⚠️ Warning: Potential collapse detected!")Representation collapse doesn't always mean all outputs are literally identical. More commonly, it manifests as dimensional collapse—where representations become low-rank, using only a small subspace of the available dimensions. Monitoring effective rank alongside standard deviation catches both failure modes.
SimSiam, introduced by Chen & He (2021) at Facebook AI Research, distills the essence of non-contrastive learning to its simplest form. It achieves comparable results to BYOL while removing the momentum encoder entirely.
SimSiam's architecture is elegantly minimal:
Single Encoder (shared weights):
The key innovation is the stop-gradient operation:
$$\mathcal{L} = -\frac{1}{2}\left( \text{sim}(p_1, \text{stopgrad}(z_2)) + \text{sim}(p_2, \text{stopgrad}(z_1)) \right)$$
where $p_i = h(g(f(x_i)))$ and $z_i = g(f(x_i))$, and sim is cosine similarity.
Without stop-gradient, both branches would receive gradients, and the network would learn to minimize the loss by collapsing everything to a constant. The stop-gradient breaks this symmetry:
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
import torchimport torch.nn as nnimport torch.nn.functional as F class SimSiam(nn.Module): """ Simple Siamese (SimSiam) implementation. Key insight: Stop-gradient is sufficient to prevent collapse. No momentum encoder, no large batches, no negative samples. """ def __init__( self, backbone: nn.Module, projection_dim: int = 2048, prediction_dim: int = 512 ): super().__init__() # Encoder self.encoder = backbone feature_dim = backbone.fc.in_features backbone.fc = nn.Identity() # Projector (3-layer MLP) self.projector = nn.Sequential( nn.Linear(feature_dim, feature_dim, bias=False), nn.BatchNorm1d(feature_dim), nn.ReLU(inplace=True), nn.Linear(feature_dim, feature_dim, bias=False), nn.BatchNorm1d(feature_dim), nn.ReLU(inplace=True), nn.Linear(feature_dim, projection_dim, bias=False), nn.BatchNorm1d(projection_dim, affine=False) # No learnable params ) # Predictor (2-layer MLP with bottleneck) self.predictor = nn.Sequential( nn.Linear(projection_dim, prediction_dim, bias=False), nn.BatchNorm1d(prediction_dim), nn.ReLU(inplace=True), nn.Linear(prediction_dim, projection_dim) # Output matches projector ) def forward(self, view1: torch.Tensor, view2: torch.Tensor): """ Forward pass computing SimSiam loss. The stop-gradient is the key to preventing collapse. It makes one branch a 'fixed' target for the other. """ # Encode and project both views z1 = self.projector(self.encoder(view1)) # N x projection_dim z2 = self.projector(self.encoder(view2)) # Predict from one view to the other p1 = self.predictor(z1) p2 = self.predictor(z2) # Symmetric loss with stop-gradient # The .detach() implements stop-gradient loss = ( self.cosine_loss(p1, z2.detach()) + self.cosine_loss(p2, z1.detach()) ) / 2 return loss def cosine_loss(self, p: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """ Negative cosine similarity loss. Minimizing this encourages p and z to point in the same direction. """ p = F.normalize(p, dim=1) z = F.normalize(z, dim=1) return -(p * z).sum(dim=1).mean() # Training loop exampledef train_simsiam(model, dataloader, optimizer, epochs): """ SimSiam training loop. Note: SimSiam works with small batch sizes (256), unlike contrastive methods requiring thousands. """ for epoch in range(epochs): for (view1, view2), _ in dataloader: view1, view2 = view1.cuda(), view2.cuda() optimizer.zero_grad() loss = model(view1, view2) loss.backward() optimizer.step() print(f"Epoch {epoch}: Loss = {loss.item():.4f}")SimSiam demonstrates that effective self-supervised learning doesn't require complex mechanisms. No momentum encoder, no large batch sizes, no memory banks—just a predictor and stop-gradient. This simplicity makes it easier to analyze, implement, and adapt to new domains.
Chen & He provided theoretical analysis showing SimSiam can be understood as an Expectation-Maximization (EM) algorithm. This perspective explains why stop-gradient prevents collapse.
Consider SimSiam as alternating between two steps:
E-step (with stop-gradient): Fix the representation $z$ as the current estimate of the underlying concept. The stop-gradient branch provides this fixed target.
M-step (predictor update): Given the fixed target, update the predictor to better match predictions to targets. The learning branch performs this optimization.
Define the augmentation distribution as $\mathcal{T}$. For an image $x$, we sample two augmentations $t, t' \sim \mathcal{T}$ to get views $x_1 = t(x)$ and $x_2 = t'(x)$.
SimSiam minimizes: $$\mathcal{L} = \mathbb{E}_{x, t, t'}\left[ |h(f(t(x))) - f(t'(x))|^2 \right]$$
subject to $|f(\cdot)| = 1$ (unit normalization).
The stop-gradient creates an alternating optimization:
This alternating structure is stable because each step has a well-defined objective.
| Component | Default | Effect of Removal/Change | Interpretation |
|---|---|---|---|
| Stop-gradient | Applied to z | Complete collapse | Creates asymmetric learning objective |
| Predictor MLP | 2-layer with 512 hidden | Collapse | Provides flexible matching function |
| BatchNorm in projector | Yes (affine=False on last) | More sensitive to initialization | Stabilizes feature distributions |
| Output dimension | 2048 | Higher helps slightly, diminishing returns | Representation capacity |
| Predictor bottleneck | 512 (1/4 of output) | Larger bottleneck works | Information compression |
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
import torchimport numpy as npfrom typing import Tuple def analyze_simsiam_dynamics( model: SimSiam, dataloader, num_batches: int = 100) -> dict: """ Analyze SimSiam learning dynamics to understand collapse prevention. Returns metrics that explain the EM interpretation: - Predictor fitting error (E-step quality) - Representation consistency (M-step quality) - Gradient flow analysis """ model.eval() pred_errors = [] rep_consistencies = [] grad_norms = {'encoder': [], 'predictor': []} for i, ((view1, view2), _) in enumerate(dataloader): if i >= num_batches: break view1, view2 = view1.cuda(), view2.cuda() # Enable gradients for analysis with torch.enable_grad(): z1 = model.projector(model.encoder(view1)) z2 = model.projector(model.encoder(view2)) p1 = model.predictor(z1) p2 = model.predictor(z2) # Predictor fitting error: how well does predictor match target? z1_norm = F.normalize(z1.detach(), dim=1) z2_norm = F.normalize(z2.detach(), dim=1) p1_norm = F.normalize(p1, dim=1) p2_norm = F.normalize(p2, dim=1) pred_error = ( 1 - (p1_norm * z2_norm).sum(dim=1).mean() + 1 - (p2_norm * z1_norm).sum(dim=1).mean() ) / 2 pred_errors.append(pred_error.item()) # Representation consistency: similarity of z1 and z2 rep_consistency = (z1_norm * z2_norm).sum(dim=1).mean() rep_consistencies.append(rep_consistency.item()) return { 'mean_pred_error': np.mean(pred_errors), 'std_pred_error': np.std(pred_errors), 'mean_rep_consistency': np.mean(rep_consistencies), 'std_rep_consistency': np.std(rep_consistencies), 'interpretation': interpret_metrics(np.mean(pred_errors), np.mean(rep_consistencies)) } def interpret_metrics(pred_error: float, rep_consistency: float) -> str: """Interpret the metrics in terms of learning dynamics.""" if pred_error < 0.1 and rep_consistency > 0.9: return "Converged: predictor matches target well, representations are consistent" elif pred_error > 0.5: return "Still learning: predictor struggling to match targets" elif rep_consistency < 0.5: return "Diverse augmentations: representations vary significantly between views" else: return "Normal learning: progressing toward convergence" def visualize_stop_gradient_effect(): """ Demonstrate why stop-gradient prevents collapse. Without stop-gradient: - Both z1 and z2 receive gradients - Easiest way to minimize loss: make everything identical - Collapse to constant output With stop-gradient: - Only z1 receives gradient (through predictor) - z2 is a 'moving target' that changes each iteration - No trivial solution exists """ # Simplified 2D example print("Without stop-gradient (collapse scenario):") print(" z1, z2 both updated to minimize ||p1 - z2||") print(" Gradient: ∂L/∂z1 AND ∂L/∂z2 both push toward center") print(" Result: z1 = z2 = constant, loss = 0 trivially") print() print("With stop-gradient (learning scenario):") print(" Only z1 updated, z2 is detached") print(" Gradient: ∂L/∂z1 only, z2 provides target") print(" Result: z1 must learn to predict z2 across augmentations") print(" This requires learning meaningful, consistent features")While BYOL and SimSiam share the goal of non-contrastive learning, they achieve it through different mechanisms. Understanding their differences helps in choosing the right approach for your use case.
| Aspect | BYOL | SimSiam |
|---|---|---|
| Target network | Separate (EMA of online) | Same as prediction network (shared weights) |
| Update mechanism | Momentum update (ξ ← τξ + (1-τ)θ) | Stop-gradient operation |
| Predictor | On online network only | Applied asymmetrically in loss |
| Collapse prevention | EMA + predictor + BN | Stop-gradient + predictor |
| Hyperparameters | Momentum τ critical | Simpler, fewer hyperparameters |
BYOL advantages:
SimSiam advantages:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
import torchfrom typing import Dict, Any def compare_training_dynamics( byol_model: BYOL, simsiam_model: SimSiam, dataloader, epochs: int = 10) -> Dict[str, Any]: """ Compare training dynamics between BYOL and SimSiam. Key observations: - BYOL's EMA provides smoother loss curves - SimSiam may show more oscillation initially but converges similarly - Both avoid collapse through different mechanisms """ results = {'byol': [], 'simsiam': []} byol_opt = torch.optim.SGD( byol_model.parameters(), lr=0.03, momentum=0.9, weight_decay=1e-4 ) simsiam_opt = torch.optim.SGD( simsiam_model.parameters(), lr=0.03, momentum=0.9, weight_decay=1e-4 ) for epoch in range(epochs): byol_losses, simsiam_losses = [], [] for (view1, view2), _ in dataloader: view1, view2 = view1.cuda(), view2.cuda() # BYOL forward/backward byol_opt.zero_grad() byol_loss = byol_model(view1, view2) byol_loss.backward() byol_opt.step() byol_model.update_target_network() # EMA update byol_losses.append(byol_loss.item()) # SimSiam forward/backward simsiam_opt.zero_grad() simsiam_loss = simsiam_model(view1, view2) simsiam_loss.backward() simsiam_opt.step() # No momentum update needed! simsiam_losses.append(simsiam_loss.item()) results['byol'].append({ 'epoch': epoch, 'mean_loss': np.mean(byol_losses), 'std_loss': np.std(byol_losses) }) results['simsiam'].append({ 'epoch': epoch, 'mean_loss': np.mean(simsiam_losses), 'std_loss': np.std(simsiam_losses) }) print(f"Epoch {epoch}:") print(f" BYOL Loss: {np.mean(byol_losses):.4f} ± {np.std(byol_losses):.4f}") print(f" SimSiam Loss: {np.mean(simsiam_losses):.4f} ± {np.std(simsiam_losses):.4f}") return results def analyze_representation_quality(model, dataloader, method_name: str): """ Evaluate representation quality using linear probe accuracy and k-NN classification. Both BYOL and SimSiam typically achieve: - ~75% ImageNet linear probe accuracy - Competitive k-NN performance """ model.eval() features, labels = [], [] with torch.no_grad(): for images, targets in dataloader: # Extract features from encoder (before projector) if hasattr(model, 'online_encoder'): # BYOL feat = model.online_encoder(images.cuda()) else: # SimSiam feat = model.encoder(images.cuda()) features.append(feat.cpu()) labels.append(targets) features = torch.cat(features, dim=0) labels = torch.cat(labels, dim=0) # k-NN evaluation (k=20) knn_acc = knn_classifier(features, labels, k=20) print(f"{method_name} k-NN Accuracy: {knn_acc:.2f}%") return knn_accBoth BYOL and SimSiam benefit from careful training practices. The following techniques maximize representation quality and training stability.
Both methods benefit from cosine annealing with warmup:
$$\eta_t = \eta_{base} \cdot \frac{1}{2}\left(1 + \cos\left(\frac{\pi t}{T}\right)\right)$$
where $t$ is the current step, $T$ is total steps, and $\eta_{base}$ is the base learning rate.
BYOL's target network momentum τ can be scheduled from 0.996 to 1.0:
$$\tau_t = 1 - (1 - \tau_{base}) \cdot \frac{1}{2}\left(1 + \cos\left(\frac{\pi t}{T}\right)\right)$$
This gradually freezes the target network, providing increasingly stable targets.
Strong augmentations are crucial for both methods. The standard recipe includes:
The asymmetry between views (different blur/solarization probabilities) is important for learning robust features.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
import torchimport torch.nn as nnfrom torch.optim.lr_scheduler import CosineAnnealingLRfrom torchvision import transformsimport math class BYOLAugmentation: """ BYOL's asymmetric augmentation strategy. View 1 and View 2 use slightly different augmentations, which prevents shortcut solutions and encourages robust features. """ def __init__(self, image_size: int = 224): # Common augmentations self.base_transforms = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(0.2, 1.0)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomApply([ transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) ], p=0.8), transforms.RandomGrayscale(p=0.2), ]) # View 1: Higher blur probability self.view1_transforms = transforms.Compose([ self.base_transforms, transforms.RandomApply([ transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)) ], p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # View 2: Lower blur, adds solarization self.view2_transforms = transforms.Compose([ self.base_transforms, transforms.RandomApply([ transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)) ], p=0.1), transforms.RandomApply([ transforms.RandomSolarize(threshold=128) ], p=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __call__(self, image): return self.view1_transforms(image), self.view2_transforms(image) class MomentumScheduler: """ Cosine schedule for BYOL's target network momentum. Starts at base_momentum (e.g., 0.996) and increases toward 1.0, progressively freezing the target network. """ def __init__(self, base_momentum: float, total_steps: int): self.base_momentum = base_momentum self.total_steps = total_steps def get_momentum(self, step: int) -> float: progress = step / self.total_steps return 1 - (1 - self.base_momentum) * (math.cos(math.pi * progress) + 1) / 2 def create_optimizer_and_schedulers( model: nn.Module, base_lr: float = 0.3, weight_decay: float = 1e-4, warmup_epochs: int = 10, total_epochs: int = 300, batch_size: int = 256, dataset_size: int = 1_281_167 # ImageNet train size): """ Create optimizer with LARS and learning rate scheduling. LARS (Layer-wise Adaptive Rate Scaling) is crucial for large batch training used in self-supervised learning. """ steps_per_epoch = dataset_size // batch_size total_steps = total_epochs * steps_per_epoch warmup_steps = warmup_epochs * steps_per_epoch # Separate batch norm and bias from other parameters param_groups = [ {'params': [p for n, p in model.named_parameters() if 'bn' not in n and 'bias' not in n], 'weight_decay': weight_decay}, {'params': [p for n, p in model.named_parameters() if 'bn' in n or 'bias' in n], 'weight_decay': 0} # No weight decay for BN and bias ] optimizer = torch.optim.SGD( param_groups, lr=base_lr * batch_size / 256, # Linear scaling rule momentum=0.9 ) # Cosine schedule with warmup scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps) return optimizer, scheduler, warmup_steps def train_epoch_with_scheduling( model: BYOL, dataloader, optimizer, lr_scheduler, momentum_scheduler: MomentumScheduler, epoch: int, warmup_steps: int, global_step: int) -> int: """ Training epoch with proper scheduling. Includes: - Learning rate warmup - Cosine LR decay - Momentum annealing for BYOL """ model.train() for (view1, view2), _ in dataloader: view1, view2 = view1.cuda(), view2.cuda() # LR warmup if global_step < warmup_steps: lr_scale = (global_step + 1) / warmup_steps for pg in optimizer.param_groups: pg['lr'] = pg['initial_lr'] * lr_scale # Forward and backward optimizer.zero_grad() loss = model(view1, view2) loss.backward() optimizer.step() # Update target network with scheduled momentum model.momentum = momentum_scheduler.get_momentum(global_step) model.update_target_network() # LR scheduling after warmup if global_step >= warmup_steps: lr_scheduler.step() global_step += 1 return global_stepFor very large batch sizes (>1024), consider using LARS (Layer-wise Adaptive Rate Scaling) instead of SGD. LARS adapts the learning rate per-layer based on the ratio of weight norm to gradient norm, enabling stable training with batch sizes up to 32,768.
BYOL and SimSiam represent a fundamental breakthrough in self-supervised learning: negative samples are not necessary for learning high-quality representations. This simplification opens new possibilities for domains where defining negatives is difficult or counterproductive.
| Aspect | BYOL | SimSiam | Recommendation |
|---|---|---|---|
| Implementation complexity | Medium | Simple | Start with SimSiam for prototyping |
| Hyperparameter sensitivity | τ requires tuning | Fewer parameters | SimSiam for robustness |
| Small batch training | Better (EMA stabilizes) | May oscillate | BYOL for batch size < 128 |
| Training speed | Slower (EMA overhead) | Faster | SimSiam for iteration speed |
| Final accuracy | Slightly higher | Competitive | BYOL for production |
| Theoretical clarity | Implicit regularization | EM interpretation | SimSiam for research |
You now understand the principles behind non-contrastive self-supervised learning through BYOL and SimSiam. These methods eliminate negative sample requirements while achieving state-of-the-art representations. Next, we'll explore clustering-based self-supervised methods that provide alternative approaches to learning without explicit labels.