Loading learning content...
Recommending from millions of items in real-time is impossible if each user-item pair requires a forward pass. The Two-Tower architecture (also called Dual Encoder) solves this with a elegant separation:
This separation enables pre-computation: item embeddings are computed once and stored. At serving time, we only run the user tower and perform approximate nearest neighbor (ANN) search to find top items.
Two-tower models are the backbone of production recommendation systems at Google, Meta, TikTok, LinkedIn, and virtually every major platform.
This page covers: (1) The two-tower architecture and why it scales, (2) Tower design for users and items, (3) Training objectives and negative sampling strategies, (4) Handling feature interactions, (5) Production deployment with ANN search, and (6) Evaluation challenges specific to retrieval.
The Core Idea:
Instead of: $$\hat{y}{ui} = f{interaction}(user_features_u, item_features_i)$$
We decompose into: $$\hat{y}{ui} = \langle f{user}(user_features_u), f_{item}(item_features_i) \rangle$$
where $\langle \cdot, \cdot \rangle$ is the inner product.
Why This Matters:
| Approach | Complexity per item | For 10M items, 100 users/sec |
|---|---|---|
| Full model | O(M × N × D) | 1 trillion FLOPs/sec |
| Two-tower | O(M × D + ANN) | 1 million FLOPs/sec |
The speedup is often 1000x+, making real-time recommendations feasible.
The Tradeoff:
By forcing decomposition, we lose the ability to model feature interactions between user and item. The towers cannot communicate during forward pass. Various techniques address this limitation.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
import torchimport torch.nn as nnimport torch.nn.functional as F class TwoTowerModel(nn.Module): """ Two-Tower / Dual Encoder model for large-scale retrieval. Separate towers encode users and items into a shared embedding space. Matching is performed via dot product for efficiency. """ def __init__( self, user_feature_dims: dict, # {feature_name: vocab_size} item_feature_dims: dict, embedding_dim: int = 32, hidden_dims: list = [256, 128], output_dim: int = 64, dropout: float = 0.2 ): super().__init__() self.output_dim = output_dim # User tower self.user_tower = Tower( feature_dims=user_feature_dims, embedding_dim=embedding_dim, hidden_dims=hidden_dims, output_dim=output_dim, dropout=dropout ) # Item tower (can have different features/architecture) self.item_tower = Tower( feature_dims=item_feature_dims, embedding_dim=embedding_dim, hidden_dims=hidden_dims, output_dim=output_dim, dropout=dropout ) # Optional: learnable temperature for softmax self.temperature = nn.Parameter(torch.ones(1) * 0.07) def forward( self, user_features: dict, item_features: dict, return_embeddings: bool = False ): """ Compute match scores between users and items. Args: user_features: {feature_name: tensor of shape (batch,)} item_features: {feature_name: tensor} return_embeddings: If True, return embeddings instead of scores Returns: scores: (batch,) dot product scores or (user_emb, item_emb) if return_embeddings=True """ # Encode through towers user_emb = self.user_tower(user_features) # (batch, output_dim) item_emb = self.item_tower(item_features) # (batch, output_dim) # L2 normalize for cosine similarity user_emb = F.normalize(user_emb, p=2, dim=-1) item_emb = F.normalize(item_emb, p=2, dim=-1) if return_embeddings: return user_emb, item_emb # Dot product scores scores = (user_emb * item_emb).sum(dim=-1) / self.temperature return scores def encode_users(self, user_features: dict): """Encode users for retrieval.""" user_emb = self.user_tower(user_features) return F.normalize(user_emb, p=2, dim=-1) def encode_items(self, item_features: dict): """Encode items for indexing.""" item_emb = self.item_tower(item_features) return F.normalize(item_emb, p=2, dim=-1) class Tower(nn.Module): """ Single tower (encoder) for user or item side. Handles multiple categorical and numerical features. """ def __init__( self, feature_dims: dict, embedding_dim: int, hidden_dims: list, output_dim: int, dropout: float ): super().__init__() # Embedding layers for categorical features self.embeddings = nn.ModuleDict() total_emb_dim = 0 for name, vocab_size in feature_dims.items(): self.embeddings[name] = nn.Embedding(vocab_size, embedding_dim) total_emb_dim += embedding_dim # MLP layers layers = [] prev_dim = total_emb_dim for hidden_dim in hidden_dims: layers.extend([ nn.Linear(prev_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout) ]) prev_dim = hidden_dim layers.append(nn.Linear(prev_dim, output_dim)) self.mlp = nn.Sequential(*layers) self._init_weights() def _init_weights(self): for emb in self.embeddings.values(): nn.init.xavier_uniform_(emb.weight) def forward(self, features: dict): """Encode features into embedding.""" embeddings = [] for name, tensor in features.items(): if name in self.embeddings: emb = self.embeddings[name](tensor) embeddings.append(emb) # Concatenate all embeddings x = torch.cat(embeddings, dim=-1) # Through MLP output = self.mlp(x) return outputTraining two-tower models requires careful consideration of negative sampling and loss functions.
1. In-Batch Negatives (Contrastive Learning):
For a batch of (user, positive_item) pairs, use other items in the batch as negatives:
$$\mathcal{L} = -\log \frac{\exp(s(u_i, v_i^+)/\tau)}{\exp(s(u_i, v_i^+)/\tau) + \sum_{j \neq i} \exp(s(u_i, v_j)/\tau)}$$
Where $s(u, v)$ is the similarity score and $\tau$ is temperature.
Advantages:
Disadvantages:
2. Hard Negative Mining:
Sample negatives that are similar to positives (high score but incorrect). These provide stronger learning signal.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
def in_batch_softmax_loss(user_emb, item_emb, temperature=0.07): """ In-batch softmax loss (InfoNCE). Each user's positive is the corresponding item. All other items in batch are negatives. Args: user_emb: (batch_size, dim) normalized user embeddings item_emb: (batch_size, dim) normalized item embeddings temperature: Softmax temperature Returns: Scalar loss """ batch_size = user_emb.size(0) # All pairwise similarities: (batch, batch) similarities = user_emb @ item_emb.T / temperature # Labels: diagonal is positive (user i matches item i) labels = torch.arange(batch_size, device=user_emb.device) # Cross entropy loss loss = F.cross_entropy(similarities, labels) return loss def sampled_softmax_loss( user_emb: torch.Tensor, pos_item_emb: torch.Tensor, neg_item_emb: torch.Tensor, temperature: float = 0.07): """ Softmax loss with explicit negative samples. Args: user_emb: (batch, dim) user embeddings pos_item_emb: (batch, dim) positive item embeddings neg_item_emb: (batch, num_neg, dim) negative item embeddings """ batch_size = user_emb.size(0) # Positive scores: (batch,) pos_scores = (user_emb * pos_item_emb).sum(dim=-1) / temperature # Negative scores: (batch, num_neg) neg_scores = torch.bmm( neg_item_emb, user_emb.unsqueeze(-1) ).squeeze(-1) / temperature # Combine: positive at index 0 all_scores = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1) labels = torch.zeros(batch_size, dtype=torch.long, device=user_emb.device) return F.cross_entropy(all_scores, labels) class HardNegativeMiner: """ Hard negative mining for two-tower training. Periodically refresh negative candidates using current model. Sample negatives with high (but incorrect) similarity. """ def __init__( self, model: nn.Module, item_embeddings: torch.Tensor, num_candidates: int = 1000, num_hard_negatives: int = 50, refresh_steps: int = 1000 ): self.model = model self.item_embeddings = item_embeddings # Pre-computed self.num_candidates = num_candidates self.num_hard_negatives = num_hard_negatives self.refresh_steps = refresh_steps self.step = 0 self.candidate_indices = None def sample_negatives( self, user_emb: torch.Tensor, positive_ids: torch.Tensor ): """ Sample hard negatives for a batch of users. Strategy: From random candidates, select those with highest similarity. """ batch_size = user_emb.size(0) device = user_emb.device # Refresh candidates periodically if self.step % self.refresh_steps == 0: self.candidate_indices = torch.randint( 0, len(self.item_embeddings), (self.num_candidates,), device=device ) self.step += 1 # Get candidate embeddings candidates = self.item_embeddings[self.candidate_indices] # Score all candidates for all users scores = user_emb @ candidates.T # (batch, num_candidates) # Remove positives from consideration for i, pos_id in enumerate(positive_ids): if pos_id in self.candidate_indices: idx = (self.candidate_indices == pos_id).nonzero() if len(idx) > 0: scores[i, idx[0]] = float('-inf') # Select top-k hardest negatives _, top_indices = scores.topk(self.num_hard_negatives, dim=1) hard_neg_ids = self.candidate_indices[top_indices] return hard_neg_idsFor in-batch negatives, larger batch sizes provide more negatives and better gradients. Production systems often use batch sizes of 4096+. Mixed-precision training and gradient accumulation help fit large batches on limited hardware.
The power of two-tower models comes from rich feature engineering:
User Tower Features:
| Feature Type | Examples | Encoding |
|---|---|---|
| Demographics | Age, gender, country | Embeddings or buckets |
| Behavior | Watch history, clicks | Pooled item embeddings |
| Aggregated | Avg rating, click rate | Numerical |
| Temporal | Time of day, day of week | Embeddings |
| Context | Device, location | Embeddings |
Item Tower Features:
| Feature Type | Examples | Encoding |
|---|---|---|
| Metadata | Category, genre, brand | Embeddings |
| Content | Title, description | Text encoder (BERT) |
| Visual | Product images | CNN/ViT |
| Aggregated | Popularity, avg rating | Numerical |
| Graph | Connected items | Pre-trained GNN |
The Interaction Challenge:
Some features are inherently interactive:
Two-tower models can't capture these directly. Solutions include late interaction models and feature crossing before the towers.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
class UserHistoryEncoder(nn.Module): """ Encode user interaction history for the user tower. Uses attention pooling over item embeddings. """ def __init__( self, item_embedding: nn.Embedding, hidden_dim: int = 128, num_heads: int = 4, max_history: int = 50 ): super().__init__() self.item_embedding = item_embedding self.max_history = max_history embed_dim = item_embedding.embedding_dim # Query vector for attention pooling self.query = nn.Parameter(torch.randn(1, embed_dim)) # Multi-head self-attention self.attention = nn.MultiheadAttention( embed_dim, num_heads, batch_first=True ) self.projection = nn.Linear(embed_dim, hidden_dim) def forward( self, history_ids: torch.Tensor, # (batch, max_history) history_mask: torch.Tensor, # (batch, max_history) 1=valid, 0=pad history_weights: torch.Tensor = None # Optional: interaction weights ): """ Encode user history into fixed-size representation. """ batch_size = history_ids.size(0) # Get item embeddings item_embs = self.item_embedding(history_ids) # (batch, seq, dim) # Apply interaction weights if provided if history_weights is not None: item_embs = item_embs * history_weights.unsqueeze(-1) # Attention pooling with learnable query query = self.query.expand(batch_size, 1, -1) # (batch, 1, dim) # Attention masking for padding key_padding_mask = ~history_mask.bool() # Attend over history attended, _ = self.attention( query, item_embs, item_embs, key_padding_mask=key_padding_mask ) # Project to output dimension output = self.projection(attended.squeeze(1)) return output class MultiModalItemEncoder(nn.Module): """ Encode items with multiple modalities (IDs, text, images). """ def __init__( self, num_items: int, id_embedding_dim: int = 64, text_dim: int = 768, # BERT output image_dim: int = 512, # ResNet output output_dim: int = 128 ): super().__init__() # ID embedding (for collaborative signal) self.id_embedding = nn.Embedding(num_items, id_embedding_dim) # Modality projections self.text_projection = nn.Linear(text_dim, output_dim) self.image_projection = nn.Linear(image_dim, output_dim) self.id_projection = nn.Linear(id_embedding_dim, output_dim) # Fusion layer self.fusion = nn.Sequential( nn.Linear(output_dim * 3, output_dim * 2), nn.ReLU(), nn.Linear(output_dim * 2, output_dim) ) def forward( self, item_ids: torch.Tensor, text_features: torch.Tensor = None, image_features: torch.Tensor = None ): """ Fuse multiple modalities for item representation. """ # ID embedding id_emb = self.id_projection(self.id_embedding(item_ids)) # Text features (from pre-trained BERT) if text_features is not None: text_emb = self.text_projection(text_features) else: text_emb = torch.zeros_like(id_emb) # Image features (from pre-trained CNN) if image_features is not None: image_emb = self.image_projection(image_features) else: image_emb = torch.zeros_like(id_emb) # Concatenate and fuse fused = torch.cat([id_emb, text_emb, image_emb], dim=-1) output = self.fusion(fused) return outputWith millions of items, exact nearest neighbor search is too slow. Approximate Nearest Neighbor (ANN) algorithms trade perfect recall for speed.
Popular ANN Libraries:
| Library | Method | Strengths |
|---|---|---|
| FAISS | IVF, HNSW, PQ | Facebook, highly optimized |
| ScaNN | Anisotropic quantization | Google, production-ready |
| Annoy | Random projections | Spotify, memory-efficient |
| Milvus | Hybrid indexes | Open-source, GPU support |
Key Trade-offs:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
import faissimport numpy as np class ANNIndex: """ Wrapper for FAISS ANN index. Supports efficient retrieval of top-K similar items. """ def __init__( self, embedding_dim: int, num_items: int, index_type: str = 'IVF', nlist: int = 100, # Number of clusters for IVF nprobe: int = 10, # Clusters to search at query time use_gpu: bool = False ): self.embedding_dim = embedding_dim self.index_type = index_type if index_type == 'flat': # Exact search (for testing/small catalogs) self.index = faiss.IndexFlatIP(embedding_dim) elif index_type == 'IVF': # Inverted file index (good balance) quantizer = faiss.IndexFlatIP(embedding_dim) self.index = faiss.IndexIVFFlat( quantizer, embedding_dim, nlist, faiss.METRIC_INNER_PRODUCT ) self.index.nprobe = nprobe elif index_type == 'HNSW': # Hierarchical Navigable Small World (fastest queries) self.index = faiss.IndexHNSWFlat( embedding_dim, 32, # M: connections per node faiss.METRIC_INNER_PRODUCT ) if use_gpu: res = faiss.StandardGpuResources() self.index = faiss.index_cpu_to_gpu(res, 0, self.index) self.is_trained = False self.item_ids = None def build_index( self, embeddings: np.ndarray, item_ids: np.ndarray = None ): """ Build/train the index with item embeddings. Args: embeddings: (num_items, dim) L2-normalized embeddings item_ids: Optional mapping from index to item ID """ # Ensure embeddings are normalized embeddings = embeddings.astype(np.float32) faiss.normalize_L2(embeddings) # Train if needed (IVF requires training) if hasattr(self.index, 'train'): self.index.train(embeddings) # Add embeddings self.index.add(embeddings) self.item_ids = item_ids self.is_trained = True print(f"Built index with {self.index.ntotal} items") def search( self, query_embeddings: np.ndarray, top_k: int = 100 ): """ Search for top-K nearest neighbors. Args: query_embeddings: (num_queries, dim) query vectors top_k: Number of neighbors to return Returns: scores: (num_queries, top_k) similarity scores indices: (num_queries, top_k) item indices """ query_embeddings = query_embeddings.astype(np.float32) faiss.normalize_L2(query_embeddings) scores, indices = self.index.search(query_embeddings, top_k) # Map to original item IDs if provided if self.item_ids is not None: indices = self.item_ids[indices] return scores, indices class TwoTowerServingPipeline: """ Production serving pipeline for two-tower recommendations. """ def __init__( self, model: TwoTowerModel, ann_index: ANNIndex, device: str = 'cuda' ): self.model = model.to(device).eval() self.ann_index = ann_index self.device = device @torch.no_grad() def recommend( self, user_features: dict, top_k: int = 100, exclude_items: list = None ): """ Generate recommendations for a user. Steps: 1. Encode user with user tower 2. ANN search for candidate items 3. Optional: filter/exclude items """ # Move features to device batch_features = { k: v.to(self.device) for k, v in user_features.items() } # Encode user user_emb = self.model.encode_users(batch_features) user_emb_np = user_emb.cpu().numpy() # ANN search scores, item_ids = self.ann_index.search(user_emb_np, top_k * 2) # Filter excluded items if exclude_items: mask = ~np.isin(item_ids, exclude_items) # Keep top_k after filtering results = [] for i in range(len(scores)): valid = mask[i] valid_scores = scores[i][valid][:top_k] valid_ids = item_ids[i][valid][:top_k] results.append((valid_scores, valid_ids)) return results return scores[:, :top_k], item_ids[:, :top_k]Item embeddings change as the model trains. Production systems rebuild ANN indices periodically (hourly/daily) or use streaming updates. The index refresh lag means new items may take time to appear in recommendations.
Two-tower models are retrieval models, not ranking models. Evaluation differs from traditional RecSys metrics:
Offline Metrics:
| Metric | Description | Formula |
|---|---|---|
| Recall@K | Fraction of relevant items in top-K | hits@K / total_relevant |
| MRR | Mean reciprocal rank of first hit | 1/rank of first hit |
| Hit Rate | Whether any relevant item in top-K | 1 if hit else 0 |
| NDCG@K | Normalized DCG | Considers position of hits |
The Retrieval vs Ranking Distinction:
Two-tower models are typically the first stage (retrieval/candidate generation). A separate ranking model re-orders the candidates:
Optimize retrieval for recall (don't miss relevant items), ranking for precision (order matters).
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
def evaluate_retrieval( model: TwoTowerModel, ann_index: ANNIndex, test_interactions: list, # [(user_features, positive_item_id), ...] k_values: list = [10, 50, 100, 500]): """ Evaluate two-tower retrieval quality. For each user, check if the positive item appears in top-K. """ model.eval() results = {f'recall@{k}': [] for k in k_values} results['mrr'] = [] max_k = max(k_values) with torch.no_grad(): for user_features, pos_item_id in test_interactions: # Encode user user_emb = model.encode_users(user_features).cpu().numpy() # Retrieve top-K _, retrieved_ids = ann_index.search(user_emb, max_k) retrieved_ids = retrieved_ids[0] # First (only) query # Check if positive is in results if pos_item_id in retrieved_ids: rank = np.where(retrieved_ids == pos_item_id)[0][0] + 1 results['mrr'].append(1.0 / rank) for k in k_values: if rank <= k: results[f'recall@{k}'].append(1.0) else: results[f'recall@{k}'].append(0.0) else: results['mrr'].append(0.0) for k in k_values: results[f'recall@{k}'].append(0.0) # Aggregate metrics = {} for key, values in results.items(): metrics[key] = np.mean(values) return metrics def sample_based_evaluation( model: TwoTowerModel, test_loader: DataLoader, num_negatives: int = 999): """ Evaluate on sampled negatives (when full retrieval is expensive). For each positive, sample random negatives and compute ranking. This approximates full retrieval metrics. """ model.eval() hit_at_10 = [] mrr = [] with torch.no_grad(): for batch in test_loader: user_features = batch['user_features'] pos_item_features = batch['pos_item_features'] neg_item_features = batch['neg_item_features'] # (batch, num_neg, ...) # Encode user_emb = model.encode_users(user_features) pos_emb = model.encode_items(pos_item_features) # Score positive pos_score = (user_emb * pos_emb).sum(dim=-1, keepdim=True) # Score negatives batch_size = user_emb.size(0) neg_scores = [] for i in range(num_negatives): neg_feat = {k: v[:, i] for k, v in neg_item_features.items()} neg_emb = model.encode_items(neg_feat) neg_score = (user_emb * neg_emb).sum(dim=-1, keepdim=True) neg_scores.append(neg_score) neg_scores = torch.cat(neg_scores, dim=1) # (batch, num_neg) # Compute rank of positive all_scores = torch.cat([pos_score, neg_scores], dim=1) ranks = (all_scores > pos_score).sum(dim=1) + 1 # 1-indexed hit_at_10.extend((ranks <= 10).float().cpu().tolist()) mrr.extend((1.0 / ranks).cpu().tolist()) return { 'HR@10': np.mean(hit_at_10), 'MRR': np.mean(mrr) }You've mastered deep learning approaches for recommendations: from Neural Collaborative Filtering through autoencoders, sequence models, graph networks, to production-scale two-tower systems. These techniques power the recommendations you see daily on major platforms. Next module explores production considerations: scaling, fairness, and real-world deployment challenges.