Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2fd7cf4
Fix some typos in README.md
JasonGross Sep 7, 2024
694f222
README: add some hyperlinks and code font
JasonGross Sep 7, 2024
4b01f89
`legendHandles` => `legend_handles`
JasonGross Sep 14, 2024
e2bf9ad
Fix to support MPS: convert to float32 earlier
JasonGross Sep 7, 2024
ebd1407
Don't fail with "AssertionError: Not enough CUDA devices to support n…
JasonGross Aug 28, 2024
f3ae2c7
Revert to cpu if cuda is not available
JasonGross Sep 7, 2024
cbfac5b
Use pathlib more and have default cache paths be relative to the repo…
JasonGross Sep 7, 2024
6357ecd
Add support for loading 16-bit models
JasonGross Sep 7, 2024
f973077
Add use of dtype in ipynb
JasonGross Sep 14, 2024
fe5e531
Update torch.load calls to include weights_only parameter
Tauheed-Elahee Feb 20, 2026
1abd74e
Update all calls to torch.load with required weights_only parameter.
Tauheed-Elahee Mar 12, 2026
c251a5f
Fix TransformerLens model name for GPT-2
Tauheed-Elahee Mar 13, 2026
5890d59
Use dynamic device count instead of hardcoded num_devices=2
Tauheed-Elahee Mar 13, 2026
1ff27db
Add compatibility for newer sae_lens forward() return type
Tauheed-Elahee Mar 13, 2026
ae86fa8
Fix incorrect --clustering_type arg in README examples
Tauheed-Elahee Mar 13, 2026
d502ab1
Merge remote-tracking branch 'JasonGross/patch-2' into bleeding-edge
Tauheed-Elahee Mar 12, 2026
eb1f5c3
Merge remote-tracking branch 'JasonGross/patch-3' into bleeding-edge
Tauheed-Elahee Mar 12, 2026
6e12727
Merge remote-tracking branch 'JasonGross/legend_handles' into bleedin…
Tauheed-Elahee Mar 12, 2026
565264d
Merge remote-tracking branch 'JasonGross/mps-cleanup' into bleeding-edge
Tauheed-Elahee Mar 12, 2026
663bd5a
Merge remote-tracking branch 'JasonGross/patch-1' into bleeding-edge
Tauheed-Elahee Mar 12, 2026
fdcdda9
Merge remote-tracking branch 'JasonGross/cuda-available-simpler' into…
Tauheed-Elahee Mar 12, 2026
888a582
Merge remote-tracking branch 'JasonGross/pathlib' into bleeding-edge
Tauheed-Elahee Mar 12, 2026
fbc859f
Merge PR #10 (dtype support) with conflicts resolved: keep PR #11 dev…
Tauheed-Elahee Mar 12, 2026
3e88a3f
Merge PR #16 (torch.load weights_only) with conflicts resolved: keep …
Tauheed-Elahee Mar 12, 2026
9e64f26
Merge branch 'tl-model-name' into bleeding-edge
Tauheed-Elahee Mar 13, 2026
ad1946e
Merge branch 'dynamic-num-devices' into bleeding-edge
Tauheed-Elahee Mar 13, 2026
b9399a6
Merge branch 'sae-lens-compatibility' into bleeding-edge
Tauheed-Elahee Mar 13, 2026
cada533
Merge branch 'readme-method-arg' into bleeding-edge
Tauheed-Elahee Mar 13, 2026
bd6ed81
Read BASE_DIR from environment variable with fallback to default
Tauheed-Elahee Mar 13, 2026
18e77fe
Change default CUDA device from cuda:4 to cuda:0
Tauheed-Elahee Mar 13, 2026
6365a40
Remove NVIDIA CUDA 12.1 hard pins from requirements
Tauheed-Elahee Mar 13, 2026
b9dc4ef
Add requirements.in and update requirements.txt via pip-compile
Tauheed-Elahee Mar 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,55 +6,55 @@ This is the github repo for our paper ["Not All Language Model Features Are Line

## Reproducing each figure

Below are instructions to reproduce each figure (aspirationally).
Below are instructions to reproduce each figure (aspirationally).

The required pthon packages to run this repo are
The required Python packages to run this repo are
```
transformer_lens sae_lens transformers datasets torch adjustText circuitsvis ipython
```
We recommend you creat a new python venv named multid and install these packages,
We recommend you create a new python venv named multid and install these packages,
either manually using pip or using the existing requirements.txt if you are on a linux
machine with Cuda 12.1:
```
python -m venv multid
pip install -r requirements.txt
pip install -r requirements.txt
OR
pip install transformer_lens sae_lens transformers datasets torch adjustText circuitsvis ipython
```
Let us know if anything does not work with this environment!

### Intervention Experiments

Before running experiments, you should change BASE_DIR in intervention/utils.py to point to a location on your machine where large artifacts can be downloaded and saved (Mistral and Llama 3 take ~60GB and experiment artifacts are ~100GB).
Before running experiments, you should change `BASE_DIR` in [`intervention/utils.py`](./intervention/utils.py) to point to a location on your machine where large artifacts can be downloaded and saved (Mistral and Llama 3 take ~60GB and experiment artifacts are ~100GB).

To reproduce the intervention results, you will first need to run intervention experiments with the following commands:

```
cd intervention
python3 circle_probe_interventions.py day a mistral --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin
python3 circle_probe_interventions.py month a mistral --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin
python3 circle_probe_interventions.py day a llama --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin
python3 circle_probe_interventions.py month a llama --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin
python3 circle_probe_interventions.py day a mistral --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin
python3 circle_probe_interventions.py month a mistral --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin
python3 circle_probe_interventions.py day a llama --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin
python3 circle_probe_interventions.py month a llama --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin
```

You can then reproduce *Figure 3*, *Figure 5*, *Figure 6*, and *Table 1* by running the corresponding cells in intervention/main_text_plots.ipynb.
You can then reproduce *Figure 3*, *Figure 5*, *Figure 6*, and *Table 1* by running the corresponding cells in [`intervention/main_text_plots.ipynb`](./intervention/main_text_plots.ipynb).


After running these intervention experiments, you can reproduce *Figure 6* by running
After running these intervention experiments, you can reproduce *Figure 6* by running
```
cd intervention
python3 intervene_in_middle_of_circle.py --only_paper_plots
```
and then running the corresponding cell in intervention/main_text_plots.ipynb.
and then running the corresponding cell in [`intervention/main_text_plots.ipynb`](./intervention/main_text_plots.ipynb).

You can reproduce *Figure 13*, *Figure 14*, *Figure 15*, *Table 2*, *Table 3*, and *Table 4* (all from the appendix) by running cells in intervention/appendix_plots.ipynb.


### SAE feature search experiments

Before running experiments, you should again change BASE_DIR in sae_multid_feature_discovery/utils.py to point to a location on your machine where large artifacts can be downloaded and saved.
Before running experiments, you should again change `BASE_DIR` in [`sae_multid_feature_discovery/utils.py`](./sae_multid_feature_discovery/utils.py) to point to a location on your machine where large artifacts can be downloaded and saved.

You will need to generate SAE feature activations to generate the cluster reconstructions. The GPT-2 SAEs will be automatically downloaded when you run the below scripts, while for Mistral you will need to download our pretrained Mistral SAEs from https://www.dropbox.com/scl/fo/hznwqj4fkqvpr7jtx9uxz/AJUe0wKmJS1-fD982PuHb5A?rlkey=ffnq6pm6syssf2p7t98q9kuh1&dl=0 to sae_multid_feature_discovery/saes/mistral_saes. You can generate SAE feature activations with one of the following two commands:
You will need to generate SAE feature activations to generate the cluster reconstructions. The GPT-2 SAEs will be automatically downloaded when you run the below scripts, while for Mistral you will need to download our pretrained Mistral SAEs from https://www.dropbox.com/scl/fo/hznwqj4fkqvpr7jtx9uxz/AJUe0wKmJS1-fD982PuHb5A?rlkey=ffnq6pm6syssf2p7t98q9kuh1&dl=0 to [`sae_multid_feature_discovery/saes/mistral_saes`](./sae_multid_feature_discovery/saes/mistral_saes). You can generate SAE feature activations with one of the following two commands:

```
cd sae_multid_feature_discovery
Expand All @@ -64,10 +64,10 @@ python3 generate_feature_occurence_data.py --model_name mistral

You can also directly download the gpt-2 layer 7 and Mistral-7B layer 8 activations data from this Dropbox folder: https://www.dropbox.com/scl/fo/frn4tihzkvyesqoumtl9u/AFPEAa6KFb8mY3NTXIEStnA?rlkey=z60j3g45jzhxwc5s5qxmbjvxs&st=da2tzqk5&dl=0. You should put them in the `sae_multid_feature_discovery` directory.

You will also need to generate the actual clusters by running clustering.py, e.g.
You will also need to generate the actual clusters by running `clustering.py`, e.g.
```
python3 clustering.py --model_name gpt-2 --clustering_type spectral --layer 7
python3 clustering.py --model_name mistral --clustering_type graph --layer 8
python3 clustering.py --model_name gpt-2 --method spectral --layer 7
python3 clustering.py --model_name mistral --method graph --layer 8
```

Unfortunately, we did not set a seed when we ran spectral clustering in our original experiments, so the clusters you get from the above command may not be the same as the ones we used in the paper. In the `sae_multid_feature_discovery` directory, we provide the GPT-2 (`gpt-2_layer_7_clusters_spectral_n1000.pkl`) and Mistral-7B (`mistral_layer_8_clusters_cutoff_0.5.pkl`) clusters that were used in the paper. For easy reference, here are the GPT-2 SAE feature indices for the days, weeks, and years clusters we reported in the paper (Figure 1):
Expand Down Expand Up @@ -120,7 +120,7 @@ To reproduce the residual RGB plots in the paper (*Figure 8*, and *Figure 16*),

## Contact

If you have any questions about the paper or reproducing results, feel free to email jengels@mit.edu.
If you have any questions about the paper or reproducing results, feel free to email [jengels@mit.edu](mailto:jengels@mit.edu).

## Citation

Expand All @@ -132,5 +132,3 @@ If you have any questions about the paper or reproducing results, feel free to e
year={2024}
}
```


Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def deconstruct(layer, n_feature_groups):
+ str(start_token + token)
+ "_pca"
+ str(n_pca_dims)
+ ".pt"
+ ".pt",
weights_only=False,
)
flat_activations = activations[order, :] # problem, pca
activations = flat_activations.reshape([mod, mod, n_pca_dims])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def deconstruct(layer, n_feature_groups):
+ str(start_token + token)
+ "_pca"
+ str(n_pca_dims)
+ ".pt"
+ ".pt",
weights_only=False,
)
flat_activations = activations[order, :] # problem, pca
activations = flat_activations.reshape([mod, mod, n_pca_dims])
Expand Down
28 changes: 15 additions & 13 deletions intervention/appendix_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"outputs": [],
"source": [
"# %%\n",
"from pathlib import Path\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from task import get_acts, get_acts_pca, get_all_acts\n",
Expand All @@ -28,7 +29,8 @@
"\n",
"os.makedirs(\"figs/paper_plots\", exist_ok=True)\n",
"\n",
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\""
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
"dtype = \"float32\""
]
},
{
Expand Down Expand Up @@ -108,8 +110,8 @@
" frameon=False,\n",
" )\n",
"\n",
" for i in range(len(legend.legendHandles)):\n",
" legend.legendHandles[i]._sizes = [2]\n",
" for i in range(len(legend.legend_handles)):\n",
" legend.legend_handles[i]._sizes = [2]\n",
"\n",
" plt.tight_layout()\n",
"\n",
Expand All @@ -130,9 +132,9 @@
"for model_name in [\"mistral\", \"llama\"]:\n",
" for task_name in [\"task_name\", \"months_of_year\"]:\n",
" if task_name == \"{task_name}\":\n",
" task = DaysOfWeekTask(model_name=model_name, device=device)\n",
" task = DaysOfWeekTask(model_name=model_name, device=device, dtype=dtype)\n",
" else:\n",
" task = MonthsOfYearTask(model_name=model_name, device=device)\n",
" task = MonthsOfYearTask(model_name=model_name, device=device, dtype=dtype)\n",
"\n",
" for keep_same_index in [0, 1]:\n",
" for layer_type in [\"mlp\", \"attention\", \"resid\", \"attention_head\"]:\n",
Expand Down Expand Up @@ -186,9 +188,9 @@
"for model_name in [\"mistral\", \"llama\"]:\n",
" for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
" if task_name == \"days_of_week\":\n",
" task = DaysOfWeekTask(device, model_name=model_name)\n",
" task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n",
" else:\n",
" task = MonthsOfYearTask(device, model_name=model_name)\n",
" task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n",
"\n",
" for patching_type in [\"mlp\", \"attention\"]:\n",
" fig, ax = plt.subplots(figsize=(10, 5))\n",
Expand Down Expand Up @@ -283,9 +285,9 @@
"for model_name in [\"mistral\", \"llama\"]:\n",
" for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
" if task_name == \"days_of_week\":\n",
" task = DaysOfWeekTask(device, model_name=model_name)\n",
" task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n",
" else:\n",
" task = MonthsOfYearTask(device, model_name=model_name)\n",
" task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n",
"\n",
" fig, ax = plt.subplots(figsize=(10, 5))\n",
"\n",
Expand Down Expand Up @@ -369,9 +371,9 @@
"data = []\n",
"for model_name, task_name in all_top_heads.keys():\n",
" if task_name == \"days_of_week\":\n",
" task = DaysOfWeekTask(device, model_name=model_name)\n",
" task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n",
" else:\n",
" task = MonthsOfYearTask(device, model_name=model_name)\n",
" task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n",
"\n",
" acts = get_all_acts(\n",
" task,\n",
Expand Down Expand Up @@ -462,7 +464,7 @@
"\n",
"for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
" results_mistral = pd.read_csv(\n",
" f\"{BASE_DIR}/mistral_{task_name}/results.csv\", skipinitialspace=True\n",
" Path(BASE_DIR) / f\"mistral_{task_name}\" / \"results.csv\", skipinitialspace=True\n",
" )\n",
"\n",
" results_mistral = results_mistral.rename(\n",
Expand All @@ -480,7 +482,7 @@
" print(sum(results_mistral[\"mistral_correct\"]))\n",
"\n",
" results_llama = pd.read_csv(\n",
" f\"{BASE_DIR}/llama_{task_name}/results.csv\", skipinitialspace=True\n",
" Path(BASE_DIR) / f\"llama_{task_name}\" / \"results.csv\", skipinitialspace=True\n",
" )\n",
"\n",
" results_llama = results_llama.rename(\n",
Expand Down
2 changes: 1 addition & 1 deletion intervention/circle_finding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_logit_diffs_from_subspace_formula_resid_intervention(
probe_r = probe_r.to(device)
target_embedding_in_q_space = target_to_embedding.to(device) @ probe_r.inverse()

pca_projection_matrix = torch.tensor(pca_projection_matrix).to(device).T.float()
pca_projection_matrix = torch.tensor(pca_projection_matrix).float().to(device).T

all_pcas = (
torch.tensor(
Expand Down
28 changes: 20 additions & 8 deletions intervention/circle_probe_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@
choices=["llama", "mistral"],
help="Choose 'llama' or 'mistral' model",
)
parser.add_argument("--device", type=int, default=4, help="CUDA device number")
parser.add_argument(
"--device",
type=str,
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="Device to use",
)
parser.add_argument(
"--dtype", type=str, default="float32", help="Data type for torch tensors"
)
parser.add_argument(
"--use_inverse_regression_probe",
action="store_true",
Expand Down Expand Up @@ -73,7 +81,8 @@
help="Probe on linear representation with center of 0.",
)
args = parser.parse_args()
device = f"cuda:{args.device}"
device = args.device
dtype = args.dtype
day_month_choice = args.problem_type
circle_letter = args.intervene_on
model_name = args.model
Expand All @@ -100,7 +109,8 @@
# use_inverse_regression_probe = False
# intervention_pca_k = 5

device = "cuda:4"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = "float32"
circle_letter = "c"
day_month_choice = "day"
model_name = "mistral"
Expand Down Expand Up @@ -131,9 +141,9 @@
# %%

if day_month_choice == "day":
task = DaysOfWeekTask(device, model_name=model_name)
task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)
else:
task = MonthsOfYearTask(device, model_name=model_name)
task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)

# %%

Expand Down Expand Up @@ -171,7 +181,7 @@
probe_projections = {}
target_to_embeddings = {}

os.makedirs(f"{task.prefix}/circle_probes_{circle_letter}", exist_ok=True)
(task.prefix / f"circle_probes_{circle_letter}").mkdir(exist_ok=True)

all_maes = []
all_r_squareds = []
Expand Down Expand Up @@ -262,7 +272,9 @@
"probe_r": probe_r,
"target_to_embedding": target_to_embedding,
},
f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_{pca_k}.pt",
task.prefix
/ f"circle_probes_{circle_letter}"
/ f"{probe_file_extension}_layer_{layer}_token_{token}_pca_{pca_k}.pt",
)

mae = (predictions - multid_targets_train).abs().mean()
Expand Down Expand Up @@ -377,7 +389,7 @@
logit_diffs_zero_everything_but_circle,
) = get_logit_diffs_from_subspace_formula_resid_intervention(
task,
probe_projection_qr=probe_projections[((layer, intervention_pca_k))],
probe_projection_qr=probe_projections[(layer, intervention_pca_k)],
pca_k_project=intervention_pca_k,
layer=layer,
token=token,
Expand Down
Loading