Loading content...
When a deep neural network classifies an image as 'golden retriever', which pixels made the difference? When a language model predicts the next word, which previous tokens were influential? Saliency maps answer these questions by leveraging the gradients that flow through trained networks.
The intuition is elegant: if changing a pixel would significantly change the model's output, that pixel must be important for the prediction. Gradients quantify exactly this sensitivity—they tell us the partial derivative of the output with respect to each input feature.
Saliency maps have become fundamental to neural network interpretability, particularly in computer vision and NLP. They provide pixel-level or token-level attribution that can reveal what models have learned, expose spurious correlations, and help debug unexpected predictions.
This page covers: (1) The mathematical foundations of gradient-based saliency, (2) Vanilla gradients and their interpretation, (3) Gradient × Input for better attributions, (4) Integrated Gradients for principled attribution, (5) SmoothGrad for noise reduction, (6) Guided Backpropagation and DeconvNets, (7) GradCAM and activation-based methods, (8) Saliency for NLP, and (9) Limitations and failure modes.
Gradient-based saliency methods rest on a simple idea: use the derivative of the output with respect to the input to measure input importance.
Setting:
Let $f:\mathbb{R}^n \rightarrow \mathbb{R}$ be a neural network that maps an input $x \in \mathbb{R}^n$ (e.g., an image with n pixels) to a scalar output (e.g., logit for a specific class). The gradient of $f$ at $x$ is:
$$\nabla_x f(x) = \left[\frac{\partial f}{\partial x_1}, \frac{\partial f}{\partial x_2}, ..., \frac{\partial f}{\partial x_n}\right]$$
The $i$-th component $\frac{\partial f}{\partial x_i}$ tells us: How much would $f$ change if we slightly perturbed $x_i$?
Interpretation as Sensitivity:
Gradients measure local sensitivity. A pixel with zero gradient at the current value might still be crucial—perhaps the network's response to that pixel has saturated. This limitation motivates more sophisticated methods like Integrated Gradients.
The simplest saliency method, introduced by Simonyan et al. (2013), takes the gradient of the class score with respect to the input image:
$$S(x) = \left|\nabla_x f_c(x)\right|$$
where $f_c(x)$ is the score for class $c$ (typically the predicted or target class), and absolute value gives unsigned importance.
For Images (3 channels):
For RGB images, each pixel has 3 gradient values. Common aggregations:
The result is a 2D saliency map with the same spatial dimensions as the input image.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
import torchimport torch.nn as nnimport numpy as npimport matplotlib.pyplot as pltfrom torchvision import models, transformsfrom PIL import Imageimport requests # Load pre-trained modelmodel = models.resnet50(pretrained=True)model.eval() # ImageNet preprocessingpreprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) # Load sample imageurl = 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/YellowLabradorLooking_new.jpg/1200px-YellowLabradorLooking_new.jpg'image = Image.open(requests.get(url, stream=True).raw).convert('RGB') # Preprocess and enable gradient computationinput_tensor = preprocess(image).unsqueeze(0) # [1, 3, 224, 224]input_tensor.requires_grad = True # Forward passoutput = model(input_tensor)predicted_class = output.argmax(dim=1).item()print(f"Predicted class: {predicted_class} (Golden Retriever = 207)") # Backward pass for target classmodel.zero_grad()output[0, predicted_class].backward() # Get gradientsgradients = input_tensor.grad.data[0] # [3, 224, 224] # Create saliency mapdef vanilla_saliency(gradients, method='max'): """Convert 3-channel gradients to 2D saliency map.""" gradients = gradients.abs() if method == 'max': saliency = gradients.max(dim=0)[0] elif method == 'l2': saliency = gradients.pow(2).sum(dim=0).sqrt() elif method == 'sum': saliency = gradients.sum(dim=0) return saliency.numpy() saliency = vanilla_saliency(gradients) # Visualizationfig, axes = plt.subplots(1, 3, figsize=(15, 5)) # Original imageaxes[0].imshow(image.resize((224, 224)))axes[0].set_title('Original Image')axes[0].axis('off') # Raw saliencyaxes[1].imshow(saliency, cmap='hot')axes[1].set_title('Vanilla Gradient Saliency')axes[1].axis('off') # Overlayaxes[2].imshow(image.resize((224, 224)))axes[2].imshow(saliency, cmap='hot', alpha=0.5)axes[2].set_title('Overlay')axes[2].axis('off') plt.tight_layout()plt.savefig('vanilla_saliency.png', dpi=150)plt.show() # Note: Vanilla gradients are often noisy and scatteredYou'll notice that vanilla gradient saliency maps are often noisy, with importance scattered across the image. This is because gradients fluctuate rapidly in deep networks with ReLU activations. More advanced methods address this noise problem.
A simple but powerful improvement multiplies gradients by input values:
$$S(x)_i = x_i \cdot \frac{\partial f}{\partial x_i}$$
Intuition:
The gradient tells us sensitivity, but doesn't account for the actual input value. A pixel might have a large gradient, but if its value is zero (or close to the mean after normalization), changing it has no practical effect.
Multiplying by the input gives an approximation of first-order Taylor expansion:
$$f(x) \approx f(0) + \sum_i x_i \cdot \frac{\partial f}{\partial x_i}\bigg|_{x=0}$$
(Note: this is exact only for linear networks; for deep networks it's an approximation)
This is equivalent to element-wise Hadamard product: $$S = x \odot \nabla_x f(x)$$
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
import torchimport numpy as npimport matplotlib.pyplot as pltfrom torchvision import models, transformsfrom PIL import Imageimport requests # Load model and image (same setup as before)model = models.resnet50(pretrained=True)model.eval() preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/YellowLabradorLooking_new.jpg/1200px-YellowLabradorLooking_new.jpg'image = Image.open(requests.get(url, stream=True).raw).convert('RGB') input_tensor = preprocess(image).unsqueeze(0)input_tensor.requires_grad = True output = model(input_tensor)predicted_class = output.argmax(dim=1).item() model.zero_grad()output[0, predicted_class].backward() gradients = input_tensor.grad.data[0] # [3, 224, 224]input_values = input_tensor.data[0] # [3, 224, 224] # Gradient × Inputgrad_times_input = gradients * input_values def aggregate_saliency(saliency_3d): """Aggregate 3-channel saliency to 2D.""" return saliency_3d.abs().max(dim=0)[0].numpy() vanilla_saliency = aggregate_saliency(gradients)gxi_saliency = aggregate_saliency(grad_times_input) # Visualization comparisonfig, axes = plt.subplots(1, 4, figsize=(20, 5)) axes[0].imshow(image.resize((224, 224)))axes[0].set_title('Original Image')axes[0].axis('off') axes[1].imshow(vanilla_saliency, cmap='hot')axes[1].set_title('Vanilla Gradient')axes[1].axis('off') axes[2].imshow(gxi_saliency, cmap='hot')axes[2].set_title('Gradient × Input')axes[2].axis('off') # Show differenceaxes[3].imshow(gxi_saliency - vanilla_saliency, cmap='seismic', vmin=-0.1, vmax=0.1)axes[3].set_title('Difference (G×I - Vanilla)')axes[3].axis('off') plt.tight_layout()plt.savefig('gradient_times_input.png', dpi=150)plt.show() # Key insight: Gradient × Input often gives sharper, more focused attributions# because it downweights gradients where input values are near zeroIntegrated Gradients (IG) (Sundararajan et al., 2017) is one of the most principled attribution methods. It addresses a fundamental problem with vanilla gradients: they only capture local sensitivity, missing important features that the model has already 'saturated' on.
Core Idea:
Instead of computing gradients only at the input $x$, integrate gradients along a path from a baseline $x'$ (typically a black image or zero vector) to the actual input $x$:
$$IG_i(x) = (x_i - x'_i) \times \int_0^1 \frac{\partial f(x' + \alpha(x - x'))}{\partial x_i} d\alpha$$
This captures the total effect of each feature as we move from baseline to input.
Desirable Axioms that IG Satisfies:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
import torchimport numpy as npimport matplotlib.pyplot as pltfrom torchvision import models, transformsfrom PIL import Imageimport requests # Load modelmodel = models.resnet50(pretrained=True)model.eval() preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/YellowLabradorLooking_new.jpg/1200px-YellowLabradorLooking_new.jpg'image = Image.open(requests.get(url, stream=True).raw).convert('RGB')input_tensor = preprocess(image).unsqueeze(0) def integrated_gradients(model, input_tensor, target_class, baseline=None, steps=50): """ Compute Integrated Gradients attribution. Args: model: PyTorch model input_tensor: Input to explain [1, C, H, W] target_class: Class to compute attribution for baseline: Baseline input (default: zeros) steps: Number of integration steps Returns: Attribution map [C, H, W] """ if baseline is None: baseline = torch.zeros_like(input_tensor) # Generate interpolation points alphas = torch.linspace(0, 1, steps + 1) # Path from baseline to input scaled_inputs = torch.stack([ baseline + alpha * (input_tensor - baseline) for alpha in alphas ]).squeeze(1) # [steps+1, C, H, W] scaled_inputs.requires_grad = True # Forward pass for all interpolated inputs outputs = model(scaled_inputs) # Backward pass for target class model.zero_grad() outputs[:, target_class].sum().backward() gradients = scaled_inputs.grad # [steps+1, C, H, W] # Riemann sum approximation of integral avg_gradients = (gradients[:-1] + gradients[1:]) / 2 integrated_grads = avg_gradients.mean(dim=0) # [C, H, W] # Multiply by (input - baseline) attribution = (input_tensor.squeeze() - baseline.squeeze()) * integrated_grads return attribution # Compute Integrated Gradientsoutput = model(input_tensor)predicted_class = output.argmax(dim=1).item() ig_attribution = integrated_gradients(model, input_tensor, predicted_class, steps=100) # Aggregate to 2Dig_saliency = ig_attribution.abs().max(dim=0)[0].detach().numpy() # Compare with vanilla gradientinput_tensor.requires_grad = Trueoutput = model(input_tensor)model.zero_grad()output[0, predicted_class].backward()vanilla_grads = input_tensor.grad.data[0]vanilla_saliency = vanilla_grads.abs().max(dim=0)[0].numpy() # Visualizationfig, axes = plt.subplots(1, 4, figsize=(20, 5)) axes[0].imshow(image.resize((224, 224)))axes[0].set_title('Original Image')axes[0].axis('off') axes[1].imshow(vanilla_saliency, cmap='hot')axes[1].set_title('Vanilla Gradient')axes[1].axis('off') axes[2].imshow(ig_saliency, cmap='hot')axes[2].set_title('Integrated Gradients (100 steps)')axes[2].axis('off') # Verify completeness: sum of attributions ≈ f(x) - f(baseline)baseline_output = model(torch.zeros_like(input_tensor))diff = (output[0, predicted_class] - baseline_output[0, predicted_class]).item()attr_sum = ig_attribution.sum().item()print(f"f(x) - f(baseline): {diff:.4f}")print(f"Sum of attributions: {attr_sum:.4f}")print(f"Completeness ratio: {attr_sum/diff:.4f}") # Should be ≈ 1.0 axes[3].text(0.5, 0.5, f"Completeness Check:\n\n" f"f(x) - f(baseline) = {diff:.4f}\n" f"Sum of IG = {attr_sum:.4f}\n" f"Ratio = {attr_sum/diff:.4f}", ha='center', va='center', fontsize=14, transform=axes[3].transAxes)axes[3].axis('off')axes[3].set_title('Axiom Verification') plt.tight_layout()plt.savefig('integrated_gradients.png', dpi=150)plt.show()The baseline is crucial for Integrated Gradients. Common choices: (1) Black image (zeros) for vision, (2) Padding token embeddings for NLP, (3) Mean image from training data, (4) Gaussian noise. The baseline should represent 'absence of information'. Results can vary with different baselines—consider averaging over multiple baselines.
Vanilla gradients are often noisy because of the sharp, non-smooth nature of ReLU activations. SmoothGrad (Smilkov et al., 2017) addresses this by averaging gradients over multiple noisy versions of the input:
$$\hat{S}(x) = \frac{1}{n} \sum_{i=1}^{n} \nabla_x f(x + \mathcal{N}(0, \sigma^2))$$
Intuition:
By adding noise and averaging, we smooth out the sharp fluctuations in gradients that come from the piece-wise linear nature of ReLUs. The result is visually cleaner saliency maps that often better match human intuitions.
Key Hyperparameters:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
import torchimport numpy as npimport matplotlib.pyplot as pltfrom torchvision import models, transformsfrom PIL import Imageimport requests # Load model and imagemodel = models.resnet50(pretrained=True)model.eval() preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/YellowLabradorLooking_new.jpg/1200px-YellowLabradorLooking_new.jpg'image = Image.open(requests.get(url, stream=True).raw).convert('RGB')input_tensor = preprocess(image).unsqueeze(0) def smoothgrad(model, input_tensor, target_class, n_samples=50, noise_level=0.15): """ Compute SmoothGrad saliency map. Args: model: PyTorch model input_tensor: Input to explain [1, C, H, W] target_class: Class to compute saliency for n_samples: Number of noisy samples noise_level: Noise standard deviation as fraction of input std Returns: Smoothed gradient saliency [C, H, W] """ # Calculate noise std based on input range sigma = noise_level * (input_tensor.max() - input_tensor.min()) all_gradients = [] for _ in range(n_samples): # Add Gaussian noise noise = torch.randn_like(input_tensor) * sigma noisy_input = input_tensor + noise noisy_input.requires_grad = True # Forward + backward output = model(noisy_input) model.zero_grad() output[0, target_class].backward() all_gradients.append(noisy_input.grad.data.clone()) # Average gradients avg_gradient = torch.stack(all_gradients).mean(dim=0).squeeze(0) return avg_gradient # Compute SmoothGrad with different parametersoutput = model(input_tensor)predicted_class = output.argmax(dim=1).item() smoothgrad_50 = smoothgrad(model, input_tensor, predicted_class, n_samples=50, noise_level=0.15)smoothgrad_200 = smoothgrad(model, input_tensor, predicted_class, n_samples=200, noise_level=0.15) # Vanilla gradient for comparisoninput_tensor.requires_grad = Trueoutput = model(input_tensor)model.zero_grad()output[0, predicted_class].backward()vanilla_grad = input_tensor.grad.data[0] # Aggregate to 2Ddef to_2d(grads): return grads.abs().max(dim=0)[0].numpy() # Visualizationfig, axes = plt.subplots(1, 4, figsize=(20, 5)) axes[0].imshow(image.resize((224, 224)))axes[0].set_title('Original Image')axes[0].axis('off') axes[1].imshow(to_2d(vanilla_grad), cmap='hot')axes[1].set_title('Vanilla Gradient')axes[1].axis('off') axes[2].imshow(to_2d(smoothgrad_50), cmap='hot')axes[2].set_title('SmoothGrad (n=50)')axes[2].axis('off') axes[3].imshow(to_2d(smoothgrad_200), cmap='hot')axes[3].set_title('SmoothGrad (n=200)')axes[3].axis('off') plt.tight_layout()plt.savefig('smoothgrad.png', dpi=150)plt.show() # SmoothGrad can be combined with other methodsdef smooth_integrated_gradients(model, input_tensor, target_class, n_samples=20, noise_level=0.1, ig_steps=30): """Combine SmoothGrad with Integrated Gradients.""" from functools import partial sigma = noise_level * (input_tensor.max() - input_tensor.min()) all_ig = [] for _ in range(n_samples): noise = torch.randn_like(input_tensor) * sigma noisy_input = input_tensor + noise # Compute IG for this noisy version ig = integrated_gradients(model, noisy_input, target_class, steps=ig_steps) all_ig.append(ig) return torch.stack(all_ig).mean(dim=0)| Method | Principle | Pros | Cons |
|---|---|---|---|
| Vanilla Gradient | ∂f/∂x at input | Fast, simple | Noisy, local only |
| Gradient × Input | x · ∂f/∂x | Sharper, considers input | Still local |
| Integrated Gradients | Path integral from baseline | Axiomatic, complete | Requires baseline choice |
| SmoothGrad | Average over noisy inputs | Less noise, visually clean | Slower, hyperparameters |
Guided Backpropagation and DeconvNets modify the backward pass through ReLU activations to produce visually sharper saliency maps.
Standard Backpropagation through ReLU: $$\frac{\partial f}{\partial x}{\text{ReLU}} = \mathbf{1}{x>0} \cdot \frac{\partial f}{\partial \text{ReLU}(x)}$$
Gradient flows back only where the forward activation was positive.
DeconvNet (Zeiler & Fergus, 2014): $$\frac{\partial f}{\partial x}{\text{Deconv}} = \mathbf{1}{\frac{\partial f}{\partial \text{ReLU}(x)}>0} \cdot \frac{\partial f}{\partial \text{ReLU}(x)}$$
Only passes positive gradients (regardless of forward activation).
Guided Backpropagation (Springenberg et al., 2015): $$\frac{\partial f}{\partial x}{\text{Guided}} = \mathbf{1}{x>0} \cdot \mathbf{1}_{\frac{\partial f}{\partial \text{ReLU}(x)}>0} \cdot \frac{\partial f}{\partial \text{ReLU}(x)}$$
Only passes where both forward activation AND gradient are positive. This produces the sharpest results but has been criticized for not reflecting actual model behavior.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
import torchimport torch.nn as nnimport numpy as npimport matplotlib.pyplot as pltfrom torchvision import models, transformsfrom PIL import Imageimport requests class GuidedBackpropReLU(torch.autograd.Function): """ReLU with guided backpropagation.""" @staticmethod def forward(ctx, input): ctx.save_for_backward(input) return input.clamp(min=0) @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors grad_input = grad_output.clone() # Guided: mask where input > 0 AND gradient > 0 grad_input[input < 0] = 0 # Standard ReLU backward grad_input[grad_output < 0] = 0 # Guided: also mask negative gradients return grad_input class GuidedBackpropModel(nn.Module): """Wrapper that replaces ReLUs with GuidedBackpropReLU.""" def __init__(self, model): super().__init__() self.model = model self._replace_relu_with_guided() def _replace_relu_with_guided(self): def replace_relu(module): for name, child in module.named_children(): if isinstance(child, nn.ReLU): setattr(module, name, GuidedReLU()) else: replace_relu(child) replace_relu(self.model) def forward(self, x): return self.model(x) class GuidedReLU(nn.Module): """Module version of guided ReLU.""" def forward(self, x): return GuidedBackpropReLU.apply(x) # Load model and imagemodel = models.resnet50(pretrained=True)model.eval() # Create guided versionguided_model = GuidedBackpropModel(model)guided_model.eval() preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/YellowLabradorLooking_new.jpg/1200px-YellowLabradorLooking_new.jpg'image = Image.open(requests.get(url, stream=True).raw).convert('RGB')input_tensor = preprocess(image).unsqueeze(0) # Vanilla gradientinput_vanilla = input_tensor.clone().requires_grad_(True)output = model(input_vanilla)predicted_class = output.argmax(dim=1).item()model.zero_grad()output[0, predicted_class].backward()vanilla_grad = input_vanilla.grad[0].clone() # Guided backpropinput_guided = input_tensor.clone().requires_grad_(True)output_guided = guided_model(input_guided)guided_model.zero_grad()output_guided[0, predicted_class].backward()guided_grad = input_guided.grad[0].clone() # Visualizationdef normalize_gradient(grad, percentile=99): """Normalize gradients for visualization.""" grad = grad.abs().max(dim=0)[0].numpy() vmax = np.percentile(grad, percentile) grad = np.clip(grad / vmax, 0, 1) return grad fig, axes = plt.subplots(1, 3, figsize=(15, 5)) axes[0].imshow(image.resize((224, 224)))axes[0].set_title('Original Image')axes[0].axis('off') axes[1].imshow(normalize_gradient(vanilla_grad), cmap='gray')axes[1].set_title('Vanilla Gradient')axes[1].axis('off') axes[2].imshow(normalize_gradient(guided_grad), cmap='gray')axes[2].set_title('Guided Backpropagation')axes[2].axis('off') plt.tight_layout()plt.savefig('guided_backprop.png', dpi=150)plt.show()While Guided Backpropagation produces visually appealing results, research has shown it's essentially a edge detector that doesn't reflect the model's actual reasoning. 'Sanity Checks for Saliency Maps' (Adebayo et al., 2018) shows guided backprop looks similar even for randomly initialized networks. Use with caution for interpretation; prefer Integrated Gradients for faithful attributions.
GradCAM (Gradient-weighted Class Activation Mapping) produces coarse localization maps by combining gradients with convolutional feature maps. Unlike input-space saliency, GradCAM operates in feature space and produces smooth, class-discriminative visualizations.
The Method:
For a target class $c$, compute gradients of $y^c$ w.r.t. feature maps $A^k$ from a convolutional layer (typically the last conv layer)
Global average pool gradients to get importance weights: $$\alpha_k^c = \frac{1}{Z} \sum_i \sum_j \frac{\partial y^c}{\partial A^k_{ij}}$$
Weighted combination of feature maps with ReLU: $$L_{\text{GradCAM}}^c = \text{ReLU}\left(\sum_k \alpha_k^c A^k\right)$$
Upsample to input size for visualization
ReLU captures only positive influences on the class.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
import torchimport torch.nn.functional as Fimport numpy as npimport matplotlib.pyplot as pltfrom torchvision import models, transformsfrom PIL import Imageimport requests model = models.resnet50(pretrained=True)model.eval() preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/YellowLabradorLooking_new.jpg/1200px-YellowLabradorLooking_new.jpg'image = Image.open(requests.get(url, stream=True).raw).convert('RGB')input_tensor = preprocess(image).unsqueeze(0) class GradCAM: """GradCAM implementation for any CNN.""" def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.feature_maps = None self.gradients = None # Register hooks target_layer.register_forward_hook(self._save_features) target_layer.register_full_backward_hook(self._save_gradients) def _save_features(self, module, input, output): self.feature_maps = output.detach() def _save_gradients(self, module, grad_input, grad_output): self.gradients = grad_output[0].detach() def __call__(self, input_tensor, target_class=None): # Forward pass output = self.model(input_tensor) if target_class is None: target_class = output.argmax(dim=1).item() # Backward pass self.model.zero_grad() output[0, target_class].backward() # Global average pool gradients -> weights weights = self.gradients.mean(dim=[2, 3], keepdim=True) # [1, C, 1, 1] # Weighted combination cam = (weights * self.feature_maps).sum(dim=1, keepdim=True) # [1, 1, H, W] cam = F.relu(cam) # Only positive contributions # Normalize to [0, 1] cam = cam - cam.min() cam = cam / (cam.max() + 1e-8) # Upsample to input size cam = F.interpolate(cam, size=(224, 224), mode='bilinear', align_corners=False) return cam.squeeze().numpy(), target_class # Apply GradCAM to last convolutional layertarget_layer = model.layer4[-1].conv3gradcam = GradCAM(model, target_layer) cam, predicted_class = gradcam(input_tensor) # Visualizationfig, axes = plt.subplots(1, 4, figsize=(20, 5)) axes[0].imshow(image.resize((224, 224)))axes[0].set_title('Original Image')axes[0].axis('off') axes[1].imshow(cam, cmap='jet')axes[1].set_title(f'GradCAM (class {predicted_class})')axes[1].axis('off') # Overlayaxes[2].imshow(image.resize((224, 224)))axes[2].imshow(cam, cmap='jet', alpha=0.5)axes[2].set_title('GradCAM Overlay')axes[2].axis('off') # Guided GradCAM: element-wise multiply guided backprop with GradCAM# (combines sharpness of guided backprop with class-specificity of GradCAM) # Simple comparison with vanilla gradientinput_tensor.requires_grad = Trueoutput = model(input_tensor)model.zero_grad()output[0, predicted_class].backward()vanilla_grad = input_tensor.grad[0].abs().max(dim=0)[0].numpy() axes[3].imshow(vanilla_grad, cmap='hot')axes[3].set_title('Vanilla Gradient (for comparison)')axes[3].axis('off') plt.tight_layout()plt.savefig('gradcam.png', dpi=150)plt.show() # GradCAM for different classes (useful for multi-object images)imagenet_classes = {207: 'golden_retriever', 231: 'Shetland sheepdog'} # Example fig, axes = plt.subplots(1, 3, figsize=(15, 5))axes[0].imshow(image.resize((224, 224)))axes[0].set_title('Original')axes[0].axis('off') for idx, (class_id, class_name) in enumerate(imagenet_classes.items()): cam, _ = gradcam(input_tensor, target_class=class_id) axes[idx + 1].imshow(image.resize((224, 224))) axes[idx + 1].imshow(cam, cmap='jet', alpha=0.5) axes[idx + 1].set_title(f'GradCAM: {class_name}') axes[idx + 1].axis('off') plt.tight_layout()plt.savefig('gradcam_classes.png', dpi=150)plt.show()Saliency methods apply to text by computing gradients with respect to input embeddings (not discrete tokens). The resulting token-level attributions highlight which words influenced the prediction.
Challenges for NLP:
Discrete Inputs: We can't directly differentiate w.r.t. tokens. We differentiate w.r.t. embeddings instead.
Embedding Space: Word embeddings are learned representations—gradient magnitude depends on embedding geometry.
Sequential Nature: Token importance may depend on context, which gradients may not fully capture.
Common Approaches:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
import torchimport numpy as npimport matplotlib.pyplot as pltfrom transformers import BertTokenizer, BertForSequenceClassification # Load sentiment modeltokenizer = BertTokenizer.from_pretrained('textattack/bert-base-uncased-SST-2')model = BertForSequenceClassification.from_pretrained('textattack/bert-base-uncased-SST-2')model.eval() sentence = "This movie was absolutely wonderful and I loved every minute of it"inputs = tokenizer(sentence, return_tensors='pt', padding=True, truncation=True)tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) def gradient_token_saliency(model, inputs, target_class=None): """Compute gradient-based saliency for each token.""" # Get embeddings embeddings = model.bert.embeddings.word_embeddings(inputs['input_ids']) embeddings.requires_grad = True # Replace embedding layer temporarily (hacky but works) original_embedding = model.bert.embeddings.word_embeddings # Forward pass using embeddings directly outputs = model.bert( inputs_embeds=embeddings, attention_mask=inputs['attention_mask'], token_type_ids=inputs.get('token_type_ids') ) logits = model.classifier(outputs.last_hidden_state[:, 0]) # [CLS] token if target_class is None: target_class = logits.argmax(dim=1).item() # Backward model.zero_grad() logits[0, target_class].backward() # Gradient L2 norm per token gradients = embeddings.grad.squeeze() # [seq_len, embed_dim] saliency = gradients.norm(dim=1).detach().numpy() # [seq_len] return saliency, target_class def gradient_times_input_saliency(model, inputs, target_class=None): """Gradient × Input saliency for tokens.""" embeddings = model.bert.embeddings.word_embeddings(inputs['input_ids']) embeddings.requires_grad = True outputs = model.bert( inputs_embeds=embeddings, attention_mask=inputs['attention_mask'], token_type_ids=inputs.get('token_type_ids') ) logits = model.classifier(outputs.last_hidden_state[:, 0]) if target_class is None: target_class = logits.argmax(dim=1).item() model.zero_grad() logits[0, target_class].backward() # Gradient × Input (sum absolute values across embedding dimension) gxi = (embeddings.grad * embeddings).sum(dim=-1).abs().squeeze().detach().numpy() return gxi, target_class # Compute saliencygrad_saliency, pred = gradient_token_saliency(model, inputs)gxi_saliency, _ = gradient_times_input_saliency(model, inputs) # Get predictionwith torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=1)[0] labels = ['Negative', 'Positive']print(f"Prediction: {labels[pred]} ({probs[pred].item():.3f})") # Visualizationdef plot_token_saliency(tokens, saliency, title='Token Saliency'): """Bar plot of token-level saliency.""" fig, ax = plt.subplots(figsize=(14, 4)) # Normalize saliency = saliency / saliency.max() colors = plt.cm.Reds(saliency) positions = range(len(tokens)) bars = ax.bar(positions, saliency, color=colors) ax.set_xticks(positions) ax.set_xticklabels(tokens, rotation=45, ha='right', fontsize=10) ax.set_ylabel('Normalized Saliency') ax.set_title(title) ax.set_ylim(0, 1.1) plt.tight_layout() return fig fig1 = plot_token_saliency(tokens, grad_saliency, 'Gradient Norm Saliency')plt.savefig('nlp_saliency_grad.png', dpi=150) fig2 = plot_token_saliency(tokens, gxi_saliency, 'Gradient × Input Saliency')plt.savefig('nlp_saliency_gxi.png', dpi=150)plt.show() # Highlight text visualizationdef highlight_text(tokens, saliency, predicted_label): """Create HTML-like visualization with highlighted words.""" saliency_norm = saliency / saliency.max() highlighted = [] for token, sal in zip(tokens, saliency_norm): # Map to color intensity intensity = int(sal * 255) if token in ['[CLS]', '[SEP]', '[PAD]']: highlighted.append(token) else: highlighted.append(f"<span style='background:rgba(255,0,0,{sal:.2f})'>{token}</span>") return ' '.join(highlighted) html = highlight_text(tokens, gxi_saliency, labels[pred])print(f"\nHighlighted (intensity = saliency):")# In Jupyter this would render nicelyprint(' '.join([f"[{t}: {s:.2f}]" for t, s in zip(tokens, gxi_saliency/gxi_saliency.max())]))Saliency maps look compelling but have fundamental limitations that every practitioner must understand.
Key Research Finding: 'Sanity Checks for Saliency Maps' (Adebayo et al., 2018)
This influential paper proposed tests to verify whether saliency methods actually explain model behavior:
Model Parameter Randomization Test: Randomize network weights and check if saliency changes. If saliency looks similar for trained vs random networks, it doesn't explain the model.
Data Randomization Test: Train on randomized labels and check saliency. Meaningful explanations should differ.
Shocking Results: Guided Backpropagation and Guided GradCAM fail these tests—they produce similar saliency for random networks! Only vanilla gradients and GradCAM pass.
| Method | Model Randomization Test | Data Randomization Test | Reliability |
|---|---|---|---|
| Vanilla Gradient | ✓ Pass | ✓ Pass | Faithful but noisy |
| Gradient × Input | ✓ Pass | ✓ Pass | Faithful |
| Integrated Gradients | ✓ Pass | ✓ Pass | Faithful, principled |
| SmoothGrad | ✓ Pass | ✓ Pass | Faithful, less noisy |
| Guided Backprop | ✗ Fail | ✗ Fail | NOT faithful—edge detector |
| GradCAM | ✓ Pass | ✓ Pass | Faithful, class-specific |
| Guided GradCAM | ✗ Fail | ✗ Fail | Combines faithful + unfaithful |
Even 'faithful' saliency methods have an interpretation gap. A pixel with high saliency means the model is sensitive to it—but humans and models may focus on the same pixels for different reasons. A model might use texture while humans expect shape. Saliency confirms sensitivity but doesn't explain reasoning.
Gradient-based saliency methods provide powerful tools for understanding neural network predictions. Here's the essential framework for using them correctly:
What's Next:
Gradient-based methods reveal sensitivity but don't directly show what concepts the model has learned. In the next page, we'll explore Concept Activation Vectors (CAVs)—a method for testing whether models use human-interpretable concepts in their decision-making. This provides a higher-level view of model behavior that complements pixel-level saliency.
You now have a comprehensive understanding of gradient-based saliency methods. You can implement vanilla gradients, Gradient × Input, Integrated Gradients, SmoothGrad, and GradCAM. More importantly, you understand their limitations and can critically evaluate saliency-based explanations using sanity checks and validation.