Loading content...
All the methods we've examined so far are discriminative—they directly model $P(Y|X)$, the conditional distribution of labels given features. But there's a fundamentally different approach: generative models that model the joint distribution $P(X, Y) = P(X|Y)P(Y)$.
The key insight: unlabeled data tells us about $P(X)$. If we model $P(X|Y)$ for each class, observing unlabeled data helps us understand where each class lives in feature space. This refined understanding of class structure improves classification even without knowing the labels.
By the end of this page, you will understand the generative approach to SSL, Gaussian mixture models with labeled data, EM for semi-supervised learning, deep generative models (VAE, GAN) for SSL, and when generative approaches outperform discriminative ones.
The Fundamental Difference:
Discriminative: Learn $P(Y|X)$ directly. Unlabeled data (only $X$, no $Y$) seems useless—what does the input tell us about the decision boundary?
Generative: Learn $P(X, Y) = P(X|Y)P(Y)$. Unlabeled data helps estimate $P(X)$, which constrains $P(X|Y)$ since $P(X) = \sum_k P(X|Y=k)P(Y=k)$.
How Unlabeled Data Helps Generatively:
Consider a two-class Gaussian mixture:
With only labeled data, we estimate $\mu_0, \mu_1$ from a few examples—high variance. But unlabeled data provides additional constraints: $$P(X) = \pi_0 \mathcal{N}(\mu_0, \Sigma_0) + \pi_1 \mathcal{N}(\mu_1, \Sigma_1)$$
The unlabeled data must be explained by this mixture. This constrains the parameters, effectively regularizing our estimates.
The Gaussian Mixture Model (GMM) is the classic generative model for semi-supervised learning. We assume each class is generated from a Gaussian:
$$P(X|Y=k) = \mathcal{N}(X; \mu_k, \Sigma_k)$$
The joint log-likelihood over labeled and unlabeled data:
$$\log P(\mathcal{D}) = \sum_{i \in L} \log P(x_i, y_i) + \sum_{j \in U} \log P(x_j)$$
$$= \sum_{i \in L} \log [\pi_{y_i} \mathcal{N}(x_i; \mu_{y_i}, \Sigma_{y_i})] + \sum_{j \in U} \log \sum_k \pi_k \mathcal{N}(x_j; \mu_k, \Sigma_k)$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
import numpy as npfrom scipy.stats import multivariate_normalfrom scipy.special import logsumexp class SemiSupervisedGMM: """ Gaussian Mixture Model for semi-supervised classification. Uses EM to learn from labeled and unlabeled data jointly. """ def __init__(self, n_classes, n_features, regularization=1e-6): self.n_classes = n_classes self.n_features = n_features self.reg = regularization def _initialize(self, X_labeled, y_labeled, X_unlabeled): """Initialize parameters from labeled data.""" X_all = np.vstack([X_labeled, X_unlabeled]) self.means_ = np.zeros((self.n_classes, self.n_features)) self.covs_ = np.zeros((self.n_classes, self.n_features, self.n_features)) self.weights_ = np.zeros(self.n_classes) for k in range(self.n_classes): X_k = X_labeled[y_labeled == k] if len(X_k) > 0: self.means_[k] = X_k.mean(axis=0) if len(X_k) > 1: self.covs_[k] = np.cov(X_k.T) + self.reg * np.eye(self.n_features) else: self.covs_[k] = np.eye(self.n_features) else: # No labeled examples for this class self.means_[k] = X_all.mean(axis=0) + np.random.randn(self.n_features) * 0.1 self.covs_[k] = np.eye(self.n_features) self.weights_[k] = max(1, (y_labeled == k).sum()) self.weights_ /= self.weights_.sum() def _e_step(self, X_labeled, y_labeled, X_unlabeled): """ E-step: Compute responsibilities for unlabeled data. Labeled data has hard assignments. """ n_unlabeled = len(X_unlabeled) # Responsibilities for unlabeled log_resp = np.zeros((n_unlabeled, self.n_classes)) for k in range(self.n_classes): log_resp[:, k] = (np.log(self.weights_[k]) + multivariate_normal.logpdf(X_unlabeled, self.means_[k], self.covs_[k])) # Normalize (softmax) log_resp -= logsumexp(log_resp, axis=1, keepdims=True) resp_unlabeled = np.exp(log_resp) # Hard assignments for labeled resp_labeled = np.zeros((len(X_labeled), self.n_classes)) resp_labeled[np.arange(len(y_labeled)), y_labeled] = 1.0 return resp_labeled, resp_unlabeled def _m_step(self, X_labeled, resp_labeled, X_unlabeled, resp_unlabeled): """M-step: Update parameters using all data.""" X_all = np.vstack([X_labeled, X_unlabeled]) resp_all = np.vstack([resp_labeled, resp_unlabeled]) n_k = resp_all.sum(axis=0) + 1e-10 # Effective count per class for k in range(self.n_classes): # Weighted mean self.means_[k] = (resp_all[:, k:k+1].T @ X_all) / n_k[k] # Weighted covariance diff = X_all - self.means_[k] self.covs_[k] = (diff.T @ (resp_all[:, k:k+1] * diff)) / n_k[k] self.covs_[k] += self.reg * np.eye(self.n_features) # Regularization self.weights_ = n_k / n_k.sum() def fit(self, X_labeled, y_labeled, X_unlabeled, max_iter=100, tol=1e-4): """Fit using EM algorithm.""" self._initialize(X_labeled, y_labeled, X_unlabeled) prev_ll = -np.inf for iteration in range(max_iter): # E-step resp_labeled, resp_unlabeled = self._e_step( X_labeled, y_labeled, X_unlabeled ) # M-step self._m_step(X_labeled, resp_labeled, X_unlabeled, resp_unlabeled) # Log-likelihood ll = self._log_likelihood(X_labeled, y_labeled, X_unlabeled) if abs(ll - prev_ll) < tol: print(f"Converged at iteration {iteration}") break prev_ll = ll return self def _log_likelihood(self, X_labeled, y_labeled, X_unlabeled): """Compute log-likelihood of data.""" ll = 0 # Labeled: P(x, y) for i, (x, y) in enumerate(zip(X_labeled, y_labeled)): ll += np.log(self.weights_[y]) ll += multivariate_normal.logpdf(x, self.means_[y], self.covs_[y]) # Unlabeled: P(x) = sum_k P(x|k)P(k) for x in X_unlabeled: log_probs = [np.log(self.weights_[k]) + multivariate_normal.logpdf(x, self.means_[k], self.covs_[k]) for k in range(self.n_classes)] ll += logsumexp(log_probs) return ll def predict(self, X): """Predict class labels.""" log_probs = np.zeros((len(X), self.n_classes)) for k in range(self.n_classes): log_probs[:, k] = (np.log(self.weights_[k]) + multivariate_normal.logpdf(X, self.means_[k], self.covs_[k])) return log_probs.argmax(axis=1) def predict_proba(self, X): """Predict class probabilities.""" log_probs = np.zeros((len(X), self.n_classes)) for k in range(self.n_classes): log_probs[:, k] = (np.log(self.weights_[k]) + multivariate_normal.logpdf(X, self.means_[k], self.covs_[k])) log_probs -= logsumexp(log_probs, axis=1, keepdims=True) return np.exp(log_probs)If data isn't actually Gaussian, the model will fit the wrong structure. Notoriously, unlabeled data can HURT performance if the generative model is misspecified—it pulls parameters toward explaining features, not toward classification accuracy.
Modern deep generative models extend the generative paradigm to complex, high-dimensional data like images.
Semi-Supervised VAE (Kingma et al., 2014):
The VAE framework naturally extends to SSL. Model:
For labeled data: Maximize $\log p(x,y)$ For unlabeled data: Marginalize over $y$: $\log p(x) = \log \sum_y p(x,y)$
The ELBO becomes: $$\mathcal{L}{\text{labeled}} = \mathbb{E}{q(z|x,y)}[\log p(x|y,z)] - D_{KL}(q(z|x,y) | p(z))$$ $$\mathcal{L}{\text{unlabeled}} = \sum_y q(y|x) \mathcal{L}{\text{labeled}}(x,y) + H(q(y|x))$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
import torchimport torch.nn as nnimport torch.nn.functional as F class SemiSupervisedVAE(nn.Module): """ Semi-Supervised VAE (M2 model from Kingma et al., 2014). """ def __init__(self, input_dim, hidden_dim, latent_dim, n_classes): super().__init__() self.n_classes = n_classes self.latent_dim = latent_dim # Encoder q(z|x,y) self.encoder = nn.Sequential( nn.Linear(input_dim + n_classes, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU() ) self.enc_mu = nn.Linear(hidden_dim, latent_dim) self.enc_logvar = nn.Linear(hidden_dim, latent_dim) # Decoder p(x|y,z) self.decoder = nn.Sequential( nn.Linear(latent_dim + n_classes, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) ) # Classifier q(y|x) self.classifier = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, n_classes) ) def encode(self, x, y_onehot): """Encode to latent distribution parameters.""" xy = torch.cat([x, y_onehot], dim=1) h = self.encoder(xy) return self.enc_mu(h), self.enc_logvar(h) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z, y_onehot): """Decode from latent + class.""" zy = torch.cat([z, y_onehot], dim=1) return self.decoder(zy) def forward_labeled(self, x, y): """ELBO for labeled data.""" y_onehot = F.one_hot(y, self.n_classes).float() mu, logvar = self.encode(x, y_onehot) z = self.reparameterize(mu, logvar) x_recon = self.decode(z, y_onehot) # Reconstruction loss recon_loss = F.mse_loss(x_recon, x, reduction='sum') # KL divergence kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Classification loss (auxiliary) logits = self.classifier(x) class_loss = F.cross_entropy(logits, y, reduction='sum') return recon_loss + kl_loss + class_loss def forward_unlabeled(self, x): """ELBO for unlabeled data (marginalize over y).""" # q(y|x) - infer label distribution logits = self.classifier(x) qy = F.softmax(logits, dim=1) total_loss = 0 for k in range(self.n_classes): y_onehot = F.one_hot(torch.tensor([k]), self.n_classes).float() y_onehot = y_onehot.expand(len(x), -1).to(x.device) mu, logvar = self.encode(x, y_onehot) z = self.reparameterize(mu, logvar) x_recon = self.decode(z, y_onehot) recon = F.mse_loss(x_recon, x, reduction='none').sum(dim=1) kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(dim=1) # Weight by q(y=k|x) total_loss += qy[:, k] * (recon + kl) # Entropy bonus H(q(y|x)) entropy = -(qy * torch.log(qy + 1e-10)).sum(dim=1) return total_loss.sum() - entropy.sum()GANs for Semi-Supervised Learning:
Goodfellow et al. proposed modifying the discriminator to output $K+1$ classes:
The discriminator learns to:
Generator provides additional training signal by generating "hard" examples near decision boundaries.
| Scenario | Recommendation |
|---|---|
| Very few labels (< 10 per class) | Generative (more sample efficient) |
| Moderate labels, good model | Discriminative (simpler, often better) |
| Model misspecification likely | Discriminative (robust to wrong assumptions) |
| Need uncertainty estimates | Generative (probabilistic framework) |
| High-dimensional complex data | Hybrid (deep generative + discriminative) |
Modern practice often combines both: use a generative model (VAE, flow) to learn representations from unlabeled data, then train a discriminative classifier on these representations using labeled data. This gets the best of both worlds—generative structure learning with discriminative classification.
Module Complete!
You've now mastered the full spectrum of semi-supervised learning methods:
These techniques form the foundation for working with limited labeled data—a constant challenge in real-world machine learning.
Congratulations! You've completed Module 2: Semi-Supervised Methods. You now have a comprehensive toolkit for learning from limited labels, from simple self-training to sophisticated deep generative models.