Fig 1.
Model and task. Top, schematic of the transformer model and task.
Bottom, the prompt and output format for the compositional generalization task.
Fig 2.
Model’s performance on the testing set.
Left, the accuracy for different output lengths. Right, the perplexity for different output lengths.
Fig 3.
Schematic of the compositional algorithm.
For Unbind, in ‘B S D’, B is the 1st token (idx1), and ‘B=blue’, so ‘B=blue=idx1’; similar for ‘D’. For Rebind, in ‘A S B’, A is the 1st token (idx1), and ‘A=red’, so ‘idx1=A=red’; similar for ‘B’.
Fig 4.
Summary of circuit for compositional generalization.
Top, the example episode’s input and output. For a-e, the yellow boxes indicate self-attention heads and the blue boxes indicate cross-attention heads. Titles refer to the functional attention heads that execute the steps (details in Circuit Discovery section). We unfold all relevant information superimposed in tokens’ embeddings and highlight their roles in attention operations. [1]*, the QK alignment discussed in Primitive-Retrieval Head section. [2]*, the QK alignment discussed in Primitive-Pairing Head section.
Fig 5.
Circuit diagram of the key attention heads.
Green circles indicate attention heads that contribute most significantly to downstream nodes. Green arrows denote the flow of contributions from upstream nodes to each attention head. The main sub-circuits highlighted are the K-circuit and Q-circuit leading to the Output Head.
Fig 6.
(a) Logit contributions of each decoder head to the logits of correct tokens (fraction to total logits). (b) Attention pattern of Dec-cross-1.5. (c) For Dec-cross-1.5, the percentage of attention focused on the next predicted token. (d) For Dec-cross-1.5, alignment (inner product) between its OV output (e.g., ) and the corresponding unembedding vector (e.g.,
). We estimated the null distribution by randomly sampling unembedding vectors.
Fig 7.
(a) Top, contributions to Output Head’s performance (percentage of attention on the correct next token) via K. Bottom, attention pattern of Enc-self-1.1. (b) Top, contributions to the Output Head’s performance through the Primitive-Pairing Head’s V. Bottom, attention pattern of Enc-self-0.5.
Fig 8.
Principal Components Analysis (PCA) of token embeddings.
The embeddings are colored by their associated index-in-question. Concretely, for a prompt like ‘A S B | A=red | B=blue | ...’, in (a), points are the Z of ‘A’ and ‘B’ in the Support (A labeled 1st, B labeled 3rd); in (b), points are the Z of ‘red’ and ‘blue’ in the Support (red labeled 1st, blue labeled 3rd); in (c), points are the K of ‘red’ and ‘blue’ in the Support (red labeled 1st, blue labeled 3rd). The distinct clusters suggest strong index information. R2 score quantifies the percentage of total variance explained by the index identity.
Fig 9.
(a) Contribution to Output Head’s performance via Q.(b) Attention pattern of Dec-cross-0.6.
Fig 10.
The primitive- and function-retrieval heads.
(a) Contribution to Output Head’s performance via Q. (b) Contribution to Output Head’s performance via the RHS-Scanner’s V. (c) Attention pattern of Dec-cross-0.6. (d) and (e) Attention patterns of Enc-self-1.0 and Enc-self-1.2.
Fig 11.
The computation pipeline for index-on-LHS.
(a), In the Primitive-Retrieval head, the pink token retrieves the absolute position of the D token; in the Function-Retrieval head, the pink token retrieves the absolute position of the S token; then the RHS-Scanner head computes the difference of the two values to get the relative position of D on the function LHS. (b) and (c), Ablation results for token embeddings labeled by index-on-LHS. Concretely, for an episode with prompt ‘A S B | A=red | B=blue | D=pink | B S D=pink blue pink | ...’ and prediction ‘SOS blue red blue EOS’, in (b), points are the Z of ‘SOS’ and ‘blue’ in the decoder input tokens (SOS is labeled 3rd, because SOS attends to the pink on function RHS, and D is the 3rd on the LHS; similarly, blue is labeled 1st); in (c), points are the Q of decoder input tokens (SOS is labeled 3rd, blue is labeled 1st). R2 score quantifies the percentage of total variance explained by the index identity.
Fig 12.
Targeted perturbation experiment.
(a) Schematic illustrating the targeted swap of position embeddings. (b) Attention weights from the Output Head comparing the original correct token and swapped token across three conditions: unperturbed (left), targeted position swap (middle), and random shuffle control (right). (c) Similar comparison as (b), but for final output logits rather than attention weights.
Fig 13.
Two modes of ablation.
Fig 14.
Circuit discovery on a different architecture.
The model consists of 3 encoder/decoder layers with 4 attention heads in each layer. All the important attention heads are re-discovered in the new model, with the only exception that the original Primitive-Pairing and Primitive-Retrieval Heads are now merged as a single head. We speculate that this is due to their similar functions and the reduced number of heads in each layer.