Loading content...
Given a search space containing billions or trillions of possible architectures, how do we find good ones efficiently? This is the central challenge addressed by NAS search strategies.
The search strategy determines:
Different strategies make fundamentally different trade-offs between sample efficiency, parallelizability, and the types of architectures they tend to discover.
Master the major NAS search paradigms: RL-based search with policy gradients, evolutionary algorithms with mutation and selection, gradient-based methods like DARTS, and Bayesian optimization. Understand when each approach excels.
Before exploring sophisticated methods, we must understand random search—the simplest strategy and an essential baseline.
Algorithm:
Why Random Search Matters:
Research has shown that random search is surprisingly competitive with more complex methods on many NAS benchmarks. This has important implications:
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
import randomfrom typing import Callable, Any def random_search_nas( search_space, evaluate_fn: Callable, budget: int, maximize: bool = True) -> tuple: """ Random search for NAS. Args: search_space: Object with sample_random() method evaluate_fn: Function that evaluates architecture performance budget: Number of architectures to evaluate maximize: Whether to maximize (True) or minimize (False) Returns: (best_architecture, best_performance, history) """ best_arch = None best_perf = float('-inf') if maximize else float('inf') history = [] for i in range(budget): # Sample random architecture arch = search_space.sample_random() # Evaluate (this is the expensive part) perf = evaluate_fn(arch) history.append((arch, perf)) # Update best is_better = perf > best_perf if maximize else perf < best_perf if is_better: best_perf = perf best_arch = arch if (i + 1) % 100 == 0: print(f"Iteration {i+1}: Best = {best_perf:.4f}") return best_arch, best_perf, history # Expected performance: With N samples, probability of finding # top-k% architecture is 1 - (1 - k/100)^N# For top 1%: N=100 gives 63%, N=500 gives 99.3%If architectures in the top 1% perform well, random search with ~500 samples has 99%+ probability of finding one. This explains why random search works well when search spaces contain many good architectures.
The seminal NAS paper (Zoph & Le, 2017) framed architecture search as an RL problem:
The Controller:
The controller is an RNN that autoregressively generates architecture tokens: $$P(a) = \prod_{t=1}^{T} P(a_t | a_1, ..., a_{t-1}; \theta_c)$$
where $\theta_c$ are the controller parameters.
Training with REINFORCE:
Since the reward (validation accuracy) is non-differentiable with respect to the architecture, we use policy gradient:
$$ abla_{\theta_c} J(\theta_c) = E_{a \sim \pi_{\theta_c}} [(R(a) - b) abla_{\theta_c} \log P(a; \theta_c)]$$
where $b$ is a baseline (typically exponential moving average of rewards) to reduce variance.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
import torchimport torch.nn as nnimport torch.nn.functional as F class NASController(nn.Module): """ RNN controller for RL-based NAS. Generates architecture sequence autoregressively. """ def __init__( self, num_layers: int, num_operations: int, hidden_dim: int = 100, temperature: float = 1.0 ): super().__init__() self.num_layers = num_layers self.num_operations = num_operations self.temperature = temperature # Embedding for operation tokens self.op_embedding = nn.Embedding(num_operations, hidden_dim) # LSTM controller self.lstm = nn.LSTMCell(hidden_dim, hidden_dim) # Output heads self.op_classifier = nn.Linear(hidden_dim, num_operations) # Learnable initial states self.h0 = nn.Parameter(torch.zeros(1, hidden_dim)) self.c0 = nn.Parameter(torch.zeros(1, hidden_dim)) def forward(self, batch_size: int = 1): """ Sample architectures from controller. Returns: architectures: Sampled operation indices [B, num_layers] log_probs: Log probabilities for each choice entropies: Entropy of each distribution """ h = self.h0.expand(batch_size, -1) c = self.c0.expand(batch_size, -1) architectures = [] log_probs = [] entropies = [] # Start token input_token = torch.zeros(batch_size, dtype=torch.long) for layer_idx in range(self.num_layers): # Embed previous token embed = self.op_embedding(input_token) # LSTM step h, c = self.lstm(embed, (h, c)) # Compute operation distribution logits = self.op_classifier(h) / self.temperature probs = F.softmax(logits, dim=-1) # Sample operation dist = torch.distributions.Categorical(probs) op = dist.sample() architectures.append(op) log_probs.append(dist.log_prob(op)) entropies.append(dist.entropy()) input_token = op return ( torch.stack(architectures, dim=1), torch.stack(log_probs, dim=1), torch.stack(entropies, dim=1) ) def train_controller_step( controller, rewards, log_probs, baseline, optimizer, entropy_weight=0.01): """ One training step using REINFORCE. """ # Advantage = reward - baseline advantages = rewards - baseline # Policy gradient loss policy_loss = -(log_probs.sum(dim=1) * advantages).mean() # Entropy bonus for exploration entropy_loss = -log_probs.mean() loss = policy_loss + entropy_weight * entropy_loss optimizer.zero_grad() loss.backward() optimizer.step() return loss.item()Evolutionary NAS applies principles from biological evolution: maintain a population of architectures, select fit individuals, mutate to create offspring, and iterate.
Regularized Evolution (AmoebaNet):
The key algorithm used in AmoebaNet and many subsequent works:
Why remove oldest, not worst?
Removing the oldest provides "regularization"—it prevents the population from converging too quickly to a local optimum and maintains diversity.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
from collections import dequeimport randomimport copy def regularized_evolution( search_space, evaluate_fn, population_size: int = 50, tournament_size: int = 10, num_mutations: int = 1, budget: int = 1000): """ Regularized Evolution for NAS (AmoebaNet-style). Key insight: Remove oldest, not worst, for regularization. """ # Initialize population (FIFO queue) population = deque(maxlen=population_size) history = [] # Seed with random architectures for _ in range(population_size): arch = search_space.sample_random() fitness = evaluate_fn(arch) population.append({'arch': arch, 'fitness': fitness}) history.append((arch, fitness)) evaluations = population_size while evaluations < budget: # Tournament selection tournament = random.sample(list(population), tournament_size) parent = max(tournament, key=lambda x: x['fitness']) # Mutation child_arch = mutate(parent['arch'], search_space, num_mutations) # Evaluate child child_fitness = evaluate_fn(child_arch) evaluations += 1 # Add child (oldest automatically removed due to maxlen) population.append({'arch': child_arch, 'fitness': child_fitness}) history.append((child_arch, child_fitness)) # Return best ever found best = max(history, key=lambda x: x[1]) return best[0], best[1], history def mutate(arch, search_space, num_mutations=1): """ Mutate architecture by randomly changing one operation or connection. """ child = copy.deepcopy(arch) for _ in range(num_mutations): mutation_type = random.choice(['operation', 'connection']) if mutation_type == 'operation': # Change a random operation node_idx = random.randint(0, len(child.nodes) - 1) op_idx = random.randint(0, 1) # Each node has 2 ops child.nodes[node_idx]['ops'][op_idx] = random.choice( search_space.operations ) else: # Change a connection node_idx = random.randint(0, len(child.nodes) - 1) conn_idx = random.randint(0, 1) valid_inputs = list(range(2 + node_idx)) # Cell inputs + prev nodes child.nodes[node_idx]['inputs'][conn_idx] = random.choice( valid_inputs ) return childDARTS (Differentiable Architecture Search) revolutionized NAS efficiency by enabling gradient-based optimization of architecture.
Core Idea:
Bi-Level Optimization:
$$\min_{\alpha} \mathcal{L}{val}(w^(\alpha), \alpha)$$ $$\text{s.t. } w^(\alpha) = \arg\min_w \mathcal{L}{train}(w, \alpha)$$
Approximation (First-Order DARTS):
Full bi-level optimization is expensive. First-order approximation:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
import torchimport torch.nn as nn class DARTSTrainer: """ DARTS training loop with bi-level optimization. """ def __init__( self, model, arch_optimizer, weight_optimizer, train_loader, val_loader ): self.model = model self.arch_optimizer = arch_optimizer self.weight_optimizer = weight_optimizer self.train_loader = train_loader self.val_loader = val_loader def search_epoch(self): """One epoch of DARTS search.""" train_iter = iter(self.train_loader) val_iter = iter(self.val_loader) for step in range(len(self.train_loader)): # Get training and validation batches train_x, train_y = next(train_iter) try: val_x, val_y = next(val_iter) except StopIteration: val_iter = iter(self.val_loader) val_x, val_y = next(val_iter) # Step 1: Update architecture weights on validation loss self.arch_optimizer.zero_grad() val_logits = self.model(val_x) val_loss = nn.functional.cross_entropy(val_logits, val_y) val_loss.backward() self.arch_optimizer.step() # Step 2: Update network weights on training loss self.weight_optimizer.zero_grad() train_logits = self.model(train_x) train_loss = nn.functional.cross_entropy(train_logits, train_y) train_loss.backward() self.weight_optimizer.step() def derive_architecture(self): """ Discretize continuous architecture weights to get final architecture. """ genotype = [] for edge_name, alpha in self.model.arch_parameters(): # Select operation with highest weight probs = torch.softmax(alpha, dim=-1) best_op_idx = probs.argmax().item() genotype.append((edge_name, best_op_idx)) return genotypeDARTS can suffer from 'architecture collapse'—converging to degenerate solutions dominated by skip connections. Many variants (DARTS+, PC-DARTS, SDARTS) address this through regularization, partial channel connections, or perturbation-based stabilization.
Bayesian Optimization (BO) approaches NAS by building a probabilistic model of the architecture-performance mapping and using it to guide search.
Components:
Algorithm:
| Component | Common Choices | Trade-offs |
|---|---|---|
| Surrogate | GP, RF, Neural Network, GNN | GP: principled uncertainty; NN: scales better |
| Acquisition | EI, UCB, Thompson Sampling | EI: exploitation; UCB: tunable exploration |
| Encoding | One-hot, adjacency matrix, path encoding | Encoding quality affects surrogate accuracy |
| Strategy | Sample Efficiency | Parallelizable | Gradient Required | Best For |
|---|---|---|---|---|
| Random | Low | Yes | No | Baseline; large compute budgets |
| RL (REINFORCE) | Medium | Yes | No | Large spaces; flexible generation |
| Evolution | Medium | Yes | No | Discrete spaces; robust optimization |
| DARTS | Very High | No (sequential) | Yes | Differentiable spaces; limited compute |
| Bayesian Opt | Very High | Limited | No | Expensive evaluations; small budgets |
Key Insights from NAS Research:
You now understand the major NAS search strategies. Next, we'll explore weight sharing—the technique that made NAS practical by dramatically reducing evaluation cost.