Loading learning content...
Coupling layers represent the architectural breakthrough that made normalizing flows practical for high-dimensional data. Before coupling layers, flows faced an impossible trade-off: transformations with tractable Jacobians were too simple, while expressive transformations had intractable $O(d^3)$ Jacobian determinants.
Coupling layers resolve this dilemma through a clever insight: partition the input dimensions and transform one partition using an arbitrary function of the other. This structure yields a triangular Jacobian with tractable $O(d)$ determinant computation, while allowing the transformation itself to be arbitrarily complex (parameterized by deep neural networks).
This single innovation, introduced in NICE (2014) and refined in RealNVP (2016), unlocked flows for images, audio, and other high-dimensional domains.
Understand the coupling layer construction, derive the triangular Jacobian structure, implement additive and affine coupling layers, and learn how to compose layers with permutations for full expressiveness.
A coupling layer partitions the input into two parts and transforms one part based on the other.
The Partition:
Given input $\mathbf{z} \in \mathbb{R}^d$, split it into two parts:
The Transformation:
$$\mathbf{x}_A = \mathbf{z}_A$$ $$\mathbf{x}_B = g(\mathbf{z}_B; \theta(\mathbf{z}_A))$$
where:
The Jacobian Structure:
The Jacobian of this transformation has a special block structure:
$$J = \begin{bmatrix} \frac{\partial \mathbf{x}_A}{\partial \mathbf{z}_A} & \frac{\partial \mathbf{x}_A}{\partial \mathbf{z}_B} \ \frac{\partial \mathbf{x}_B}{\partial \mathbf{z}_A} & \frac{\partial \mathbf{x}_B}{\partial \mathbf{z}_B} \end{bmatrix} = \begin{bmatrix} \mathbf{I}_k & \mathbf{0} \ \frac{\partial \mathbf{x}_B}{\partial \mathbf{z}_A} & \frac{\partial g}{\partial \mathbf{z}_B} \end{bmatrix}$$
This is a block lower triangular matrix! Its determinant is: $$\det(J) = \det(\mathbf{I}_k) \cdot \det\left(\frac{\partial g}{\partial \mathbf{z}_B}\right) = \det\left(\frac{\partial g}{\partial \mathbf{z}_B}\right)$$
The complex coupling network $\theta(\mathbf{z}_A)$ doesn't appear in the determinant at all—only the transformation $g$ of the $B$ dimensions matters.
The conditioner network θ(z_A) can be arbitrarily complex—it could be a 100-layer ResNet—and it doesn't affect the Jacobian determinant computation. All that complexity is 'free' from a tractability standpoint. Only the transformation g of z_B must have a tractable Jacobian.
The simplest coupling layer uses additive coupling, introduced in NICE:
$$\mathbf{x}_A = \mathbf{z}_A$$ $$\mathbf{x}_B = \mathbf{z}_B + m(\mathbf{z}_A)$$
where $m: \mathbb{R}^k \to \mathbb{R}^{d-k}$ is an arbitrary neural network.
Properties:
Additive coupling is extremely simple but volume-preserving (det = 1), which limits expressiveness. All density changes must come from the base distribution being warped, not local expansion/contraction.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
import torchimport torch.nn as nn class AdditiveCoupling(nn.Module): """ Additive coupling layer (NICE). x_A = z_A x_B = z_B + m(z_A) Volume-preserving: log|det(J)| = 0 """ def __init__(self, dim, hidden_dim=256, mask_type='left'): super().__init__() self.dim = dim self.split = dim // 2 # Conditioner network: arbitrary architecture self.m = nn.Sequential( nn.Linear(self.split, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim - self.split) ) # Which half to transform self.mask_type = mask_type def forward(self, z): if self.mask_type == 'left': z_A, z_B = z[:, :self.split], z[:, self.split:] else: z_B, z_A = z[:, :self.split], z[:, self.split:] x_A = z_A x_B = z_B + self.m(z_A) if self.mask_type == 'left': x = torch.cat([x_A, x_B], dim=1) else: x = torch.cat([x_B, x_A], dim=1) # Log determinant is 0 (volume-preserving) log_det = torch.zeros(z.shape[0], device=z.device) return x, log_det def inverse(self, x): if self.mask_type == 'left': x_A, x_B = x[:, :self.split], x[:, self.split:] else: x_B, x_A = x[:, :self.split], x[:, self.split:] z_A = x_A z_B = x_B - self.m(x_A) # Simple subtraction! if self.mask_type == 'left': z = torch.cat([z_A, z_B], dim=1) else: z = torch.cat([z_B, z_A], dim=1) log_det = torch.zeros(x.shape[0], device=x.device) return z, log_detAffine coupling extends additive coupling with learnable scaling, introduced in RealNVP:
$$\mathbf{x}_A = \mathbf{z}_A$$ $$\mathbf{x}_B = \mathbf{z}_B \odot \exp(s(\mathbf{z}_A)) + t(\mathbf{z}_A)$$
where:
Properties:
The scaling allows non-volume-preserving transformations, dramatically increasing expressiveness.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
import torchimport torch.nn as nn class AffineCoupling(nn.Module): """ Affine coupling layer (RealNVP). x_A = z_A x_B = z_B * exp(s(z_A)) + t(z_A) log|det(J)| = sum(s(z_A)) """ def __init__(self, dim, hidden_dim=256, mask_type='left'): super().__init__() self.dim = dim self.split = dim // 2 self.mask_type = mask_type # Network outputs both scale and translation self.net = nn.Sequential( nn.Linear(self.split, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 2 * (dim - self.split)) # s and t ) # Initialize to identity transform self.net[-1].weight.data.zero_() self.net[-1].bias.data.zero_() def forward(self, z): z_A, z_B = self._split(z) # Get scale and translation st = self.net(z_A) s, t = st.chunk(2, dim=1) # Affine transformation x_A = z_A x_B = z_B * torch.exp(s) + t x = self._merge(x_A, x_B) # Log determinant = sum of scales log_det = s.sum(dim=1) return x, log_det def inverse(self, x): x_A, x_B = self._split(x) st = self.net(x_A) s, t = st.chunk(2, dim=1) # Inverse affine z_A = x_A z_B = (x_B - t) * torch.exp(-s) z = self._merge(z_A, z_B) # Log det of inverse is negative log_det = -s.sum(dim=1) return z, log_det def _split(self, x): if self.mask_type == 'left': return x[:, :self.split], x[:, self.split:] else: return x[:, self.split:], x[:, :self.split] def _merge(self, x_A, x_B): if self.mask_type == 'left': return torch.cat([x_A, x_B], dim=1) else: return torch.cat([x_B, x_A], dim=1)The scale factor exp(s) can cause numerical issues if s is too large or small. Common remedies: (1) Initialize the scale network to output zeros (identity transform), (2) Clamp s to a reasonable range like [-5, 5], (3) Use a tanh activation on s and scale appropriately.
A single coupling layer only transforms half the dimensions. To transform all dimensions, we must alternate which dimensions are transformed across layers.
Alternating Masks:
The simplest strategy alternates between transforming the first and second halves:
This ensures all dimensions are eventually transformed.
Permutations Between Layers:
More sophisticated approaches permute dimensions between coupling layers:
| Strategy | Description | Jacobian Det | Pros/Cons |
|---|---|---|---|
| Alternating masks | Flip which half is conditioned | 1 | Simple, but limited mixing |
| Reverse permutation | Reverse dimension order | ±1 | Deterministic, moderate mixing |
| Random permutation | Random shuffle (fixed) | ±1 | Better mixing, no learning |
| 1×1 convolution | Learned linear mixing | $\det(\mathbf{W})$ | Full flexibility, needs LU decomposition |
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
import torchimport torch.nn as nn class ReversePermutation(nn.Module): """Simple reverse permutation layer.""" def forward(self, z): return torch.flip(z, dims=[1]), torch.zeros(z.shape[0], device=z.device) def inverse(self, x): return torch.flip(x, dims=[1]), torch.zeros(x.shape[0], device=x.device) class Invertible1x1Conv(nn.Module): """ Learned permutation via invertible 1x1 convolution (Glow). Uses LU decomposition for efficient determinant computation. """ def __init__(self, num_channels): super().__init__() # Initialize as random rotation W = torch.randn(num_channels, num_channels) Q, _ = torch.linalg.qr(W) # LU decomposition for efficient det computation P, L, U = torch.linalg.lu(Q) self.register_buffer('P', P) self.L = nn.Parameter(L) self.U = nn.Parameter(U) self.register_buffer('L_mask', torch.tril(torch.ones_like(L), -1)) self.register_buffer('U_mask', torch.triu(torch.ones_like(U), 1)) def _get_weight(self): L = self.L * self.L_mask + torch.eye(self.L.shape[0], device=self.L.device) U = self.U * self.U_mask + torch.diag(torch.diag(self.U)) return self.P @ L @ U def forward(self, z): W = self._get_weight() x = z @ W.T # Log det = sum of log|diagonal of U| log_det = torch.sum(torch.log(torch.abs(torch.diag(self.U)))) log_det = log_det * torch.ones(z.shape[0], device=z.device) return x, log_det def inverse(self, x): W = self._get_weight() W_inv = torch.linalg.inv(W) z = x @ W_inv.T log_det = -torch.sum(torch.log(torch.abs(torch.diag(self.U)))) log_det = log_det * torch.ones(x.shape[0], device=x.device) return z, log_detA complete flow is built by stacking coupling blocks, each consisting of:
Stacking many such blocks creates an expressive flow that can model complex distributions while maintaining tractable likelihood computation.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
import torchimport torch.nn as nn class CouplingBlock(nn.Module): """ Complete coupling block: ActNorm + Permutation + AffineCoupling """ def __init__(self, dim, hidden_dim=256, use_1x1_conv=True): super().__init__() self.actnorm = ActNorm(dim) if use_1x1_conv: self.permute = Invertible1x1Conv(dim) else: self.permute = ReversePermutation() self.coupling = AffineCoupling(dim, hidden_dim) def forward(self, z): log_det = torch.zeros(z.shape[0], device=z.device) x, ld = self.actnorm.forward(z) log_det += ld x, ld = self.permute.forward(x) log_det += ld x, ld = self.coupling.forward(x) log_det += ld return x, log_det def inverse(self, x): log_det = torch.zeros(x.shape[0], device=x.device) z, ld = self.coupling.inverse(x) log_det += ld z, ld = self.permute.inverse(z) log_det += ld z, ld = self.actnorm.inverse(z) log_det += ld return z, log_det class CouplingFlow(nn.Module): """Complete flow model with multiple coupling blocks.""" def __init__(self, dim, num_blocks=8, hidden_dim=256): super().__init__() self.blocks = nn.ModuleList([ CouplingBlock(dim, hidden_dim) for _ in range(num_blocks) ]) def forward(self, z): log_det = torch.zeros(z.shape[0], device=z.device) x = z for block in self.blocks: x, ld = block.forward(x) log_det += ld return x, log_det def inverse(self, x): log_det = torch.zeros(x.shape[0], device=x.device) z = x for block in reversed(self.blocks): z, ld = block.inverse(z) log_det += ld return z, log_detYou now understand the coupling layer architecture that powers modern flow models. Next, we'll see how RealNVP and Glow build on these foundations with multi-scale architectures and additional innovations to achieve state-of-the-art image generation.