Loading content...
The concepts we've explored—computational graphs, forward propagation, reverse-mode autodiff, and topological ordering—are implemented in production-grade frameworks used by millions of practitioners worldwide. Understanding how these frameworks implement these concepts, and why they make different design decisions, is essential for effective deep learning practice.
The modern framework landscape has evolved dramatically:
Today's practitioner must understand multiple frameworks, their strengths, and when to use each. This knowledge is increasingly important as models and teams grow in complexity.
By the end of this page, you'll understand: (1) Design philosophies of PyTorch, TensorFlow, and JAX, (2) How each framework implements computational graphs and autodiff, (3) Performance characteristics and optimization strategies, (4) Ecosystem strengths (deployment, distributed training, research), (5) Framework selection criteria for different use cases, and (6) Best practices for each framework.
PyTorch, developed by Meta AI, pioneered the define-by-run (dynamic graph) paradigm that has largely won the deep learning framework wars. Its design philosophy prioritizes pythonic usability and research flexibility.
PyTorch builds the computational graph during execution:
Tensor with requires_grad=True tracks its creation operationgrad_fn objects storing backward functions.backward() traverses this linked structure via DFSretain_graph=True)1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
import torchimport torch.nn as nnimport torch.nn.functional as F # ============ Understanding PyTorch's Graph Building ============ x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True) # Each operation creates a new tensor with grad_fnz = x * y # z has grad_fn=<MulBackward0>print(f"z.grad_fn: {z.grad_fn}") w = z.sum() # w has grad_fn=<SumBackward0>print(f"w.grad_fn: {w.grad_fn}") # grad_fn contains references to input grad_fnsprint(f"w's next functions: {w.grad_fn.next_functions}")# Shows the MulBackward0 that created z # ============ Custom Autograd Function ============ class MySigmoid(torch.autograd.Function): """ Custom autograd function demonstrating how PyTorch implements forward/backward for primitives. """ @staticmethod def forward(ctx, x): """ Forward pass. ctx is a context object for passing info to backward. """ result = 1 / (1 + torch.exp(-x)) # Save for backward - this is what gets cached ctx.save_for_backward(result) return result @staticmethod def backward(ctx, grad_output): """ Backward pass. grad_output is the upstream gradient (dL/d(sigmoid output)) Returns gradient w.r.t. each input. """ result, = ctx.saved_tensors # sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x)) grad_input = grad_output * result * (1 - result) return grad_input # Using custom functionmy_sigmoid = MySigmoid.applyx = torch.randn(5, requires_grad=True)y = my_sigmoid(x)loss = y.sum()loss.backward()print(f"Custom sigmoid gradient: {x.grad}") # ============ Gradient Accumulation and zeroing ============ # Gradients ACCUMULATE by default - crucial for RNNs, gradient accumulationx = torch.tensor([1.0], requires_grad=True) # First backward(x ** 2).backward()print(f"After first backward: {x.grad}") # 2.0 # Second backward (without zeroing)(x ** 3).backward()print(f"After second backward: {x.grad}") # 2.0 + 3.0 = 5.0 (accumulated!) # Must zero gradients before each new computationx.grad.zero_()(x ** 2).backward()print(f"After zeroing and new backward: {x.grad}") # 2.0 # ============ Gradient Checkpointing ============from torch.utils.checkpoint import checkpoint class LargeModel(nn.Module): def __init__(self): super().__init__() self.layers = nn.ModuleList([ nn.Linear(1024, 1024) for _ in range(20) ]) def forward(self, x): for i, layer in enumerate(self.layers): # Checkpoint every 4th layer to save memory if i % 4 == 0: x = checkpoint(lambda l, x: F.relu(l(x)), layer, x) else: x = F.relu(layer(x)) return xAlways use torch.no_grad() for inference (15-30% speedup). Use optimizer.zero_grad(set_to_none=True) instead of zero_grad() for slight speedup. Enable cudnn.benchmark=True for fixed-size inputs. Use torch.compile() in PyTorch 2.0+ for significant speedups.
TensorFlow, developed by Google, originally pioneered the static graph paradigm. TensorFlow 2.0 brought eager execution by default while retaining graph compilation benefits through tf.function.
TensorFlow 1.x (2015-2018):
tf.Graph, tf.Session)TensorFlow 2.x (2019-present):
tf.function for graph compilationTensorFlow maintains sophisticated graph infrastructure:
@tf.function): Traces Python function into optimized graph123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
import tensorflow as tffrom tensorflow import keras # ============ Eager vs Graph Mode ============ # Eager execution (default in TF2)x = tf.constant([1.0, 2.0, 3.0])y = x ** 2 # Executes immediatelyprint(f"Eager result: {y}") # tf.Tensor([1. 4. 9.], ...) # Graph mode with tf.function@tf.functiondef square_fn(x): return x ** 2 # First call: traces the function and compiles graphresult = square_fn(tf.constant([1.0, 2.0]))# Subsequent calls: uses compiled graph (fast)result = square_fn(tf.constant([3.0, 4.0])) # ============ GradientTape for Autodiff ============ x = tf.Variable([3.0, 4.0]) # Record operations within tape contextwith tf.GradientTape() as tape: y = x ** 2 z = tf.reduce_sum(y) # Compute gradient (this is the backward pass)grad = tape.gradient(z, x)print(f"Gradient: {grad}") # [6. 8.] = 2*x # Persistent tape for multiple gradient computationsx = tf.Variable([3.0])with tf.GradientTape(persistent=True) as tape: y = x ** 2 z = x ** 3 grad_y = tape.gradient(y, x) # 6.0grad_z = tape.gradient(z, x) # 27.0del tape # Must delete persistent tape # ============ Higher-Order Gradients ============ x = tf.Variable(3.0) with tf.GradientTape() as outer_tape: with tf.GradientTape() as inner_tape: y = x ** 3 # y = x^3 first_order = inner_tape.gradient(y, x) # dy/dx = 3x^2 = 27second_order = outer_tape.gradient(first_order, x) # d²y/dx² = 6x = 18 print(f"First derivative: {first_order}") # 27.0print(f"Second derivative: {second_order}") # 18.0 # ============ Understanding tf.function Tracing ============ @tf.functiondef trace_demo(x): print(f"Tracing with input shape: {x.shape}") # Printed during tracing tf.print(f"Executing with value: ", x) # Printed during execution return x + 1 # First call with shape (2,) - tracestrace_demo(tf.constant([1.0, 2.0])) # Second call with same shape - reuses tracetrace_demo(tf.constant([3.0, 4.0])) # Third call with different shape - traces AGAINtrace_demo(tf.constant([1.0, 2.0, 3.0])) # ============ Custom Training Loop ============ model = keras.Sequential([ keras.layers.Dense(64, activation='relu'), keras.layers.Dense(10)]) optimizer = keras.optimizers.Adam(learning_rate=1e-3)loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) @tf.function # Compile for performancedef train_step(x, y): with tf.GradientTape() as tape: logits = model(x, training=True) loss = loss_fn(y, logits) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) return loss # Training would look like:# for batch_x, batch_y in dataset:# loss = train_step(batch_x, batch_y)tf.function creates a new trace for each unique input shape/dtype combination. Side effects in traced functions execute ONCE during tracing, not during execution. Python state captured during tracing becomes static. Use tf.Variable for mutable state inside tf.function.
JAX, also from Google, takes a radically different approach: functional transformations. Rather than defining a framework with specific classes and constructs, JAX provides composable transformations that turn pure functions into differentiated, vectorized, JIT-compiled versions.
JAX is built around function transformations:
jax.grad: Transform function → gradient functionjax.jit: Transform function → XLA-compiled functionjax.vmap: Transform function → batched functionjax.pmap: Transform function → parallel-across-devices functionThese transformations compose: you can apply jit(grad(vmap(f))) to get a JIT-compiled batched gradient function.
JAX requires functions to be pure (no side effects, same inputs → same outputs). This constraint enables:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
import jaximport jax.numpy as jnpfrom jax import grad, jit, vmap # ============ Basic JAX Operations ============ # JAX arrays work like NumPyx = jnp.array([1.0, 2.0, 3.0])y = jnp.sin(x) ** 2print(f"Result: {y}") # ============ Automatic Differentiation ============ def f(x): """Simple scalar function.""" return jnp.sin(x) ** 2 # grad transforms f into a function that computes df/dxdf = grad(f)print(f"f(π/4) = {f(jnp.pi/4):.4f}") # 0.5print(f"f'(π/4) = {df(jnp.pi/4):.4f}") # 2*sin(π/4)*cos(π/4) = 1.0 # Higher-order gradients via compositiond2f = grad(grad(f)) # Second derivativeprint(f"f''(π/4) = {d2f(jnp.pi/4):.4f}") # cos(2x)*2 at π/4 = 0 # ============ Gradients for Multi-Input Functions ============ def loss(params, x, y): """MSE loss for linear regression.""" w, b = params pred = w * x + b return jnp.mean((pred - y) ** 2) # grad with argnums specifies which argument to diff w.r.t.grad_loss = grad(loss, argnums=0) # Gradient w.r.t. params params = (2.0, 1.0) # w=2, b=1x = jnp.array([1.0, 2.0, 3.0])y = jnp.array([3.0, 5.0, 7.0]) grads = grad_loss(params, x, y)print(f"Gradients: dw={grads[0]:.4f}, db={grads[1]:.4f}") # ============ JIT Compilation ============ def slow_function(x): """Function that would be slow without JIT.""" for _ in range(100): x = jnp.sin(x) + jnp.cos(x) return x # Without JITimport timex = jnp.ones(1000) start = time.time()result = slow_function(x)print(f"Without JIT: {time.time() - start:.4f}s") # With JIT - first call compilesfast_function = jit(slow_function)_ = fast_function(x) # Compilation happens here # Subsequent calls are faststart = time.time()result = fast_function(x)print(f"With JIT (after compile): {time.time() - start:.4f}s") # ============ Vectorization with vmap ============ def single_example_loss(w, x, y): """Loss for a single example.""" pred = jnp.dot(w, x) return (pred - y) ** 2 # Manually batching is verbose and error-prone# vmap transforms to handle batches automatically batch_loss = vmap(single_example_loss, in_axes=(None, 0, 0))# None means w is not batched; 0 means x and y have batch dim at axis 0 w = jnp.array([1.0, 2.0, 3.0])x_batch = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])y_batch = jnp.array([1.0, 2.0, 3.0]) losses = batch_loss(w, x_batch, y_batch)print(f"Batched losses: {losses}") # [0. 0. 0.] (perfect predictions) # ============ Composing Transformations ============ def model(params, x): """Simple MLP.""" w1, b1, w2, b2 = params h = jax.nn.relu(x @ w1 + b1) return h @ w2 + b2 def loss_fn(params, x, y): pred = model(params, x) return jnp.mean((pred - y) ** 2) # Compose: JIT-compile the gradient of the lossfast_grad = jit(grad(loss_fn)) # Now fast_grad is a JIT-compiled gradient function# params = ...# grads = fast_grad(params, x_batch, y_batch) # Fast!JAX provides primitives, not high-level abstractions. For neural networks, use Flax (model definition) and Optax (optimizers). Other libraries: Haiku, Equinox, and PJIT for distributed training. The ecosystem is growing rapidly.
Each framework makes different trade-offs. Understanding these helps you choose the right tool for your use case.
| Feature | PyTorch | TensorFlow 2 | JAX |
|---|---|---|---|
| Execution Model | Eager (dynamic) | Eager + tf.function | Eager + jit |
| Graph Construction | Implicit during forward | GradientTape / tf.function | Tracing via jit |
| Autodiff Approach | Reverse mode, tape-based | Reverse mode, GradientTape | Forward + reverse via transformations |
| Primary API | Imperative (classes) | Keras + functional | Functional (pure functions) |
| State Management | Stateful (nn.Module) | Stateful (tf.Variable) | Stateless (explicit params) |
| JIT Compilation | torch.compile (2.0+) | tf.function + XLA | jax.jit + XLA |
| Distributed Training | DDP, FSDP | tf.distribute | pjit, pmap |
| Mobile/Edge Deploy | TorchServe, torch.jit | TF Lite, TF.js | Via TF Lite export |
| Learning Curve | Gentle | Moderate | Steep (functional style) |
The best practitioners are framework-agnostic. Understanding the underlying concepts (computational graphs, autodiff, optimization) enables moving between frameworks easily. Concepts transfer; syntax is learnable.
Performance optimization in deep learning spans multiple levels. Understanding these across frameworks helps maximize hardware utilization.
Single operations (matmul, conv) must be optimized:
All frameworks leverage these automatically.
When the framework sees the full graph:
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
# ============ PyTorch 2.0 Compilation ============import torch model = ... # Your model # torch.compile analyzes model and generates optimized codeoptimized_model = torch.compile(model) # Different modes trade compile time for performanceoptimized_model = torch.compile(model, mode="reduce-overhead") # Lower latencyoptimized_model = torch.compile(model, mode="max-autotune") # Highest throughput # ============ Mixed Precision Training ============ # PyTorch AMP (Automatic Mixed Precision)from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() for batch in dataloader: optimizer.zero_grad() with autocast(): # Automatic float16 where safe outputs = model(batch) loss = criterion(outputs, targets) # Scale loss to prevent gradient underflow in float16 scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # TensorFlow Mixed Precisionimport tensorflow as tffrom tensorflow.keras import mixed_precision # Set global policymixed_precision.set_global_policy('mixed_float16') # Model will use float16 for compute, float32 for variablesmodel = keras.Sequential([...])model.compile(optimizer='adam', loss='mse') # ============ Memory Optimization ============ # Gradient Checkpointing (PyTorch)from torch.utils.checkpoint import checkpoint_sequential class DeepModel(torch.nn.Module): def __init__(self): super().__init__() self.layers = torch.nn.ModuleList([ torch.nn.Linear(1024, 1024) for _ in range(50) ]) def forward(self, x): # Checkpoint every segment of 5 layers segments = [self.layers[i:i+5] for i in range(0, 50, 5)] return checkpoint_sequential(segments, 5, x) # ============ Data Loading Optimization ============ # PyTorch DataLoader optimizationfrom torch.utils.data import DataLoader loader = DataLoader( dataset, batch_size=64, num_workers=8, # Parallel data loading pin_memory=True, # Faster CPU→GPU transfer prefetch_factor=2, # Prefetch batches persistent_workers=True # Keep workers alive) # TensorFlow tf.data optimizationimport tensorflow as tf dataset = tf.data.Dataset.from_tensor_slices((x, y))dataset = (dataset .shuffle(buffer_size=10000) .batch(64) .prefetch(tf.data.AUTOTUNE) # Overlap data prep with compute .cache() # Cache in memory after first epoch)| Optimization | Impact | Framework Support |
|---|---|---|
| JIT Compilation | 2-10x speedup | PyTorch 2.0, tf.function, jax.jit |
| Mixed Precision (FP16) | 2x speedup, 50% memory | All frameworks |
| Gradient Checkpointing | 50-75% memory reduction | All frameworks |
| Operator Fusion | 20-50% speedup | Automatic with JIT |
| Batch Size Tuning | Variable | Manual tuning |
| Data Pipeline Async | Hide I/O latency | DataLoader, tf.data |
| Channel Last Format (NHWC) | 10-30% CNN speedup | PyTorch 2.0, TF default |
Use profiling tools to identify bottlenecks before optimizing. PyTorch: torch.profiler. TensorFlow: TensorBoard profiler. JAX: jax.profiler. Don't optimize blindly—measure first.
Modern large models require distributed training across multiple devices. Each framework provides different abstractions for this.
The most common strategy: replicate model across devices, split data:
Effective batch size = Per-device batch × Number of devices
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
# ============ PyTorch DDP (Distributed Data Parallel) ============ import torchimport torch.distributed as distfrom torch.nn.parallel import DistributedDataParallel as DDP def setup(rank, world_size): dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) def train(rank, world_size): setup(rank, world_size) model = MyModel().to(rank) model = DDP(model, device_ids=[rank]) optimizer = torch.optim.Adam(model.parameters()) for batch in dataloader: optimizer.zero_grad() loss = model(batch).sum() loss.backward() # Gradients synchronized automatically optimizer.step() dist.destroy_process_group() # Launch with torchrun:# torchrun --nproc_per_node=4 train.py # ============ PyTorch FSDP (Fully Sharded Data Parallel) ============ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP # FSDP shards model parameters across devices# Enables training models larger than single GPU memorymodel = FSDP( model, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=MixedPrecision(param_dtype=torch.float16),) # ============ TensorFlow Distribution Strategy ============ import tensorflow as tf # Multi-GPU on single machinestrategy = tf.distribute.MirroredStrategy() with strategy.scope(): model = create_model() model.compile(optimizer='adam', loss='mse') model.fit(dataset, epochs=10) # Multi-machine distributedstrategy = tf.distribute.MultiWorkerMirroredStrategy() # TPU strategyresolver = tf.distribute.cluster_resolver.TPUClusterResolver()tf.config.experimental_connect_to_cluster(resolver)tf.tpu.experimental.initialize_tpu_system(resolver)strategy = tf.distribute.TPUStrategy(resolver) # ============ JAX Distributed Training ============ import jaxfrom jax import pjitfrom jax.sharding import PartitionSpec as P # JAX uses explicit sharding specifications@partial(pjit, in_axis_resources=(P('data'), P()), out_axis_resources=P('data'))def train_step(batch, params): def loss_fn(params): return compute_loss(params, batch) grads = jax.grad(loss_fn)(params) # Gradients across devices are automatically synchronized return update_params(params, grads) # Model parallelism in JAX@partial(pjit, in_axis_resources=(None, P('model')), out_axis_resources=P('model'))def forward(x, params): return model(params, x)Regardless of framework, certain patterns and practices lead to more reliable, performant, and maintainable code.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
import torchimport torch.nn as nnfrom torch.utils.tensorboard import SummaryWriter class TrainingPipeline: """ Robust training pattern incorporating best practices. Framework: PyTorch (similar patterns apply to TF/JAX) """ def __init__(self, model, train_loader, val_loader, config): self.model = model self.train_loader = train_loader self.val_loader = val_loader self.config = config # Reproducibility self._set_seeds(config.seed) # Device setup self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) # Optimizer and scheduler self.optimizer = torch.optim.AdamW( model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay ) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=config.epochs ) # Mixed precision self.scaler = torch.cuda.amp.GradScaler() # Logging self.writer = SummaryWriter(config.log_dir) # Checkpointing self.best_val_loss = float('inf') def _set_seeds(self, seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) import numpy as np np.random.seed(seed) import random random.seed(seed) def train_epoch(self, epoch): self.model.train() total_loss = 0 for batch_idx, (data, target) in enumerate(self.train_loader): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad(set_to_none=True) # Slightly faster # Mixed precision forward pass with torch.cuda.amp.autocast(): output = self.model(data) loss = nn.functional.cross_entropy(output, target) # Scaled backward pass self.scaler.scale(loss).backward() # Gradient clipping self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config.max_grad_norm ) self.scaler.step(self.optimizer) self.scaler.update() total_loss += loss.item() # Log every N steps if batch_idx % 100 == 0: step = epoch * len(self.train_loader) + batch_idx self.writer.add_scalar('train/loss', loss.item(), step) return total_loss / len(self.train_loader) def validate(self, epoch): self.model.eval() val_loss = 0 correct = 0 with torch.no_grad(): for data, target in self.val_loader: data, target = data.to(self.device), target.to(self.device) with torch.cuda.amp.autocast(): output = self.model(data) val_loss += nn.functional.cross_entropy(output, target).item() pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() val_loss /= len(self.val_loader) accuracy = correct / len(self.val_loader.dataset) # Logging self.writer.add_scalar('val/loss', val_loss, epoch) self.writer.add_scalar('val/accuracy', accuracy, epoch) # Checkpointing if val_loss < self.best_val_loss: self.best_val_loss = val_loss self.save_checkpoint(epoch, is_best=True) return val_loss, accuracy def save_checkpoint(self, epoch, is_best=False): state = { 'epoch': epoch, 'model_state': self.model.state_dict(), 'optimizer_state': self.optimizer.state_dict(), 'scheduler_state': self.scheduler.state_dict(), 'best_val_loss': self.best_val_loss, 'config': self.config, } path = f"{self.config.checkpoint_dir}/checkpoint_epoch{epoch}.pt" torch.save(state, path) if is_best: best_path = f"{self.config.checkpoint_dir}/best_model.pt" torch.save(state, best_path) def run(self): for epoch in range(self.config.epochs): train_loss = self.train_epoch(epoch) val_loss, accuracy = self.validate(epoch) self.scheduler.step() print(f"Epoch {epoch}: train_loss={train_loss:.4f}, " f"val_loss={val_loss:.4f}, accuracy={accuracy:.4f}")Modern deep learning frameworks implement the computational graph concepts we've studied throughout this module. Each framework makes different trade-offs, and understanding these helps you leverage their strengths effectively.
You've completed the Computational Graphs module. You now understand:
These foundations underpin everything in neural network training. The next modules will build on this foundation to explore specific neural network architectures and training techniques.
Congratulations! You've mastered the fundamentals of computational graphs—the representation that enables automatic differentiation and makes training neural networks possible. You understand both the theory and how major frameworks implement it in practice. This knowledge will serve you throughout your deep learning journey.