Loading content...
Stochastic Gradient Descent (SGD) works well for matrix factorization, but it has a fundamental limitation: it's inherently sequential. Each update depends on the previous one, making parallelization difficult.
For systems like Netflix (100M+ users, 10K+ items) or Spotify (500M+ users, 100M+ tracks), we need algorithms that can leverage distributed computing. Alternating Least Squares (ALS) provides exactly this—a method where each step is embarrassingly parallel, enabling training on clusters of thousands of machines.
This page covers the ALS algorithm, why fixing one matrix makes the problem convex, how to solve the resulting least squares problems efficiently, parallelization strategies, and when to choose ALS over SGD.
The matrix factorization objective is non-convex in P and Q jointly:
min_{P,Q} Σ (r_ui - p_u · q_i)² + λ(||P||² + ||Q||²)
However, here's the key insight:
This suggests an alternating optimization strategy:
Each step reduces the objective, guaranteeing convergence to a local minimum.
When Q is fixed, each user's p_u depends only on their ratings and the fixed Q—not on other users. We can update all users in parallel! Similarly, when P is fixed, all items can be updated in parallel. This is 'embarrassingly parallel' computation.
Consider updating user u's latent vector p_u while Q is fixed. The objective terms involving p_u are:
L_u = Σ_{i ∈ I_u} (r_ui - p_u · q_i)² + λ||p_u||²
Where I_u is the set of items rated by user u. This is a ridge regression problem!
Let:
Then p_u solves: p_u = (Q_u^T Q_u + λI)^{-1} Q_u^T r_u
This is the standard ridge regression closed-form solution.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
import numpy as npfrom collections import defaultdictfrom typing import List, Dict, Set class ALSMatrixFactorization: """ Matrix Factorization using Alternating Least Squares. """ def __init__(self, n_users: int, n_items: int, n_factors: int = 50, reg: float = 0.1): self.n_users = n_users self.n_items = n_items self.n_factors = n_factors self.reg = reg # Initialize factors randomly self.P = np.random.normal(0, 0.1, (n_users, n_factors)) self.Q = np.random.normal(0, 0.1, (n_items, n_factors)) def _solve_user(self, user_id: int, item_ids: List[int], ratings: np.ndarray) -> np.ndarray: """Solve for p_u given fixed Q using ridge regression.""" Q_u = self.Q[item_ids] # |I_u| x k # (Q_u^T @ Q_u + λI)^{-1} @ Q_u^T @ r_u A = Q_u.T @ Q_u + self.reg * np.eye(self.n_factors) b = Q_u.T @ ratings return np.linalg.solve(A, b) def _solve_item(self, item_id: int, user_ids: List[int], ratings: np.ndarray) -> np.ndarray: """Solve for q_i given fixed P using ridge regression.""" P_i = self.P[user_ids] # |U_i| x k A = P_i.T @ P_i + self.reg * np.eye(self.n_factors) b = P_i.T @ ratings return np.linalg.solve(A, b) def fit(self, user_items: Dict[int, List[tuple]], item_users: Dict[int, List[tuple]], n_epochs: int = 15): """ Train using ALS. Args: user_items: {user_id: [(item_id, rating), ...]} item_users: {item_id: [(user_id, rating), ...]} """ for epoch in range(n_epochs): # Step 1: Fix Q, update all P for u in range(self.n_users): if u in user_items and user_items[u]: items = [x[0] for x in user_items[u]] ratings = np.array([x[1] for x in user_items[u]]) self.P[u] = self._solve_user(u, items, ratings) # Step 2: Fix P, update all Q for i in range(self.n_items): if i in item_users and item_users[i]: users = [x[0] for x in item_users[i]] ratings = np.array([x[1] for x in item_users[i]]) self.Q[i] = self._solve_item(i, users, ratings) # Compute loss loss = self._compute_loss(user_items) print(f"Epoch {epoch+1}: Loss = {loss:.4f}") def _compute_loss(self, user_items: Dict) -> float: """Compute regularized squared error.""" sq_err = 0.0 count = 0 for u, items in user_items.items(): for (i, r) in items: pred = np.dot(self.P[u], self.Q[i]) sq_err += (r - pred) ** 2 count += 1 reg_term = self.reg * (np.sum(self.P**2) + np.sum(self.Q**2)) return sq_err + reg_term def predict(self, user_id: int, item_id: int) -> float: return np.dot(self.P[user_id], self.Q[item_id])ALS's structure enables several parallelization approaches:
Data Parallelism (Most Common): Distribute users across workers. Each worker:
Computational Bottlenecks:
| Framework | Approach | Scale |
|---|---|---|
| Spark MLlib | RDD partitioning, broadcast variables | 100M+ users |
| Parameter Server | Async updates, model sharding | Billion+ scale |
| GPU (cuMF) | Matrix ops on GPU, batched solves | 10-100x speedup |
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
from multiprocessing import Poolimport numpy as np def parallel_als_step(args): """Worker function for parallel user updates.""" user_id, item_ids, ratings, Q, reg, n_factors = args Q_u = Q[item_ids] A = Q_u.T @ Q_u + reg * np.eye(n_factors) b = Q_u.T @ np.array(ratings) return user_id, np.linalg.solve(A, b) class ParallelALS: def __init__(self, n_users, n_items, n_factors=50, reg=0.1, n_workers=4): self.n_users = n_users self.n_items = n_items self.n_factors = n_factors self.reg = reg self.n_workers = n_workers self.P = np.random.normal(0, 0.1, (n_users, n_factors)) self.Q = np.random.normal(0, 0.1, (n_items, n_factors)) def fit(self, user_items, item_users, n_epochs=15): for epoch in range(n_epochs): # Parallel user updates user_args = [ (u, [x[0] for x in user_items[u]], [x[1] for x in user_items[u]], self.Q, self.reg, self.n_factors) for u in user_items ] with Pool(self.n_workers) as pool: results = pool.map(parallel_als_step, user_args) for user_id, p_u in results: self.P[user_id] = p_u # Parallel item updates (similar) item_args = [ (i, [x[0] for x in item_users[i]], [x[1] for x in item_users[i]], self.P, self.reg, self.n_factors) for i in item_users ] with Pool(self.n_workers) as pool: results = pool.map(parallel_als_step, item_args) for item_id, q_i in results: self.Q[item_id] = q_i print(f"Epoch {epoch+1} complete")ALS typically converges in 10-20 iterations vs. 50-100+ epochs for SGD. However, each ALS iteration is more expensive (matrix solves vs. vector updates). Total wall-clock time depends heavily on hardware and parallelization.
You now understand ALS optimization: the alternating convexity insight, closed-form solutions via ridge regression, and parallelization strategies. Next, we'll explore regularization techniques that prevent overfitting in matrix factorization models.