Loading content...
Real-world classification problems often involve more than two classes. Handwritten digit recognition has 10 classes. ImageNet classification has 1000. Natural language intent classification may have dozens or hundreds of categories. How do we extend the elegant GP classification framework—developed for binary problems—to these multi-class settings?
The answer involves both modeling choices and computational challenges. We must design likelihoods that map multiple latent functions to class probabilities, handle the increased dimensionality of the latent space, and develop efficient inference algorithms that scale with the number of classes.
This page develops multi-class GP classification from first principles, connecting the various approaches and highlighting the trade-offs between accuracy, calibration, and computational cost.
By the end of this page, you will understand: (1) The one-vs-rest and one-vs-one decomposition strategies, (2) The multi-class softmax (multinomial probit/logit) likelihood, (3) Multi-output GP formulations for class-specific latent functions, (4) Inference approaches including Laplace, EP, and variational methods, (5) Computational complexity and scalability considerations, and (6) Practical guidelines for method selection.
Consider a classification problem with $C > 2$ classes. For each input $\mathbf{x}$, we want to predict probabilities for each class:
$$\pi_c(\mathbf{x}) = p(y = c | \mathbf{x}), \quad c \in {1, ..., C}$$
subject to $\sum_{c=1}^C \pi_c(\mathbf{x}) = 1$ and $\pi_c(\mathbf{x}) \geq 0$.
The Latent Function Approach:
We introduce $C$ latent functions $f_1(\mathbf{x}), ..., f_C(\mathbf{x})$, each drawn from a Gaussian process. The class probabilities are obtained by passing these latent values through a link function.
The Key Question:
How should the $C$ latent GPs be related?
Independent GPs: Each $f_c$ is a separate GP with its own kernel. Simple but ignores class correlations.
Shared kernel: All $f_c$ use the same kernel, differing only in realizations. Captures input-space correlations.
Intrinsic coregionalization: A multi-output GP framework that explicitly models correlations between latent functions.
Notation:
For $n$ training points and $C$ classes:
The Stacked Representation:
Often we stack all latent values into a single vector:
$$\mathbf{f} = [\mathbf{f}_1^\top, ..., \mathbf{f}_C^\top]^\top \in \mathbb{R}^{nC}$$
The prior is then a $nC$-dimensional Gaussian, and the inference challenge scales correspondingly.
Multi-class GP classification involves O(n²C²) parameters for the full covariance matrix and O(n³C³) operations for inference. With n=1000 and C=10, this is already 100M parameters and 10¹² operations. Scalability is a central challenge.
Before developing a principled multi-class model, let's consider simpler decomposition strategies that reduce multi-class to binary classification.
One-vs-Rest (OvR) / One-vs-All (OvA):
Train $C$ independent binary classifiers:
Prediction: $$\tilde{\pi}_c(\mathbf{x}) = p(y = c | y \in {c, \text{not-}c}, \mathbf{x})$$
Normalize to get valid probabilities: $$\pi_c(\mathbf{x}) = \frac{\tilde{\pi}c(\mathbf{x})}{\sum{c'=1}^C \tilde{\pi}_{c'}(\mathbf{x})}$$
One-vs-One (OvO):
Train $\binom{C}{2} = C(C-1)/2$ binary classifiers:
Prediction: Use voting: each classifier votes for one class, the class with most votes wins.
Or use pairwise coupling to combine probabilities.
Comparison:
| Aspect | One-vs-Rest | One-vs-One |
|---|---|---|
| Classifiers | $C$ | $C(C-1)/2$ |
| Training data per classifier | All $n$ | ~$2n/C$ |
| Imbalance | Severe | None |
| Scalability | Better | Worse for large $C$ |
Recommendation: For GPs, one-vs-rest is usually preferred due to computational considerations. But both are heuristics—principled multi-class models are better when tractable.
The principled approach to multi-class classification uses the softmax function to map $C$ latent values to class probabilities.
The Softmax Function:
Given latent values $\mathbf{f}(\mathbf{x}) = [f_1(\mathbf{x}), ..., f_C(\mathbf{x})]^\top$:
$$\pi_c(\mathbf{x}) = \frac{\exp(f_c(\mathbf{x}))}{\sum_{c'=1}^C \exp(f_{c'}(\mathbf{x}))} = \text{softmax}(\mathbf{f}(\mathbf{x}))_c$$
The Categorical Likelihood:
For a label $y \in {1, ..., C}$:
$$p(y | \mathbf{f}(\mathbf{x})) = \prod_{c=1}^C \pi_c(\mathbf{x})^{\mathbf{1}[y=c]} = \text{Categorical}(y | \boldsymbol{\pi}(\mathbf{x}))$$
Log-likelihood:
$$\log p(y | \mathbf{f}) = f_y - \log\sum_{c=1}^C \exp(f_c)$$
This is the well-known cross-entropy loss used throughout deep learning.
Identifiability:
The softmax is overparameterized: adding a constant to all latent values doesn't change probabilities:
$$\text{softmax}(\mathbf{f} + c\mathbf{1}) = \text{softmax}(\mathbf{f})$$
Solutions:
In practice, using all $C$ latent functions with appropriate priors works well.
Gradient for Optimization:
$$\frac{\partial \log p(y | \mathbf{f})}{\partial f_c} = \mathbf{1}[y=c] - \pi_c$$
This is exactly the 'residual' form we saw in binary classification, now extended to multiple classes.
The log-sum-exp in softmax can overflow. Use the stable form: log∑exp(f_c) = max(f) + log∑exp(f_c - max(f)). This is standard practice and implemented in all deep learning frameworks via logsoftmax functions.
With $C$ latent functions, we need a prior over vector-valued functions. This is the domain of multi-output Gaussian processes.
Independent GPs:
The simplest approach: each $f_c$ is an independent GP:
$$f_c \sim \mathcal{GP}(m_c, k_c)$$
The joint prior over stacked latents:
$$p(\mathbf{f}) = \prod_{c=1}^C \mathcal{N}(\mathbf{f}_c | \mathbf{m}_c, K_c)$$
This block-diagonal covariance structure means classes are a priori independent given inputs.
Shared Kernel:
A common simplification: use the same kernel for all classes:
$$k_1 = k_2 = ... = k_C = k$$
This reduces hyperparameters but still assumes independence between classes.
Intrinsic Coregionalization Model (ICM):
To model correlations between classes, use:
$$k((\mathbf{x}, c), (\mathbf{x}', c')) = k_\text{input}(\mathbf{x}, \mathbf{x}') \cdot B_{cc'}$$
where:
Properties:
Covariance Structure:
For the stacked representation $\mathbf{f} \in \mathbb{R}^{nC}$:
$$\text{Cov}[\mathbf{f}] = B \otimes K$$
where $\otimes$ is the Kronecker product and $K$ is the $n \times n$ input kernel matrix.
This Kronecker structure enables efficient computation: $(B \otimes K)^{-1} = B^{-1} \otimes K^{-1}$ and $|B \otimes K| = |B|^n |K|^C$.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
import numpy as npfrom scipy.linalg import cho_factor, cho_solve class MultiOutputGPPrior: """ Multi-output GP prior with Intrinsic Coregionalization Model. Covariance: Cov[f] = B ⊗ K where B is CxC coregionalization matrix and K is nxn input kernel matrix. """ def __init__(self, n_classes, kernel_fn, coregion_rank=None): self.C = n_classes self.kernel_fn = kernel_fn # Coregionalization matrix parameterization # B = W @ W.T + diag(kappa) for PSD guarantee if coregion_rank is None: coregion_rank = n_classes # Full rank self.W = np.random.randn(n_classes, coregion_rank) * 0.1 self.log_kappa = np.zeros(n_classes) @property def B(self): """Coregionalization matrix.""" kappa = np.exp(self.log_kappa) return self.W @ self.W.T + np.diag(kappa) def compute_covariance(self, X): """ Compute full covariance matrix for stacked latent vector. Parameters: ----------- X : ndarray (n, d) Input points Returns: -------- K_full : ndarray (nC, nC) Full covariance matrix (Kronecker structure) """ n = len(X) K_input = self.kernel_fn(X, X) # Kronecker product: B ⊗ K # Result is (nC) x (nC) K_full = np.kron(self.B, K_input) return K_full def log_prob_efficient(self, f_stacked, X): """ Compute log prior probability using Kronecker structure. Uses: (B ⊗ K)^{-1} = B^{-1} ⊗ K^{-1} log|B ⊗ K| = n*log|B| + C*log|K| """ n = len(X) K_input = self.kernel_fn(X, X) + 1e-6 * np.eye(n) # Cholesky factors L_K = np.linalg.cholesky(K_input) L_B = np.linalg.cholesky(self.B + 1e-6 * np.eye(self.C)) # Log determinant: n*log|B| + C*log|K| log_det_K = 2 * np.sum(np.log(np.diag(L_K))) log_det_B = 2 * np.sum(np.log(np.diag(L_B))) log_det = n * log_det_B + self.C * log_det_K # Reshape f to (n, C) for efficient computation F = f_stacked.reshape(n, self.C, order='F') # Column-major for Kronecker # Quadratic form: f^T (B^{-1} ⊗ K^{-1}) f # = tr(F^T K^{-1} F B^{-1}) K_inv_F = cho_solve(cho_factor(K_input), F) B_inv = cho_solve(cho_factor(self.B + 1e-6 * np.eye(self.C)), np.eye(self.C)) quad_form = np.trace(F.T @ K_inv_F @ B_inv) # Log probability log_prob = -0.5 * (n * self.C * np.log(2 * np.pi) + log_det + quad_form) return log_prob # Example usagedef rbf_kernel(X1, X2, lengthscale=1.0): sq_dist = np.sum((X1[:, None] - X2[None, :])**2, axis=-1) return np.exp(-0.5 * sq_dist / lengthscale**2) n_classes = 3n_points = 50X = np.random.randn(n_points, 2)f = np.random.randn(n_points * n_classes) prior = MultiOutputGPPrior(n_classes, rbf_kernel)log_p = prior.log_prob_efficient(f, X)print(f"Log prior: {log_p:.4f}")print(f"Coregionalization matrix B:\n{prior.B}")All inference methods for binary GPC extend to multi-class, with increased computational complexity.
Laplace Approximation:
Find the mode of the posterior:
$$\hat{\mathbf{f}} = \arg\max_\mathbf{f} \left[\sum_{i=1}^n \log p(y_i | \mathbf{f}_i) + \log p(\mathbf{f} | X)\right]$$
where $\mathbf{f}_i = [f_1(\mathbf{x}_i), ..., f_C(\mathbf{x}_i)]^\top$.
Gradient (for point $i$, class $c$): $$\frac{\partial \log p(y_i | \mathbf{f}i)}{\partial f{ic}} = \mathbf{1}[y_i = c] - \pi_{ic}$$
Hessian (block for point $i$): $$\frac{\partial^2 \log p(y_i | \mathbf{f}i)}{\partial f{ic} \partial f_{ic'}} = -\pi_{ic}(\delta_{cc'} - \pi_{ic'})$$
The Hessian is block-diagonal (blocks of size $C \times C$), where each block is: $$H_i = -\text{diag}(\boldsymbol{\pi}_i) + \boldsymbol{\pi}_i \boldsymbol{\pi}_i^\top$$
Newton-Raphson for Multi-class:
The update is structurally similar to binary:
$$\mathbf{f}^{(t+1)} = (K_{\text{full}}^{-1} + W)^{-1}(W\mathbf{f}^{(t)} + \nabla \log p(\mathbf{y} | \mathbf{f}^{(t)}))$$
where $K_{\text{full}} = B \otimes K$ and $W$ is the block-diagonal Hessian.
Complexity:
Exploiting Kronecker Structure:
With ICM prior, use: $$(B \otimes K + W)^{-1}$$
This doesn't have simple Kronecker structure due to $W$, but iterative methods (conjugate gradient) can exploit the structure for matrix-vector products.
For moderate C and n, independent GPs (B = I) are often sufficient. This reduces the problem to C independent binary-like problems, each with O(n³) complexity, giving O(Cn³) total. Only use ICM if class correlations are expected and computationally feasible.
Variational inference is particularly attractive for multi-class GPC due to its natural handling of sparse approximations.
The Multi-class ELBO:
$$\mathcal{L} = \sum_{i=1}^n \mathbb{E}_{q(\mathbf{f}_i)}\left[\log p(y_i | \mathbf{f}_i)\right] - \text{KL}(q(\mathbf{f}) || p(\mathbf{f} | X))$$
Sparse Variational Formulation:
With $m$ inducing points per class (or shared):
For Independent GPs per Class:
The ELBO decomposes: $$\mathcal{L} = \sum_{i=1}^n \mathbb{E}_{q(\mathbf{f}_i)}[\log p(y_i | \mathbf{f}i)] - \sum{c=1}^C \text{KL}(q(\mathbf{u}_c) || p(\mathbf{u}_c))$$
The KL term separates into $C$ independent terms, but the expected log-likelihood still couples all classes at each point due to the softmax.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
import torchimport torch.nn as nnimport torch.nn.functional as F class SparseMulticlassGPC(nn.Module): """ Sparse Variational Multi-class Gaussian Process Classification. Uses independent GPs per class with shared inducing point locations. """ def __init__(self, X_train, y_train, n_classes, n_inducing=50): super().__init__() self.n = len(y_train) self.C = n_classes self.m = n_inducing self.X = torch.tensor(X_train, dtype=torch.float32) self.y = torch.tensor(y_train, dtype=torch.long) # Class indices # Kernel hyperparameters (shared across classes) self.log_lengthscale = nn.Parameter(torch.tensor(0.0)) self.log_variance = nn.Parameter(torch.tensor(0.0)) # Shared inducing locations indices = torch.randperm(self.n)[:n_inducing] self.Z = nn.Parameter(self.X[indices].clone()) # Variational parameters: q(u_c) = N(m_c, S_c) for each class self.q_mu = nn.Parameter(torch.zeros(n_inducing, n_classes)) # Lower-triangular Cholesky factors for each class self.q_L = nn.Parameter( torch.stack([torch.eye(n_inducing) for _ in range(n_classes)], dim=-1) ) def kernel(self, X1, X2): """RBF kernel.""" lengthscale = torch.exp(self.log_lengthscale) variance = torch.exp(self.log_variance) dist_sq = torch.cdist(X1 / lengthscale, X2 / lengthscale, p=2)**2 return variance * torch.exp(-0.5 * dist_sq) def get_q_f(self, X_batch): """ Get variational distribution q(f) at batch points. Returns mean and variance for each class at each point. """ batch_size = len(X_batch) K_uu = self.kernel(self.Z, self.Z) + 1e-6 * torch.eye(self.m) K_bu = self.kernel(X_batch, self.Z) L_uu = torch.linalg.cholesky(K_uu) A = torch.linalg.solve(L_uu.T, torch.linalg.solve(L_uu, K_bu.T)).T # K_bu @ K_uu^{-1} f_means = [] f_vars = [] for c in range(self.C): # q(u_c) parameters m_c = self.q_mu[:, c] L_c = torch.tril(self.q_L[:, :, c]) S_c = L_c @ L_c.T # q(f_c) = N(A @ m_c, k(x,x) - A(K_uu - S_c)A^T) f_mean_c = A @ m_c # Variance: diagonal only v = torch.linalg.solve(L_uu, K_bu.T) var_prior = torch.exp(self.log_variance) # k(x,x) diagonal var_reduction = (v**2).sum(dim=0) # A K_uu A^T diagonal # Add back S_c contribution w = A @ L_c var_addition = (w**2).sum(dim=1) # A S_c A^T diagonal f_var_c = var_prior - var_reduction + var_addition f_means.append(f_mean_c) f_vars.append(f_var_c) # Stack: (batch_size, n_classes) return torch.stack(f_means, dim=1), torch.stack(f_vars, dim=1) def elbo(self, batch_indices, n_samples=5): """Compute sparse variational ELBO for multi-class.""" batch_size = len(batch_indices) X_batch = self.X[batch_indices] y_batch = self.y[batch_indices] # Get q(f) at batch points f_mean, f_var = self.get_q_f(X_batch) # (batch, C) f_std = torch.sqrt(torch.clamp(f_var, min=1e-6)) # Monte Carlo estimate of E[log p(y|f)] epsilon = torch.randn(n_samples, batch_size, self.C) f_samples = f_mean.unsqueeze(0) + epsilon * f_std.unsqueeze(0) # Log-softmax for numerical stability log_probs = F.log_softmax(f_samples, dim=-1) # (samples, batch, C) # Gather log-prob of true class y_expanded = y_batch.unsqueeze(0).unsqueeze(-1).expand(n_samples, -1, 1) log_lik = log_probs.gather(-1, y_expanded).squeeze(-1) # (samples, batch) expected_ll = log_lik.mean(dim=0).sum() * (self.n / batch_size) # KL divergence: sum over classes K_uu = self.kernel(self.Z, self.Z) + 1e-6 * torch.eye(self.m) L_uu = torch.linalg.cholesky(K_uu) kl_total = 0 for c in range(self.C): m_c = self.q_mu[:, c] L_c = torch.tril(self.q_L[:, :, c]) alpha = torch.linalg.solve(L_uu, m_c) beta = torch.linalg.solve(L_uu, L_c) kl_c = 0.5 * ( (beta**2).sum() + (alpha**2).sum() - self.m + 2 * torch.log(torch.diag(L_uu)).sum() - 2 * torch.log(torch.diag(L_c).abs()).sum() ) kl_total = kl_total + kl_c return expected_ll - kl_total def predict(self, X_test, n_samples=100): """Predict class probabilities at test points.""" with torch.no_grad(): X_test = torch.tensor(X_test, dtype=torch.float32) f_mean, f_var = self.get_q_f(X_test) f_std = torch.sqrt(torch.clamp(f_var, min=1e-6)) # Monte Carlo epsilon = torch.randn(n_samples, len(X_test), self.C) f_samples = f_mean.unsqueeze(0) + epsilon * f_std.unsqueeze(0) probs = F.softmax(f_samples, dim=-1).mean(dim=0) return probs.numpy()| Method | Time Complexity | Space Complexity | Notes |
|---|---|---|---|
| Full Laplace | O(n³C³) | O(n²C²) | Exact prior; prohibitive for large n,C |
| Independent Laplace | O(Cn³) | O(n²) | Ignores class correlations |
| Sparse VI (shared Z) | O(nm²C + m³C) | O(nmC + m²C) | Scalable; practical choice |
| One-vs-Rest | O(Cn³) | O(n²) | C binary problems |
| Robustmax | O(Cn³) | O(n²C) | Parallel C problems, shared kernel |
Practical Guidelines:
| Scenario | Recommended Approach |
|---|---|
| Small n, small C (< 5 classes, < 500 points) | Full multi-class Laplace |
| Small n, large C (many classes, < 500 points) | Independent GPs or sparse VI |
| Large n, small C | Sparse multi-class VI |
| Large n, large C | Sparse VI with shared inducing |
| Need calibrated probabilities | Variational or EP (not OvR) |
| Need interpretable correlations | ICM with coregionalization |
Key Hyperparameters:
GPflow provides robust multi-class GPC with sparse variational inference. GPyTorch supports multi-class through variational strategies. These libraries handle the numerical details (stable softmax, reparameterization, efficient Kronecker operations) that are error-prone to implement from scratch.
Multi-class GP classification extends the binary framework to handle problems with more than two classes, requiring careful attention to both modeling and computation.
Module Complete:
This concludes our comprehensive treatment of GP Classification. We've journeyed from the fundamental challenge of non-Gaussian likelihoods, through three major inference approaches (Laplace, EP, Variational), to multi-class extensions. You now have the conceptual foundation and practical toolkit to apply GP classification to real-world problems, understanding both the power and limitations of each approach.
You have mastered Gaussian Process Classification—from non-Gaussian likelihoods that break conjugacy, through Laplace approximation, Expectation Propagation, and variational inference, to multi-class extensions with softmax likelihoods. These methods power uncertainty-aware classification in applications from medical diagnosis to autonomous systems.