Loading content...
So far, we've treated outputs as single values (regression), categorical choices (classification), or independent label sets (multi-label). However, many important prediction tasks produce structured outputs with rich internal dependencies:
In structured prediction, output components are interdependent—the probability of one output element depends on others. A word in translation depends on surrounding words. A parse tree must satisfy grammatical constraints. These dependencies fundamentally change how we design output layers and loss functions.
This page covers: the structured prediction problem formulation, sequence-to-sequence outputs with attention, autoregressive generation, non-autoregressive approaches, tree and graph structured outputs, set prediction with Hungarian matching, spatial output structures, and connections between neural networks and classical structured prediction methods.
Formal definition:
Given input $\mathbf{x}$, predict structured output $\mathbf{y} = (y_1, y_2, \ldots, y_T)$ where $T$ may vary and components have dependencies. The goal is to find:
$$\hat{\mathbf{y}} = \arg\max_{\mathbf{y} \in \mathcal{Y}} p(\mathbf{y} | \mathbf{x})$$
where $\mathcal{Y}$ is the (often exponentially large) space of valid outputs.
The key challenge:
For a sequence of length $T$ with vocabulary $V$, there are $|V|^T$ possible outputs. Exhaustive search is impossible. We need:
Factorization approaches:
Autoregressive: Decompose into sequential conditionals $$p(\mathbf{y}|\mathbf{x}) = \prod_{t=1}^T p(y_t | y_{<t}, \mathbf{x})$$
Non-autoregressive: Predict all positions independently (faster but ignores dependencies) $$p(\mathbf{y}|\mathbf{x}) \approx \prod_{t=1}^T p(y_t | \mathbf{x})$$
Latent variable: Model through latent structure $$p(\mathbf{y}|\mathbf{x}) = \sum_z p(z|\mathbf{x}) p(\mathbf{y}|z, \mathbf{x})$$
| Output Type | Example Tasks | Typical Architecture | Inference |
|---|---|---|---|
| Variable-length sequence | Translation, captioning, summarization | Encoder-decoder, Transformer | Beam search, nucleus sampling |
| Fixed-length sequence | Named entity recognition, POS tagging | BiLSTM/Transformer + CRF | Viterbi decoding |
| Tree | Constituency parsing, code generation | Pointer networks, tree-structured decoders | Greedy/beam parsing |
| Graph | Molecule generation, scene graphs | Graph neural networks, VAEs | Iterative refinement |
| Spatial (dense) | Segmentation, depth estimation | Encoder-decoder CNNs, U-Net | Direct output |
| Set | Object detection, slot filling | DETR-style transformers | Hungarian matching |
The structure of the output often constrains what's valid. In sequence tagging, B-PER must not follow I-LOC. In parsing, brackets must balance. In graph generation, no duplicate edges. Your output layer and training must respect these constraints.
Sequence-to-sequence (Seq2Seq) models transform an input sequence into an output sequence of (potentially) different length. The standard architecture:
The decoder output layer:
At each time step $t$, the decoder hidden state $h_t$ is projected to vocabulary logits:
$$z_t = W_{\text{out}} h_t + b_{\text{out}}$$ $$p(y_t | y_{<t}, \mathbf{x}) = \text{softmax}(z_t)$$
This is multi-class classification at each step, but with the crucial difference that the class being predicted depends on what was generated before.
Training with teacher forcing:
During training, we feed ground-truth previous tokens (not model predictions): $$\mathcal{L} = -\sum_{t=1}^T \log p(y_t^* | y_{<t}^*, \mathbf{x})$$
This stabilizes training but creates exposure bias—the model never sees its own errors during training.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
import torchimport torch.nn as nnimport torch.nn.functional as F class Seq2SeqDecoder(nn.Module): """ Autoregressive decoder with attention. Standard output layer: project hidden state to vocabulary. """ def __init__( self, vocab_size: int, embed_dim: int, hidden_dim: int, num_layers: int = 2, dropout: float = 0.1, ): super().__init__() self.vocab_size = vocab_size self.hidden_dim = hidden_dim # Token embeddings self.embedding = nn.Embedding(vocab_size, embed_dim) # Decoder RNN self.rnn = nn.LSTM( embed_dim + hidden_dim, # Input: embed + context from attention hidden_dim, num_layers=num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True, ) # Attention over encoder outputs self.attention = nn.MultiheadAttention( hidden_dim, num_heads=4, batch_first=True ) # Output projection: hidden_dim -> vocab_size # This is the KEY output layer for seq2seq self.output_projection = nn.Linear(hidden_dim, vocab_size) self.dropout = nn.Dropout(dropout) def forward_step( self, y_prev: torch.Tensor, # Previous token: [batch] hidden: tuple, # LSTM hidden state encoder_outputs: torch.Tensor, # [batch, src_len, hidden_dim] ): """ Single decoding step. Returns: logits: [batch, vocab_size] - output probabilities hidden: Updated hidden state """ batch_size = y_prev.size(0) # Embed previous token embedded = self.embedding(y_prev).unsqueeze(1) # [batch, 1, embed_dim] embedded = self.dropout(embedded) # Attention context query = hidden[0][-1].unsqueeze(1) # [batch, 1, hidden_dim] context, _ = self.attention(query, encoder_outputs, encoder_outputs) # RNN input: embedding + context rnn_input = torch.cat([embedded, context], dim=-1) # Decode output, hidden = self.rnn(rnn_input, hidden) # [batch, 1, hidden_dim] # Output logits logits = self.output_projection(output.squeeze(1)) # [batch, vocab_size] return logits, hidden def forward( self, encoder_outputs: torch.Tensor, target_tokens: torch.Tensor, hidden: tuple = None, ): """ Forward pass with teacher forcing. Args: encoder_outputs: [batch, src_len, hidden_dim] target_tokens: [batch, tgt_len] (includes BOS, excludes EOS) hidden: Initial hidden state Returns: logits: [batch, tgt_len, vocab_size] """ batch_size, tgt_len = target_tokens.shape # Initialize hidden state if not provided if hidden is None: h_0 = torch.zeros(2, batch_size, self.hidden_dim, device=target_tokens.device) c_0 = torch.zeros(2, batch_size, self.hidden_dim, device=target_tokens.device) hidden = (h_0, c_0) all_logits = [] for t in range(tgt_len): y_prev = target_tokens[:, t] logits, hidden = self.forward_step(y_prev, hidden, encoder_outputs) all_logits.append(logits) return torch.stack(all_logits, dim=1) # [batch, tgt_len, vocab_size] def seq2seq_loss(logits, targets, pad_idx=0): """ Cross-entropy loss for seq2seq, ignoring padding. Args: logits: [batch, tgt_len, vocab_size] targets: [batch, tgt_len] ground truth tokens pad_idx: Padding token to ignore """ # Flatten for cross-entropy batch_size, tgt_len, vocab_size = logits.shape logits_flat = logits.reshape(-1, vocab_size) targets_flat = targets.reshape(-1) loss = F.cross_entropy( logits_flat, targets_flat, ignore_index=pad_idx, ) return loss # Examplevocab_size = 10000decoder = Seq2SeqDecoder(vocab_size, embed_dim=256, hidden_dim=512) # Simulate encoder outputsbatch_size = 8src_len = 20tgt_len = 15 encoder_outputs = torch.randn(batch_size, src_len, 512)target_tokens = torch.randint(0, vocab_size, (batch_size, tgt_len)) # Forwardlogits = decoder(encoder_outputs, target_tokens)print(f"Logits shape: {logits.shape}") # [8, 15, 10000] # Lossloss = seq2seq_loss(logits, target_tokens)print(f"Loss: {loss:.4f}")A common technique: tie the output projection weights to the input embedding weights (transposed). If embedding is [vocab_size, embed_dim], output projection is its transpose. This reduces parameters and often improves performance: model.output_projection.weight = model.embedding.weight.
At inference time, we don't have ground-truth targets. The model must generate tokens autoregressively, feeding its own predictions back as input. The decoding strategy significantly affects output quality.
Greedy decoding: $$y_t = \arg\max_v p(v | y_{<t}, \mathbf{x})$$ Fast but often produces repetitive, low-quality outputs.
Beam search: Maintain top-$k$ partial hypotheses, extend each, keep top-$k$ across all extensions.
Sampling strategies:
These strategies don't change the output layer—they change how we use the probabilities it produces.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
import torchimport torch.nn.functional as F def greedy_decode(model, encoder_outputs, bos_idx, eos_idx, max_len=100): """ Greedy decoding: always pick most likely token. """ batch_size = encoder_outputs.size(0) device = encoder_outputs.device # Start with BOS token generated = torch.full((batch_size, 1), bos_idx, dtype=torch.long, device=device) hidden = None for _ in range(max_len): logits, hidden = model.forward_step( generated[:, -1], hidden, encoder_outputs ) # Greedy: pick highest probability next_token = logits.argmax(dim=-1) generated = torch.cat([generated, next_token.unsqueeze(1)], dim=1) # Stop if all sequences have generated EOS if (next_token == eos_idx).all(): break return generated def top_k_sampling(logits, k=50, temperature=1.0): """ Top-k sampling: sample from top k tokens. """ logits = logits / temperature # Get top k top_k_logits, top_k_indices = torch.topk(logits, k, dim=-1) # Sample from top k probs = F.softmax(top_k_logits, dim=-1) sampled_idx = torch.multinomial(probs, num_samples=1) # Map back to vocabulary return torch.gather(top_k_indices, -1, sampled_idx).squeeze(-1) def nucleus_sampling(logits, p=0.9, temperature=1.0): """ Nucleus (top-p) sampling: sample from smallest set whose cumulative probability >= p. """ logits = logits / temperature probs = F.softmax(logits, dim=-1) # Sort probabilities descending sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) # Cumulative probabilities cumsum_probs = torch.cumsum(sorted_probs, dim=-1) # Find cutoff # Mask tokens beyond p threshold sorted_mask = cumsum_probs - sorted_probs > p sorted_probs[sorted_mask] = 0.0 # Renormalize sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True) # Sample sampled_idx = torch.multinomial(sorted_probs, num_samples=1) # Map back return torch.gather(sorted_indices, -1, sampled_idx).squeeze(-1) def sample_decode( model, encoder_outputs, bos_idx, eos_idx, max_len=100, method='nucleus', temperature=1.0, top_k=50, top_p=0.9,): """ Sampling-based decoding with various strategies. """ batch_size = encoder_outputs.size(0) device = encoder_outputs.device generated = torch.full((batch_size, 1), bos_idx, dtype=torch.long, device=device) hidden = None for _ in range(max_len): logits, hidden = model.forward_step( generated[:, -1], hidden, encoder_outputs ) if method == 'greedy': next_token = logits.argmax(dim=-1) elif method == 'top_k': next_token = top_k_sampling(logits, k=top_k, temperature=temperature) elif method == 'nucleus': next_token = nucleus_sampling(logits, p=top_p, temperature=temperature) elif method == 'temperature': probs = F.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1).squeeze(-1) else: raise ValueError(f"Unknown method: {method}") generated = torch.cat([generated, next_token.unsqueeze(1)], dim=1) if (next_token == eos_idx).all(): break return generated # Beam search (simplified)def beam_search(model, encoder_outputs, bos_idx, eos_idx, beam_width=5, max_len=100): """ Beam search decoding. Maintains top-k hypotheses at each step. """ batch_size = encoder_outputs.size(0) assert batch_size == 1, "Beam search implemented for batch_size=1" device = encoder_outputs.device vocab_size = model.vocab_size # Initial beam: just BOS beams = [(torch.tensor([bos_idx], device=device), 0.0, None)] # (tokens, score, hidden) completed = [] for step in range(max_len): candidates = [] for tokens, score, hidden in beams: if tokens[-1] == eos_idx: completed.append((tokens, score)) continue # Get next token probabilities logits, new_hidden = model.forward_step( tokens[-1:], hidden, encoder_outputs ) log_probs = F.log_softmax(logits, dim=-1).squeeze(0) # Top k extensions top_probs, top_indices = torch.topk(log_probs, beam_width) for prob, idx in zip(top_probs, top_indices): new_tokens = torch.cat([tokens, idx.unsqueeze(0)]) new_score = score + prob.item() candidates.append((new_tokens, new_score, new_hidden)) # Keep top beam_width candidates candidates.sort(key=lambda x: x[1], reverse=True) beams = candidates[:beam_width] if not beams: break # Add remaining beams to completed completed.extend(beams) # Return best completed.sort(key=lambda x: x[1] / len(x[0]), reverse=True) # Length-normalized return completed[0][0] if completed else beams[0][0] print("=== Decoding Strategies ===")print("Available strategies: greedy, temperature, top_k, nucleus, beam_search")While the output layer defines the probability distribution, the decoding strategy determines what we actually generate. For creative tasks (dialogue, stories), use sampling. For translation or summarization, beam search often works better. The best strategy depends on the application.
For fixed-length sequence labeling (POS tagging, NER, slot filling), each input token gets a label. While you could use independent softmax per position, this ignores label dependencies—for example, in BIO tagging, I-PER cannot follow B-LOC.
Conditional Random Fields (CRF) add a structured output layer that models transitions between labels:
$$p(\mathbf{y}|\mathbf{x}) = \frac{\exp\left(\sum_t \phi(y_t, \mathbf{x}, t) + \sum_t \psi(y_{t-1}, y_t)\right)}{Z(\mathbf{x})}$$
where:
Architecture: BiLSTM + CRF
Input → Embeddings → BiLSTM → Linear(2d, K) → CRF → Output Labels
The neural network produces emission scores; the CRF layer adds transitions and performs structured inference.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
import torchimport torch.nn as nn class CRFLayer(nn.Module): """ Conditional Random Field layer for sequence labeling. Adds learnable transition scores between labels and uses dynamic programming for training (forward algo) and inference (viterbi decoding). """ def __init__(self, num_tags: int, batch_first: bool = True): super().__init__() self.num_tags = num_tags self.batch_first = batch_first # Transition scores: transitions[i,j] = score of j -> i self.transitions = nn.Parameter(torch.randn(num_tags, num_tags)) # Start and end transition scores self.start_transitions = nn.Parameter(torch.randn(num_tags)) self.end_transitions = nn.Parameter(torch.randn(num_tags)) def forward(self, emissions, tags, mask=None): """ Compute negative log-likelihood loss. Args: emissions: [batch, seq_len, num_tags] emission scores tags: [batch, seq_len] gold tags mask: [batch, seq_len] valid positions (1 = valid) Returns: Negative log-likelihood loss """ if mask is None: mask = torch.ones_like(tags, dtype=torch.bool) # Numerator: score of gold sequence gold_score = self._score_sequence(emissions, tags, mask) # Denominator: log partition function (sum over all sequences) log_Z = self._compute_log_partition(emissions, mask) # NLL = log(Z) - gold_score nll = log_Z - gold_score return nll.mean() def _score_sequence(self, emissions, tags, mask): """Score of a specific tag sequence.""" batch_size, seq_len, _ = emissions.shape score = self.start_transitions[tags[:, 0]] score += emissions[:, 0].gather(1, tags[:, 0:1]).squeeze(1) for t in range(1, seq_len): # Transition score trans_score = self.transitions[tags[:, t], tags[:, t-1]] # Emission score emit_score = emissions[:, t].gather(1, tags[:, t:t+1]).squeeze(1) # Mask out padding score += (trans_score + emit_score) * mask[:, t].float() # End transition last_tag_idx = mask.sum(dim=1).long() - 1 last_tags = tags.gather(1, last_tag_idx.unsqueeze(1)).squeeze(1) score += self.end_transitions[last_tags] return score def _compute_log_partition(self, emissions, mask): """ Compute log partition function using forward algorithm. Dynamic programming over all possible sequences. """ batch_size, seq_len, num_tags = emissions.shape # Initialize: start transitions + first emissions alpha = self.start_transitions + emissions[:, 0] # [batch, num_tags] for t in range(1, seq_len): # Broadcast for all transitions # alpha: [batch, num_tags] -> [batch, num_tags, 1] # transitions: [num_tags, num_tags] # emissions: [batch, num_tags] # new_alpha[j] = logsumexp_i(alpha[i] + trans[i,j] + emit[j]) emit = emissions[:, t].unsqueeze(1) # [batch, 1, num_tags] trans = self.transitions.unsqueeze(0) # [1, num_tags, num_tags] alpha_exp = alpha.unsqueeze(2) # [batch, num_tags, 1] scores = alpha_exp + trans + emit # [batch, num_tags, num_tags] new_alpha = torch.logsumexp(scores, dim=1) # [batch, num_tags] # Mask: keep old alpha for padded positions alpha = torch.where( mask[:, t].unsqueeze(1), new_alpha, alpha ) # End transitions alpha = alpha + self.end_transitions return torch.logsumexp(alpha, dim=1) # [batch] def decode(self, emissions, mask=None): """ Viterbi decoding: find best tag sequence. Returns: best_tags: [batch, seq_len] """ batch_size, seq_len, num_tags = emissions.shape if mask is None: mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=emissions.device) # Viterbi forward pass score = self.start_transitions + emissions[:, 0] history = [] for t in range(1, seq_len): broadcast_score = score.unsqueeze(2) broadcast_emission = emissions[:, t].unsqueeze(1) combined = broadcast_score + self.transitions + broadcast_emission best_score, best_idx = combined.max(dim=1) history.append(best_idx) score = torch.where(mask[:, t].unsqueeze(1), best_score, score) # End transitions score = score + self.end_transitions _, best_last_tag = score.max(dim=1) # Backtrack best_tags = [best_last_tag] for hist in reversed(history): best_last_tag = hist.gather(1, best_last_tag.unsqueeze(1)).squeeze(1) best_tags.append(best_last_tag) best_tags.reverse() return torch.stack(best_tags, dim=1) class BiLSTM_CRF(nn.Module): """ BiLSTM + CRF for sequence labeling. """ def __init__( self, vocab_size: int, num_tags: int, embed_dim: int = 100, hidden_dim: int = 256, dropout: float = 0.5, ): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.dropout = nn.Dropout(dropout) self.lstm = nn.LSTM( embed_dim, hidden_dim // 2, num_layers=1, bidirectional=True, batch_first=True, ) # Emission layer: hidden -> tag scores self.emission = nn.Linear(hidden_dim, num_tags) # CRF layer self.crf = CRFLayer(num_tags) def forward(self, tokens, tags, mask=None): """Training: compute CRF loss.""" emissions = self._get_emissions(tokens) return self.crf(emissions, tags, mask) def decode(self, tokens, mask=None): """Inference: Viterbi decoding.""" emissions = self._get_emissions(tokens) return self.crf.decode(emissions, mask) def _get_emissions(self, tokens): embedded = self.dropout(self.embedding(tokens)) lstm_out, _ = self.lstm(embedded) emissions = self.emission(lstm_out) return emissions # Example usagemodel = BiLSTM_CRF(vocab_size=5000, num_tags=9) # BIO tags for NER batch_size, seq_len = 4, 20tokens = torch.randint(0, 5000, (batch_size, seq_len))tags = torch.randint(0, 9, (batch_size, seq_len)) # Trainingloss = model(tokens, tags)print(f"CRF loss: {loss:.4f}") # Inferencepredictions = model.decode(tokens)print(f"Predicted tags shape: {predictions.shape}")CRF helps most when label constraints are important (BIO tagging) or when neighboring labels are strongly correlated. For simple tagging with independent labels, softmax may suffice. Transformers with large context sometimes learn constraints implicitly, making CRF less beneficial.
Some outputs are unordered sets: object detection (set of bounding boxes), slot filling (set of entities), molecule property prediction (set of atoms). The challenge: output order is arbitrary, but losses typically compare ordered sequences.
DETR's solution: Hungarian matching
Predict a fixed-size set of outputs, then use the Hungarian algorithm to find the optimal assignment between predictions and ground truth that minimizes total cost:
Key insight: The loss becomes permutation-invariant—the model isn't penalized for predicting objects in a different order than the annotation.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom scipy.optimize import linear_sum_assignment class SetPredictionHead(nn.Module): """ Set prediction with Hungarian matching. Each query predicts: (class_logits, box_coordinates). Training uses bipartite matching for assignment. """ def __init__( self, hidden_dim: int, num_classes: int, num_queries: int = 100, ): super().__init__() self.num_queries = num_queries self.num_classes = num_classes # Learnable object queries self.query_embed = nn.Embedding(num_queries, hidden_dim) # Prediction heads (per query) self.class_head = nn.Linear(hidden_dim, num_classes + 1) # +1 for "no object" self.box_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 4), # (cx, cy, w, h) nn.Sigmoid(), # Normalize to [0, 1] ) def forward(self, encoder_output): """ Args: encoder_output: [batch, hidden_dim] or [batch, seq_len, hidden_dim] Returns: class_logits: [batch, num_queries, num_classes+1] box_pred: [batch, num_queries, 4] """ batch_size = encoder_output.size(0) # Get query embeddings queries = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1) # [batch, num_queries, hidden_dim] # For simplicity, assume encoder_output is global feature # In DETR, this goes through a transformer decoder with cross-attention if encoder_output.dim() == 2: features = encoder_output.unsqueeze(1).expand(-1, self.num_queries, -1) else: features = encoder_output.mean(dim=1, keepdim=True) features = features.expand(-1, self.num_queries, -1) # Combine queries with features (simplified) query_features = queries + features # Predictions class_logits = self.class_head(query_features) box_pred = self.box_head(query_features) return class_logits, box_pred class HungarianMatcher: """ Computes optimal bipartite matching between predictions and targets. """ def __init__(self, cost_class=1.0, cost_box=1.0, cost_giou=1.0): self.cost_class = cost_class self.cost_box = cost_box self.cost_giou = cost_giou def __call__(self, outputs, targets): """ Args: outputs: (class_logits, box_pred) class_logits: [batch, num_queries, num_classes+1] box_pred: [batch, num_queries, 4] targets: list of dicts with 'labels' and 'boxes' Returns: List of (pred_idx, target_idx) tuples for each batch element """ class_logits, box_pred = outputs batch_size, num_queries, _ = class_logits.shape indices = [] for b in range(batch_size): tgt_labels = targets[b]['labels'] tgt_boxes = targets[b]['boxes'] num_targets = len(tgt_labels) if num_targets == 0: indices.append(([], [])) continue # Class cost: -log(predicted prob of true class) probs = F.softmax(class_logits[b], dim=-1) class_cost = -probs[:, tgt_labels] # [num_queries, num_targets] # Box L1 cost box_cost = torch.cdist(box_pred[b], tgt_boxes, p=1) # Total cost matrix C = (self.cost_class * class_cost + self.cost_box * box_cost) # Hungarian matching C_np = C.detach().cpu().numpy() pred_idx, tgt_idx = linear_sum_assignment(C_np) indices.append(( torch.tensor(pred_idx, device=class_logits.device), torch.tensor(tgt_idx, device=class_logits.device) )) return indices def set_criterion(outputs, targets, indices, num_classes): """ Compute loss using Hungarian matched pairs. """ class_logits, box_pred = outputs batch_size = class_logits.size(0) # Classification loss (cross-entropy on matched pairs + "no object" for unmatched) # Create target labels: matched get true class, unmatched get "no object" class target_classes = torch.full( (batch_size, class_logits.size(1)), num_classes, # "no object" class dtype=torch.long, device=class_logits.device ) for b, (pred_idx, tgt_idx) in enumerate(indices): if len(pred_idx) > 0: target_classes[b, pred_idx] = targets[b]['labels'][tgt_idx] class_loss = F.cross_entropy( class_logits.flatten(0, 1), target_classes.flatten(), ) # Box regression loss (only on matched pairs) box_losses = [] for b, (pred_idx, tgt_idx) in enumerate(indices): if len(pred_idx) > 0: src_boxes = box_pred[b][pred_idx] tgt_boxes = targets[b]['boxes'][tgt_idx] box_losses.append(F.l1_loss(src_boxes, tgt_boxes)) box_loss = torch.stack(box_losses).mean() if box_losses else torch.tensor(0.0) return class_loss + 5 * box_loss # Weight box loss higher # Examplemodel = SetPredictionHead(hidden_dim=256, num_classes=80, num_queries=100)matcher = HungarianMatcher() # Simulateencoder_out = torch.randn(2, 256)class_logits, box_pred = model(encoder_out) # Ground truth (variable number of objects per image)targets = [ {'labels': torch.tensor([0, 1, 5]), 'boxes': torch.rand(3, 4)}, {'labels': torch.tensor([2]), 'boxes': torch.rand(1, 4)},] indices = matcher((class_logits, box_pred), targets)loss = set_criterion((class_logits, box_pred), targets, indices, num_classes=80) print(f"Output shapes: class={class_logits.shape}, boxes={box_pred.shape}")print(f"Matched indices: {[(len(p), len(t)) for p, t in indices]}")print(f"Loss: {loss:.4f}")DETR (DEtection TRansformer) showed that set prediction with Hungarian matching can replace complex hand-crafted components (anchor boxes, NMS) in object detection. The key insight: treat detection as direct set prediction, making the output layer conceptually simple.
Many vision tasks require dense predictions—one output for every input pixel:
Encoder-decoder architecture:
Input Image → Encoder (downsample) → Decoder (upsample) → Dense Output
The encoder extracts hierarchical features; the decoder upsamples to original resolution. Skip connections (U-Net style) preserve spatial detail.
Output layer considerations:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
import torchimport torch.nn as nnimport torch.nn.functional as F class SegmentationHead(nn.Module): """ Output layer for semantic segmentation. Produces per-pixel class logits. """ def __init__( self, in_channels: int, num_classes: int, hidden_channels: int = 256, ): super().__init__() self.num_classes = num_classes # Refinement convolutions self.conv = nn.Sequential( nn.Conv2d(in_channels, hidden_channels, 3, padding=1), nn.BatchNorm2d(hidden_channels), nn.ReLU(), nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1), nn.BatchNorm2d(hidden_channels), nn.ReLU(), ) # Per-pixel classifier (1x1 conv) self.classifier = nn.Conv2d(hidden_channels, num_classes, kernel_size=1) def forward(self, features, target_size=None): """ Args: features: Decoder output [batch, in_channels, H, W] target_size: (H, W) to upsample to if needed Returns: logits: [batch, num_classes, H, W] """ x = self.conv(features) logits = self.classifier(x) # Upsample to target size if needed if target_size is not None and logits.shape[2:] != target_size: logits = F.interpolate( logits, size=target_size, mode='bilinear', align_corners=False ) return logits def predict(self, features, target_size=None): """Get class predictions per pixel.""" logits = self.forward(features, target_size) return logits.argmax(dim=1) class DepthHead(nn.Module): """ Output layer for monocular depth estimation. Produces per-pixel depth values. """ def __init__( self, in_channels: int, hidden_channels: int = 256, min_depth: float = 0.1, max_depth: float = 100.0, ): super().__init__() self.min_depth = min_depth self.max_depth = max_depth self.conv = nn.Sequential( nn.Conv2d(in_channels, hidden_channels, 3, padding=1), nn.BatchNorm2d(hidden_channels), nn.ReLU(), nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1), nn.BatchNorm2d(hidden_channels), nn.ReLU(), ) # Single-channel output with sigmoid self.regressor = nn.Conv2d(hidden_channels, 1, kernel_size=1) def forward(self, features, target_size=None): """ Returns depth map in [min_depth, max_depth]. """ x = self.conv(features) depth_normalized = torch.sigmoid(self.regressor(x)) # Scale to depth range depth = self.min_depth + (self.max_depth - self.min_depth) * depth_normalized if target_size is not None and depth.shape[2:] != target_size: depth = F.interpolate( depth, size=target_size, mode='bilinear', align_corners=False ) return depth.squeeze(1) # [batch, H, W] class KeypointHeatmapHead(nn.Module): """ Output layer for keypoint detection. Produces K heatmaps, one per keypoint type. """ def __init__( self, in_channels: int, num_keypoints: int, hidden_channels: int = 256, ): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, hidden_channels, 3, padding=1), nn.BatchNorm2d(hidden_channels), nn.ReLU(), ) # One channel per keypoint self.heatmap = nn.Conv2d(hidden_channels, num_keypoints, kernel_size=1) def forward(self, features): """ Returns heatmaps (probabilities) per keypoint. """ x = self.conv(features) heatmaps = torch.sigmoid(self.heatmap(x)) return heatmaps # [batch, num_keypoints, H, W] def get_keypoints(self, heatmaps, threshold=0.5): """ Extract keypoint coordinates from heatmaps. Returns (x, y, confidence) per keypoint. """ batch_size, num_kp, H, W = heatmaps.shape keypoints = [] for b in range(batch_size): sample_kps = [] for k in range(num_kp): hm = heatmaps[b, k] max_val = hm.max() if max_val >= threshold: max_idx = hm.argmax() y = max_idx // W x = max_idx % W sample_kps.append((x.item(), y.item(), max_val.item())) else: sample_kps.append(None) keypoints.append(sample_kps) return keypoints # Segmentation loss with class weightingdef segmentation_loss(logits, targets, class_weights=None, ignore_index=-100): """ Cross-entropy loss for segmentation. Args: logits: [batch, num_classes, H, W] targets: [batch, H, W] class indices class_weights: Optional weight per class for imbalance ignore_index: Label to ignore (e.g., unlabeled pixels) """ loss = F.cross_entropy( logits, targets, weight=class_weights, ignore_index=ignore_index, ) return loss # Example usageprint("=== Spatial Output Heads ===") # Segmentationseg_head = SegmentationHead(in_channels=256, num_classes=21)features = torch.randn(2, 256, 64, 64)seg_logits = seg_head(features, target_size=(256, 256))print(f"Segmentation output: {seg_logits.shape}") # [2, 21, 256, 256] # Depthdepth_head = DepthHead(in_channels=256, min_depth=0.1, max_depth=80.0)depth = depth_head(features, target_size=(256, 256))print(f"Depth output: {depth.shape}") # [2, 256, 256]print(f"Depth range: {depth.min():.2f} to {depth.max():.2f}") # Keypointskp_head = KeypointHeatmapHead(in_channels=256, num_keypoints=17)heatmaps = kp_head(features)print(f"Keypoint heatmaps: {heatmaps.shape}") # [2, 17, 64, 64]Don't upsample too aggressively in one step. Use progressive upsampling with skip connections. Output stride (ratio of input to output resolution) should be at least 8 for good boundary quality. For fine details, use output stride of 4 or even 1 (dilated convolutions help).
Structured outputs require going beyond simple classification or regression. The output layer must respect the structure of the prediction space—whether that's sequences, trees, graphs, sets, or spatial structures. The design choices for output activation, loss function, and inference algorithm are tightly coupled.
You have now completed the comprehensive study of neural network output layers. From simple regression to complex structured outputs, you understand how to design output layers that match the prediction task's requirements. This knowledge is foundational for building effective neural network solutions across all domains of machine learning.