Loading content...
In mechanistic interpretability, researchers aim to reverse-engineer neural networks by identifying specific computational pathways called neural circuits. One powerful technique for isolating these circuits is activation substitution, where selected node activations are replaced with reference values to observe the resulting changes in network behavior.
The core idea is straightforward yet profound: by substituting certain activations with their mean values (computed over a reference distribution), we can "knock out" specific components and measure their causal contribution to a model's output. This technique is analogous to genetic knockout experiments in biology, where scientists disable individual genes to understand their function.
Formal Definition:
Given:
a of length n, representing the current outputs of neurons in a layerm of length n, where m[i] = 1 indicates the i-th neuron should be ablatedr of length n, containing the mean activations computed over a baseline distributionThe substituted activation vector s is computed element-wise as:
$$s_i = \begin{cases} r_i & \text{if } m_i = 1 \text{ (substitute with reference)} \ a_i & \text{if } m_i = 0 \text{ (keep original)} \end{cases}$$
Interpretation:
m[i] = 1: The neuron is "ablated" — its information is removed by replacing it with the average behaviorm[i] = 0: The neuron's activation passes through unchanged, preserving its contributionYour Task: Implement a function that performs activation substitution. The function should take the current activations, a binary mask, and reference values, then return the modified activations where masked positions are substituted with their corresponding reference values.
activations = [0.5, -0.3, 0.8, 0.2]
mask = [1, 0, 1, 0]
reference_values = [0.1, 0.0, 0.2, -0.1][0.1, -0.3, 0.2, 0.2]The mask indicates which activations should be substituted:
• Index 0: mask=1, so activation 0.5 is replaced with reference value 0.1 • Index 1: mask=0, so activation -0.3 is preserved • Index 2: mask=1, so activation 0.8 is replaced with reference value 0.2 • Index 3: mask=0, so activation 0.2 is preserved
The resulting vector [0.1, -0.3, 0.2, 0.2] represents the activations after ablating neurons at indices 0 and 2.
activations = [1.0, 2.0, 3.0]
mask = [1, 1, 1]
reference_values = [0.5, 1.5, 2.5][0.5, 1.5, 2.5]When all mask values are 1, every activation is substituted with its corresponding reference value. This represents a complete ablation of the layer, where all neurons are replaced with their mean activations. The output shows the baseline behavior when no specific neuron provides distinct information.
activations = [1.5, -2.5, 3.5]
mask = [0, 0, 0]
reference_values = [0.0, 0.0, 0.0][1.5, -2.5, 3.5]When all mask values are 0, no substitution occurs and all original activations are preserved. This is equivalent to normal forward propagation with no ablation. The reference values are ignored entirely, and the circuit functions as if no intervention was applied.
Constraints