diff --git a/README.md b/README.md index 7e7f6a8..81a640f 100644 --- a/README.md +++ b/README.md @@ -6,18 +6,18 @@ This is the github repo for our paper ["Not All Language Model Features Are Line ## Reproducing each figure -Below are instructions to reproduce each figure (aspirationally). +Below are instructions to reproduce each figure (aspirationally). -The required pthon packages to run this repo are +The required Python packages to run this repo are ``` transformer_lens sae_lens transformers datasets torch adjustText circuitsvis ipython ``` -We recommend you creat a new python venv named multid and install these packages, +We recommend you create a new python venv named multid and install these packages, either manually using pip or using the existing requirements.txt if you are on a linux machine with Cuda 12.1: ``` python -m venv multid -pip install -r requirements.txt +pip install -r requirements.txt OR pip install transformer_lens sae_lens transformers datasets torch adjustText circuitsvis ipython ``` @@ -25,36 +25,36 @@ Let us know if anything does not work with this environment! ### Intervention Experiments -Before running experiments, you should change BASE_DIR in intervention/utils.py to point to a location on your machine where large artifacts can be downloaded and saved (Mistral and Llama 3 take ~60GB and experiment artifacts are ~100GB). +Before running experiments, you should change `BASE_DIR` in [`intervention/utils.py`](./intervention/utils.py) to point to a location on your machine where large artifacts can be downloaded and saved (Mistral and Llama 3 take ~60GB and experiment artifacts are ~100GB). To reproduce the intervention results, you will first need to run intervention experiments with the following commands: ``` cd intervention -python3 circle_probe_interventions.py day a mistral --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin -python3 circle_probe_interventions.py month a mistral --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin -python3 circle_probe_interventions.py day a llama --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin -python3 circle_probe_interventions.py month a llama --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py day a mistral --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py month a mistral --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py day a llama --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py month a llama --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin ``` -You can then reproduce *Figure 3*, *Figure 5*, *Figure 6*, and *Table 1* by running the corresponding cells in intervention/main_text_plots.ipynb. +You can then reproduce *Figure 3*, *Figure 5*, *Figure 6*, and *Table 1* by running the corresponding cells in [`intervention/main_text_plots.ipynb`](./intervention/main_text_plots.ipynb). -After running these intervention experiments, you can reproduce *Figure 6* by running +After running these intervention experiments, you can reproduce *Figure 6* by running ``` cd intervention python3 intervene_in_middle_of_circle.py --only_paper_plots ``` -and then running the corresponding cell in intervention/main_text_plots.ipynb. +and then running the corresponding cell in [`intervention/main_text_plots.ipynb`](./intervention/main_text_plots.ipynb). You can reproduce *Figure 13*, *Figure 14*, *Figure 15*, *Table 2*, *Table 3*, and *Table 4* (all from the appendix) by running cells in intervention/appendix_plots.ipynb. ### SAE feature search experiments -Before running experiments, you should again change BASE_DIR in sae_multid_feature_discovery/utils.py to point to a location on your machine where large artifacts can be downloaded and saved. +Before running experiments, you should again change `BASE_DIR` in [`sae_multid_feature_discovery/utils.py`](./sae_multid_feature_discovery/utils.py) to point to a location on your machine where large artifacts can be downloaded and saved. -You will need to generate SAE feature activations to generate the cluster reconstructions. The GPT-2 SAEs will be automatically downloaded when you run the below scripts, while for Mistral you will need to download our pretrained Mistral SAEs from https://www.dropbox.com/scl/fo/hznwqj4fkqvpr7jtx9uxz/AJUe0wKmJS1-fD982PuHb5A?rlkey=ffnq6pm6syssf2p7t98q9kuh1&dl=0 to sae_multid_feature_discovery/saes/mistral_saes. You can generate SAE feature activations with one of the following two commands: +You will need to generate SAE feature activations to generate the cluster reconstructions. The GPT-2 SAEs will be automatically downloaded when you run the below scripts, while for Mistral you will need to download our pretrained Mistral SAEs from https://www.dropbox.com/scl/fo/hznwqj4fkqvpr7jtx9uxz/AJUe0wKmJS1-fD982PuHb5A?rlkey=ffnq6pm6syssf2p7t98q9kuh1&dl=0 to [`sae_multid_feature_discovery/saes/mistral_saes`](./sae_multid_feature_discovery/saes/mistral_saes). You can generate SAE feature activations with one of the following two commands: ``` cd sae_multid_feature_discovery @@ -64,10 +64,10 @@ python3 generate_feature_occurence_data.py --model_name mistral You can also directly download the gpt-2 layer 7 and Mistral-7B layer 8 activations data from this Dropbox folder: https://www.dropbox.com/scl/fo/frn4tihzkvyesqoumtl9u/AFPEAa6KFb8mY3NTXIEStnA?rlkey=z60j3g45jzhxwc5s5qxmbjvxs&st=da2tzqk5&dl=0. You should put them in the `sae_multid_feature_discovery` directory. -You will also need to generate the actual clusters by running clustering.py, e.g. +You will also need to generate the actual clusters by running `clustering.py`, e.g. ``` -python3 clustering.py --model_name gpt-2 --clustering_type spectral --layer 7 -python3 clustering.py --model_name mistral --clustering_type graph --layer 8 +python3 clustering.py --model_name gpt-2 --method spectral --layer 7 +python3 clustering.py --model_name mistral --method graph --layer 8 ``` Unfortunately, we did not set a seed when we ran spectral clustering in our original experiments, so the clusters you get from the above command may not be the same as the ones we used in the paper. In the `sae_multid_feature_discovery` directory, we provide the GPT-2 (`gpt-2_layer_7_clusters_spectral_n1000.pkl`) and Mistral-7B (`mistral_layer_8_clusters_cutoff_0.5.pkl`) clusters that were used in the paper. For easy reference, here are the GPT-2 SAE feature indices for the days, weeks, and years clusters we reported in the paper (Figure 1): @@ -120,7 +120,7 @@ To reproduce the residual RGB plots in the paper (*Figure 8*, and *Figure 16*), ## Contact -If you have any questions about the paper or reproducing results, feel free to email jengels@mit.edu. +If you have any questions about the paper or reproducing results, feel free to email [jengels@mit.edu](mailto:jengels@mit.edu). ## Citation @@ -132,5 +132,3 @@ If you have any questions about the paper or reproducing results, feel free to e year={2024} } ``` - - diff --git a/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py b/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py index 2c85289..9706026 100644 --- a/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py +++ b/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py @@ -42,7 +42,8 @@ def deconstruct(layer, n_feature_groups): + str(start_token + token) + "_pca" + str(n_pca_dims) - + ".pt" + + ".pt", + weights_only=False, ) flat_activations = activations[order, :] # problem, pca activations = flat_activations.reshape([mod, mod, n_pca_dims]) diff --git a/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py b/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py index 419e824..b99bba1 100644 --- a/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py +++ b/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py @@ -42,7 +42,8 @@ def deconstruct(layer, n_feature_groups): + str(start_token + token) + "_pca" + str(n_pca_dims) - + ".pt" + + ".pt", + weights_only=False, ) flat_activations = activations[order, :] # problem, pca activations = flat_activations.reshape([mod, mod, n_pca_dims]) diff --git a/intervention/appendix_plots.ipynb b/intervention/appendix_plots.ipynb index e64984b..cc5971f 100644 --- a/intervention/appendix_plots.ipynb +++ b/intervention/appendix_plots.ipynb @@ -7,6 +7,7 @@ "outputs": [], "source": [ "# %%\n", + "from pathlib import Path\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from task import get_acts, get_acts_pca, get_all_acts\n", @@ -28,7 +29,8 @@ "\n", "os.makedirs(\"figs/paper_plots\", exist_ok=True)\n", "\n", - "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"" + "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", + "dtype = \"float32\"" ] }, { @@ -108,8 +110,8 @@ " frameon=False,\n", " )\n", "\n", - " for i in range(len(legend.legendHandles)):\n", - " legend.legendHandles[i]._sizes = [2]\n", + " for i in range(len(legend.legend_handles)):\n", + " legend.legend_handles[i]._sizes = [2]\n", "\n", " plt.tight_layout()\n", "\n", @@ -130,9 +132,9 @@ "for model_name in [\"mistral\", \"llama\"]:\n", " for task_name in [\"task_name\", \"months_of_year\"]:\n", " if task_name == \"{task_name}\":\n", - " task = DaysOfWeekTask(model_name=model_name, device=device)\n", + " task = DaysOfWeekTask(model_name=model_name, device=device, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(model_name=model_name, device=device)\n", + " task = MonthsOfYearTask(model_name=model_name, device=device, dtype=dtype)\n", "\n", " for keep_same_index in [0, 1]:\n", " for layer_type in [\"mlp\", \"attention\", \"resid\", \"attention_head\"]:\n", @@ -186,9 +188,9 @@ "for model_name in [\"mistral\", \"llama\"]:\n", " for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(device, model_name=model_name)\n", + " task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(device, model_name=model_name)\n", + " task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n", "\n", " for patching_type in [\"mlp\", \"attention\"]:\n", " fig, ax = plt.subplots(figsize=(10, 5))\n", @@ -283,9 +285,9 @@ "for model_name in [\"mistral\", \"llama\"]:\n", " for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(device, model_name=model_name)\n", + " task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(device, model_name=model_name)\n", + " task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n", "\n", " fig, ax = plt.subplots(figsize=(10, 5))\n", "\n", @@ -369,9 +371,9 @@ "data = []\n", "for model_name, task_name in all_top_heads.keys():\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(device, model_name=model_name)\n", + " task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(device, model_name=model_name)\n", + " task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n", "\n", " acts = get_all_acts(\n", " task,\n", @@ -462,7 +464,7 @@ "\n", "for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " results_mistral = pd.read_csv(\n", - " f\"{BASE_DIR}/mistral_{task_name}/results.csv\", skipinitialspace=True\n", + " Path(BASE_DIR) / f\"mistral_{task_name}\" / \"results.csv\", skipinitialspace=True\n", " )\n", "\n", " results_mistral = results_mistral.rename(\n", @@ -480,7 +482,7 @@ " print(sum(results_mistral[\"mistral_correct\"]))\n", "\n", " results_llama = pd.read_csv(\n", - " f\"{BASE_DIR}/llama_{task_name}/results.csv\", skipinitialspace=True\n", + " Path(BASE_DIR) / f\"llama_{task_name}\" / \"results.csv\", skipinitialspace=True\n", " )\n", "\n", " results_llama = results_llama.rename(\n", diff --git a/intervention/circle_finding_utils.py b/intervention/circle_finding_utils.py index 8b556ee..82866c3 100644 --- a/intervention/circle_finding_utils.py +++ b/intervention/circle_finding_utils.py @@ -114,7 +114,7 @@ def get_logit_diffs_from_subspace_formula_resid_intervention( probe_r = probe_r.to(device) target_embedding_in_q_space = target_to_embedding.to(device) @ probe_r.inverse() - pca_projection_matrix = torch.tensor(pca_projection_matrix).to(device).T.float() + pca_projection_matrix = torch.tensor(pca_projection_matrix).float().to(device).T all_pcas = ( torch.tensor( diff --git a/intervention/circle_probe_interventions.py b/intervention/circle_probe_interventions.py index d936af3..4f894e9 100644 --- a/intervention/circle_probe_interventions.py +++ b/intervention/circle_probe_interventions.py @@ -41,7 +41,15 @@ choices=["llama", "mistral"], help="Choose 'llama' or 'mistral' model", ) - parser.add_argument("--device", type=int, default=4, help="CUDA device number") + parser.add_argument( + "--device", + type=str, + default="cuda:0" if torch.cuda.is_available() else "cpu", + help="Device to use", + ) + parser.add_argument( + "--dtype", type=str, default="float32", help="Data type for torch tensors" + ) parser.add_argument( "--use_inverse_regression_probe", action="store_true", @@ -73,7 +81,8 @@ help="Probe on linear representation with center of 0.", ) args = parser.parse_args() - device = f"cuda:{args.device}" + device = args.device + dtype = args.dtype day_month_choice = args.problem_type circle_letter = args.intervene_on model_name = args.model @@ -100,7 +109,8 @@ # use_inverse_regression_probe = False # intervention_pca_k = 5 - device = "cuda:4" + device = "cuda:0" if torch.cuda.is_available() else "cpu" + dtype = "float32" circle_letter = "c" day_month_choice = "day" model_name = "mistral" @@ -131,9 +141,9 @@ # %% if day_month_choice == "day": - task = DaysOfWeekTask(device, model_name=model_name) + task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype) else: - task = MonthsOfYearTask(device, model_name=model_name) + task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype) # %% @@ -171,7 +181,7 @@ probe_projections = {} target_to_embeddings = {} -os.makedirs(f"{task.prefix}/circle_probes_{circle_letter}", exist_ok=True) +(task.prefix / f"circle_probes_{circle_letter}").mkdir(exist_ok=True) all_maes = [] all_r_squareds = [] @@ -262,7 +272,9 @@ "probe_r": probe_r, "target_to_embedding": target_to_embedding, }, - f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_{pca_k}.pt", + task.prefix + / f"circle_probes_{circle_letter}" + / f"{probe_file_extension}_layer_{layer}_token_{token}_pca_{pca_k}.pt", ) mae = (predictions - multid_targets_train).abs().mean() @@ -377,7 +389,7 @@ logit_diffs_zero_everything_but_circle, ) = get_logit_diffs_from_subspace_formula_resid_intervention( task, - probe_projection_qr=probe_projections[((layer, intervention_pca_k))], + probe_projection_qr=probe_projections[(layer, intervention_pca_k)], pca_k_project=intervention_pca_k, layer=layer, token=token, diff --git a/intervention/compare_circle_intervention_types.py b/intervention/compare_circle_intervention_types.py index 621c2d7..a2f8799 100644 --- a/intervention/compare_circle_intervention_types.py +++ b/intervention/compare_circle_intervention_types.py @@ -61,17 +61,21 @@ # %% -mistral_pcas = pickle.load(open("../sae_multid_feature_discovery/fit_pca_days.pkl", "rb")).components_[1:3, :] +mistral_pcas = pickle.load( + open("../sae_multid_feature_discovery/fit_pca_days.pkl", "rb") +).components_[1:3, :] # %% # Get original probe data -original_probe = torch.load(f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_8_token_{token}_pca_5.pt") +original_probe = torch.load( + f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_8_token_{token}_pca_5.pt", + weights_only=False, +) original_probe_data = [] for layer in [6, 7, 8, 9, 10]: - ( logit_diffs_before, logit_diffs_after, @@ -98,8 +102,30 @@ average_zero_circle = np.mean(logit_diffs_zero_circle) average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle) - original_probe_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle)) - original_probe_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle)) + original_probe_data.append( + ( + layer, + average_before, + average_after, + average_replace_pca, + average_replace_all, + average_average_ablate, + average_zero_circle, + average_zero_everything_but_circle, + ) + ) + original_probe_data.append( + ( + layer, + logit_diffs_before, + logit_diffs_after, + logit_diffs_replace_pca, + logit_diffs_replace_all, + logit_diffs_average_ablate, + logit_diffs_zero_circle, + logit_diffs_zero_everything_but_circle, + ) + ) # %% @@ -119,21 +145,15 @@ current_probe_dimension = 0 if probe_on_cos: multid_targets[:, current_probe_dimension] = torch.cos(w * oned_targets) - target_to_embedding[:, current_probe_dimension] = torch.cos( - w * torch.arange(p) - ) + target_to_embedding[:, current_probe_dimension] = torch.cos(w * torch.arange(p)) current_probe_dimension += 1 if probe_on_sin: multid_targets[:, current_probe_dimension] = torch.sin(w * oned_targets) - target_to_embedding[:, current_probe_dimension] = torch.sin( - w * torch.arange(p) - ) + target_to_embedding[:, current_probe_dimension] = torch.sin(w * torch.arange(p)) current_probe_dimension += 1 if probe_on_centered_linear: multid_targets[:, current_probe_dimension] = oned_targets - (p - 1) / 2 - target_to_embedding[:, current_probe_dimension] = ( - torch.arange(p) - (p - 1) / 2 - ) + target_to_embedding[:, current_probe_dimension] = torch.arange(p) - (p - 1) / 2 current_probe_dimension += 1 assert current_probe_dimension == probe_dimension @@ -144,9 +164,7 @@ projections = (acts_train @ mistral_pcas.T).float() -least_squares_sol = torch.linalg.lstsq( - projections, multid_targets_train -).solution +least_squares_sol = torch.linalg.lstsq(projections, multid_targets_train).solution probe_q, probe_r = torch.linalg.qr(least_squares_sol) @@ -159,7 +177,6 @@ mistral_data = [] for layer in [6, 7, 8, 9, 10]: - ( logit_diffs_before, logit_diffs_after, @@ -187,18 +204,41 @@ average_zero_circle = np.mean(logit_diffs_zero_circle) average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle) - mistral_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle)) - mistral_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle)) + mistral_data.append( + ( + layer, + average_before, + average_after, + average_replace_pca, + average_replace_all, + average_average_ablate, + average_zero_circle, + average_zero_everything_but_circle, + ) + ) + mistral_data.append( + ( + layer, + logit_diffs_before, + logit_diffs_after, + logit_diffs_replace_pca, + logit_diffs_replace_all, + logit_diffs_average_ablate, + logit_diffs_zero_circle, + logit_diffs_zero_everything_but_circle, + ) + ) # %% - original_probe_varying_layer_data = [] for layer in [6, 7, 8, 9, 10]: - - original_probe = torch.load(f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_5.pt") + original_probe = torch.load( + f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_5.pt", + weights_only=False, + ) ( logit_diffs_before, @@ -226,8 +266,30 @@ average_zero_circle = np.mean(logit_diffs_zero_circle) average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle) - original_probe_varying_layer_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle)) - original_probe_varying_layer_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle)) + original_probe_varying_layer_data.append( + ( + layer, + average_before, + average_after, + average_replace_pca, + average_replace_all, + average_average_ablate, + average_zero_circle, + average_zero_everything_but_circle, + ) + ) + original_probe_varying_layer_data.append( + ( + layer, + logit_diffs_before, + logit_diffs_after, + logit_diffs_replace_pca, + logit_diffs_replace_all, + logit_diffs_average_ablate, + logit_diffs_zero_circle, + logit_diffs_zero_everything_but_circle, + ) + ) # %% @@ -235,7 +297,10 @@ pickle.dump(original_probe_data, open("figs/original_probe_data.pkl", "wb")) pickle.dump(mistral_data, open("figs/mistral_data.pkl", "wb")) -pickle.dump(original_probe_varying_layer_data, open("figs/original_probe_varying_layer_data.pkl", "wb")) +pickle.dump( + original_probe_varying_layer_data, + open("figs/original_probe_varying_layer_data.pkl", "wb"), +) # %% @@ -246,19 +311,25 @@ # Get means average_after_original_probe = [x[2] for x in original_probe_data[::2]] average_after_mistral = [x[2] for x in mistral_data[::2]] -average_after_original_probe_varying_layer = [x[2] for x in original_probe_varying_layer_data[::2]] +average_after_original_probe_varying_layer = [ + x[2] for x in original_probe_varying_layer_data[::2] +] print(average_after_original_probe[0]) print(average_after_mistral[0]) print(average_after_original_probe_varying_layer[0]) import scipy + + def mean_confidence_interval(data, confidence=0.96): a = 1.0 * np.array(data) n = len(a) m, se = np.mean(a), scipy.stats.sem(a) h = se * scipy.stats.t.ppf((1 + confidence) / 2.0, n - 1) return m, m - h, m + h + + # Get confidence intervals original_probe_means = [] original_probe_lower = [] @@ -288,22 +359,13 @@ def mean_confidence_interval(data, confidence=0.96): varying_layer_upper.append(upper) ax.plot(x, original_probe_means, label="Intervene with Layer 8 Probe", marker="o") -ax.fill_between(x, - original_probe_lower, - original_probe_upper, - alpha=0.3) +ax.fill_between(x, original_probe_lower, original_probe_upper, alpha=0.3) ax.plot(x, mistral_means, label="Intervene with SAE Subspace", marker="o") -ax.fill_between(x, - mistral_lower, - mistral_upper, - alpha=0.3) +ax.fill_between(x, mistral_lower, mistral_upper, alpha=0.3) ax.plot(x, varying_layer_means, label="Intervene with Probe", marker="o") -ax.fill_between(x, - varying_layer_lower, - varying_layer_upper, - alpha=0.3) +ax.fill_between(x, varying_layer_lower, varying_layer_upper, alpha=0.3) ax.set_xlabel("Layer") ax.set_xticks(x) @@ -318,19 +380,46 @@ def mean_confidence_interval(data, confidence=0.96): # Map each target value to a consistent color based on its position in the circle cmap = plt.get_cmap("tab10") -days_of_week = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"] +days_of_week = [ + "Monday", + "Tuesday", + "Wednesday", + "Thursday", + "Friday", + "Saturday", + "Sunday", +] added_labels = set() for i in range(len(projections)): if int(oned_targets[i]) not in added_labels: added_labels.add(int(oned_targets[i])) - plt.plot(projections[i, 0], projections[i, 1], ".", color=cmap(int(oned_targets[i])), markersize=10, label=days_of_week[int(oned_targets[i])]) + plt.plot( + projections[i, 0], + projections[i, 1], + ".", + color=cmap(int(oned_targets[i])), + markersize=10, + label=days_of_week[int(oned_targets[i])], + ) else: - plt.plot(projections[i, 0], projections[i, 1], ".", color=cmap(int(oned_targets[i])), markersize=10) + plt.plot( + projections[i, 0], + projections[i, 1], + ".", + color=cmap(int(oned_targets[i])), + markersize=10, + ) # Sort legend by days of the week handles, labels = ax.get_legend_handles_labels() order = np.argsort([days_of_week.index(label) for label in labels]) -ax.legend([handles[idx] for idx in order], [labels[idx] for idx in order], loc="upper left", bbox_to_anchor=(-0.1, 1.2), ncol=4) +ax.legend( + [handles[idx] for idx in order], + [labels[idx] for idx in order], + loc="upper left", + bbox_to_anchor=(-0.1, 1.2), + ncol=4, +) ax.set_xlabel("Projection onto second SAE PCA component") ax.set_ylabel("Projection onto third SAE PCA component") diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index cbb5fa8..9ee675a 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -1,5 +1,6 @@ # %% +from pathlib import Path import os from utils import setup_notebook, BASE_DIR @@ -7,11 +8,12 @@ import numpy as np import transformer_lens +import torch from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca from task import activation_patching -device = "cuda:4" +device = "cuda:0" if torch.cuda.is_available() else "cpu" # # %% @@ -39,19 +41,20 @@ class DaysOfWeekTask: - def __init__(self, device, model_name="mistral", n_devices=None): + def __init__(self, device, model_name="mistral", n_devices=None, dtype="float32"): self.device = device self.model_name = model_name self.n_devices = n_devices + self.dtype = dtype + # Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall) self.allowable_tokens = days_of_week - self.prefix = f"{BASE_DIR}{model_name}_days_of_week/" - if not os.path.exists(self.prefix): - os.makedirs(self.prefix) + self.prefix = Path(BASE_DIR) / f"{model_name}_days_of_week" + self.prefix.mkdir(parents=True, exist_ok=True) self.num_tokens_in_answer = 1 @@ -148,11 +151,15 @@ def generate_problems(self): def get_model(self): if self.n_devices is None: - self.n_devices = 2 if "llama" == self.model_name else 1 + self.n_devices = ( + min(2, max(1, torch.cuda.device_count())) + if "llama" == self.model_name + else 1 + ) if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( - "mistral-7b", device=self.device, n_devices=self.n_devices + "mistral-7b", device=self.device, n_devices=self.n_devices, dtype=self.dtype ) elif self.model_name == "llama": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( @@ -160,6 +167,7 @@ def get_model(self): "meta-llama/Meta-Llama-3-8B", device=self.device, n_devices=self.n_devices, + dtype=self.dtype, ) return self._lazy_model diff --git a/intervention/intervene_in_middle_of_circle.py b/intervention/intervene_in_middle_of_circle.py index 2fb75fb..028175c 100644 --- a/intervention/intervene_in_middle_of_circle.py +++ b/intervention/intervene_in_middle_of_circle.py @@ -40,7 +40,10 @@ def vary_wthin_circle(circle_letter, duration, layer, token, pca_k, all_points): model = task.get_model() circle_projection_qr = torch.load( - f"{task.prefix}/circle_probes_{circle_letter}/cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt" + task.prefix + / f"circle_probes_{circle_letter}" + / f"cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt", + weights_only=False, ) for problem in task.generate_problems(): @@ -257,6 +260,7 @@ def get_circle_hook(layer, circle_point): parser.add_argument( "--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu" ) + parser.add_argument("--dtype", type=str, default="float32") args = parser.parse_args() @@ -265,7 +269,7 @@ def get_circle_hook(layer, circle_point): if args.only_paper_plots: task_level_granularity = "day" model_name = "mistral" - task = DaysOfWeekTask(device, model_name=model_name) + task = DaysOfWeekTask(device, model_name=model_name, dtype=args.dtype) layer = 5 bs = range(2, 6) pca_k = 5 @@ -282,9 +286,13 @@ def get_circle_hook(layer, circle_point): bs = range(1, 13) for b in bs: if task_level_granularity == "day": - task = DaysOfWeekTask(device, model_name=model_name) + task = DaysOfWeekTask( + device, model_name=model_name, dtype=args.dtype + ) elif task_level_granularity == "month": - task = MonthsOfYearTask(device, model_name=model_name) + task = MonthsOfYearTask( + device, model_name=model_name, dtype=args.dtype + ) else: raise ValueError(f"Unknown {task_level_granularity}") for pca_k in [5]: diff --git a/intervention/main_text_plots.ipynb b/intervention/main_text_plots.ipynb index 2266226..763de8d 100644 --- a/intervention/main_text_plots.ipynb +++ b/intervention/main_text_plots.ipynb @@ -7,6 +7,7 @@ "outputs": [], "source": [ "# %%\n", + "from pathlib import Path\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from task import get_acts, get_acts_pca\n", @@ -25,7 +26,9 @@ "\n", "os.makedirs(\"figs/paper_plots\", exist_ok=True)\n", "\n", - "torch.set_grad_enabled(False)" + "torch.set_grad_enabled(False)\n", + "device = \"cpu\"\n", + "dtype = \"float32\"" ] }, { @@ -57,7 +60,7 @@ "\n", "\n", "# Left plot\n", - "task = DaysOfWeekTask(\"cpu\", \"mistral\")\n", + "task = DaysOfWeekTask(device, \"mistral\", dtype=dtype)\n", "problems = task.generate_problems()\n", "tokens = task.allowable_tokens\n", "acts = get_acts_pca(task, layer=30, token=task.a_token, pca_k=2)[0]\n", @@ -88,7 +91,7 @@ "ax1.set_ylim(-8, 8)\n", "\n", "# Right plot\n", - "task = MonthsOfYearTask(\"cpu\", \"llama\")\n", + "task = MonthsOfYearTask(device, \"llama\", dtype=dtype)\n", "problems = task.generate_problems()\n", "tokens = task.allowable_tokens\n", "acts = get_acts_pca(task, layer=3, token=task.a_token, pca_k=2)[0]\n", @@ -309,7 +312,7 @@ " columnspacing=1,\n", " handlelength=0.8,\n", ")\n", - "for legobj in leg.legendHandles:\n", + "for legobj in leg.legend_handles:\n", " legobj.set_linewidth(1.5)\n", "\n", "fig.add_artist(\n", @@ -350,7 +353,7 @@ "s = 0.1\n", "\n", "\n", - "task = DaysOfWeekTask(\"cpu\", model_name=\"mistral\")\n", + "task = DaysOfWeekTask(device, model_name=\"mistral\", dtype=dtype)\n", "layer = 5\n", "token = task.a_token\n", "durations = range(2, 6)\n", @@ -404,7 +407,7 @@ " columnspacing=0,\n", ")\n", "for i in range(circle_size):\n", - " lgnd.legendHandles[i]._sizes = [10]\n", + " lgnd.legend_handles[i]._sizes = [10]\n", "\n", "plt.show()\n", "\n", @@ -430,7 +433,7 @@ "fig = plt.figure(figsize=(1.65, 1.5))\n", "ax = plt.gca()\n", "\n", - "task = DaysOfWeekTask(\"cpu\", model_name=\"mistral\")\n", + "task = DaysOfWeekTask(device, model_name=\"mistral\", dtype=dtype)\n", "acts = get_acts(task, layer_fetch=25, token_fetch=task.before_c_token)\n", "\n", "problems = task.generate_problems()\n", @@ -516,7 +519,7 @@ "for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " for model_name in [\"mistral\", \"llama\"]:\n", " results = pd.read_csv(\n", - " f\"{BASE_DIR}/{model_name}_{task_name}/results.csv\", skipinitialspace=True\n", + " Path(BASE_DIR) / f\"{model_name}_{task_name}\" / \"results.csv\", skipinitialspace=True\n", " )\n", " number_correct = results[\"best_token\"] == results[\"ground_truth\"]\n", " print(task_name, model_name, np.sum(number_correct))\n", @@ -524,13 +527,13 @@ "# GPT 2\n", "from transformer_lens import HookedTransformer\n", "\n", - "model = HookedTransformer.from_pretrained(\"gpt2\")\n", + "model = HookedTransformer.from_pretrained(\"gpt2\", device=device, dtype=dtype)\n", "\n", "for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(\"cpu\", model_name=\"gpt2\")\n", + " task = DaysOfWeekTask(device, model_name=\"gpt2\", dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(\"cpu\", model_name=\"gpt2\")\n", + " task = MonthsOfYearTask(device, model_name=\"gpt2\", dtype=dtype)\n", " problems = task.generate_problems()\n", " answer_logits = [model.to_single_token(token) for token in task.allowable_tokens]\n", " num_correct = 0\n", @@ -546,7 +549,7 @@ ], "metadata": { "kernelspec": { - "display_name": "multiplexing", + "display_name": "multid", "language": "python", "name": "python3" }, @@ -560,7 +563,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index 161ef29..1dea61b 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -1,17 +1,20 @@ # %% import os +from pathlib import Path from utils import setup_notebook, BASE_DIR setup_notebook() +import torch import numpy as np import transformer_lens +import torch from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca from task import activation_patching -device = "cuda:4" +device = "cuda:0" if torch.cuda.is_available() else "cpu" # # %% @@ -61,19 +64,20 @@ class MonthsOfYearTask: - def __init__(self, device, model_name="mistral", n_devices=None): + def __init__(self, device, model_name="mistral", n_devices=None, dtype="float32"): self.device = device self.model_name = model_name self.n_devices = n_devices + self.dtype = dtype + # Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall) self.allowable_tokens = months_of_year - self.prefix = f"{BASE_DIR}{model_name}_months_of_year/" - if not os.path.exists(self.prefix): - os.makedirs(self.prefix) + self.prefix = Path(BASE_DIR) / f"{model_name}_months_of_year" + self.prefix.mkdir(parents=True, exist_ok=True) self.num_tokens_in_answer = 1 @@ -159,11 +163,18 @@ def generate_problems(self): def get_model(self): if self.n_devices is None: - self.n_devices = 2 if "llama" == self.model_name else 1 + self.n_devices = ( + min(2, max(1, torch.cuda.device_count())) + if "llama" == self.model_name + else 1 + ) if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( - "mistral-7b", device=self.device, n_devices=self.n_devices + "mistral-7b", + device=self.device, + n_devices=self.n_devices, + dtype=self.dtype, ) elif self.model_name == "llama": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( @@ -171,6 +182,7 @@ def get_model(self): "meta-llama/Meta-Llama-3-8B", device=self.device, n_devices=self.n_devices, + dtype=self.dtype, ) return self._lazy_model diff --git a/intervention/task.py b/intervention/task.py index 5c02857..06cea1c 100644 --- a/intervention/task.py +++ b/intervention/task.py @@ -1,3 +1,4 @@ +from pathlib import Path from utils import BASE_DIR # Need this import to set the huggingface cache directory import os import numpy as np @@ -39,10 +40,10 @@ def generate_and_save_acts( forward_batch_size = 2 num_tokens_to_generate = task.num_tokens_in_answer all_problems = task.generate_problems() - output_file = task.prefix + "results.csv" + output_file = task.prefix / "results.csv" if save_results_csv: - os.makedirs(task.prefix, exist_ok=True) + task.prefix.mkdir(parents=True, exist_ok=True) model_best_addition = "" if not save_best_logit else ", best_token" with open(output_file, "w") as f: f.write( @@ -98,7 +99,7 @@ def generate_and_save_acts( print(tensors.shape) torch.save( tensors, - f"{task.prefix}{save_file_prefix}{current_problem_index}.pt", + task.prefix / f"{save_file_prefix}{current_problem_index}.pt", ) if save_results_csv: @@ -146,7 +147,7 @@ def get_all_acts( all_problems = task.generate_problems() all_problems_already_generated = True for i in range(len(all_problems)): - if not os.path.exists(f"{task.prefix}{save_file_prefix}{i}.pt"): + if not (task.prefix / f"{save_file_prefix}{i}.pt").exists(): all_problems_already_generated = False break if not all_problems_already_generated or force_regenerate: @@ -163,7 +164,9 @@ def get_all_acts( all_acts = [] for i in range(0, len(all_problems)): tensors = torch.load( - f"{task.prefix}{save_file_prefix}{i}.pt", map_location="cpu" + task.prefix / f"{save_file_prefix}{i}.pt", + map_location="cpu", + weights_only=False, ) all_acts.append(tensors) if len(all_acts) > 1: @@ -186,9 +189,9 @@ def get_acts( if save_file_prefix != "" and save_file_prefix[-1] != "_": save_file_prefix += "_" file_name = ( - f"{task.prefix}{save_file_prefix}layer{layer_fetch}_token{token_fetch}.pt" + task.prefix / f"{save_file_prefix}layer{layer_fetch}_token{token_fetch}.pt" ) - if not os.path.exists(file_name) or force_regenerate: + if not file_name.exists() or force_regenerate: print(file_name, "not exists") all_acts = get_all_acts( task, names_filter=names_filter, save_file_prefix=save_file_prefix @@ -196,12 +199,12 @@ def get_acts( for layer in range(all_acts.shape[1]): for token in range(all_acts.shape[2]): file_name = ( - f"{task.prefix}{save_file_prefix}layer{layer}_token{token}.pt" + task.prefix / f"{save_file_prefix}layer{layer}_token{token}.pt" ) torch.save( all_acts[:, layer, token, :].detach().cpu().clone(), file_name ) - data = torch.load(file_name) + data = torch.load(file_name, weights_only=False) if normalize_rms: eps = 1e-5 scale = (data.pow(2).mean(-1, keepdim=True) + eps).sqrt() @@ -218,11 +221,21 @@ def get_acts_pca( names_filter=lambda x: "resid_post" in x or "hook_embed" in x, save_file_prefix="", ): - act_file_name = f"{task.prefix}pca/{save_file_prefix}/layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pt" - pca_pkl_file_name = f"{task.prefix}pca/{save_file_prefix}/layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pkl" - os.makedirs(f"{task.prefix}/pca/{save_file_prefix}", exist_ok=True) + act_file_name = ( + task.prefix + / "pca" + / save_file_prefix + / f"layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pt" + ) + pca_pkl_file_name = ( + task.prefix + / "pca" + / save_file_prefix + / f"layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pkl" + ) + (task.prefix / "pca" / save_file_prefix).mkdir(parents=True, exist_ok=True) - if not os.path.exists(act_file_name) or not os.path.exists(pca_pkl_file_name): + if not act_file_name.exists() or not pca_pkl_file_name.exists(): acts = get_acts( task, layer, @@ -235,13 +248,23 @@ def get_acts_pca( pca_acts = pca_object.transform(acts) torch.save(pca_acts, act_file_name) pkl.dump(pca_object, open(pca_pkl_file_name, "wb")) - return torch.load(act_file_name), pkl.load(open(pca_pkl_file_name, "rb")) + return torch.load(act_file_name, weights_only=False), pkl.load( + open(pca_pkl_file_name, "rb") + ) def get_acts_pls(task, layer, token, pls_k, normalize_rms=False): - act_file_name = f"{task.prefix}/pls/layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pt" - pls_pkl_file_name = f"{task.prefix}/pls/layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pkl" - os.makedirs(f"{task.prefix}/pls", exist_ok=True) + act_file_name = ( + task.prefix + / "pls" + / f"layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pt" + ) + pls_pkl_file_name = ( + task.prefix + / "pls" + / f"layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pkl" + ) + (task.prefix / "pls").mkdir(parents=True, exist_ok=True) # if not os.path.exists(act_file_name) or not os.path.exists(pls_pkl_file_name): if True: @@ -255,7 +278,9 @@ def get_acts_pls(task, layer, token, pls_k, normalize_rms=False): torch.save(torch.tensor(pls_acts), act_file_name) pkl.dump(pls, open(pls_pkl_file_name, "wb")) - return torch.load(act_file_name), pkl.load(open(pls_pkl_file_name, "rb")) + return torch.load(act_file_name, weights_only=False), pkl.load( + open(pls_pkl_file_name, "rb") + ) def _set_plotting_sizes(): diff --git a/intervention/utils.py b/intervention/utils.py index 216a26c..dcc1e97 100644 --- a/intervention/utils.py +++ b/intervention/utils.py @@ -1,9 +1,10 @@ import os import dill as pickle +from pathlib import Path -BASE_DIR = "/data/scratch/jae/" +BASE_DIR = Path(os.environ.get("BASE_DIR", Path(__file__).parent.parent / "cache")) -os.environ["TRANSFORMERS_CACHE"] = f"{BASE_DIR}/.cache/" +os.environ["TRANSFORMERS_CACHE"] = f"{(Path(BASE_DIR) / '.cache').absolute()}/" def setup_notebook(): diff --git a/requirements.in b/requirements.in new file mode 100644 index 0000000..233f48f --- /dev/null +++ b/requirements.in @@ -0,0 +1,144 @@ +jupyterlab +ipywidgets +numpy +pandas +torch +tensorflow +transformers +huggingface_hub +hf +accelerate +adjustText +aiohappyeyeballs +aiohttp +aiosignal +anyio +asttokens +async-timeout +attrs +automated-interpretability +babe +beartype +better-abc +blobfile +boostedblob +certifi +charset-normalizer +circuitsvis +click +comm +config2py +contourpy +cycler +datasets +debugpy +decorator +dill +docker-pycreds +dol +einops +exceptiongroup +executing +fancy-einsum +filelock +fonttools +frozenlist +fsspec +gitdb +GitPython +gprof2dot +graze +h11 +httpcore +httpx +i2 +idna +importlib_metadata +importlib_resources +iniconfig +jaxtyping +jedi +Jinja2 +joblib +jupyter_client +jupyter_core +kiwisolver +lxml +markdown-it-py +MarkupSafe +matplotlib +matplotlib-inline +mdurl +mpmath +multidict +multiprocess +nest-asyncio +networkx +nltk +orjson +packaging +parso +patsy +pexpect +pillow +platformdirs +plotly +plotly-express +pluggy +prompt_toolkit +protobuf +psutil +ptyprocess +pure_eval +py2store +pyarrow +pyarrow-hotfix +pycryptodomex +Pygments +pyparsing +pytest +pytest-profiling +python-dateutil +python-dotenv +pytz +PyYAML +pyzmq +regex +requests +rich +sae-lens +safetensors +scikit-learn +scipy +sentencepiece +sentry-sdk +setproctitle +shellingham +six +smmap +sniffio +stack-data +statsmodels +sympy +tenacity +threadpoolctl +tiktoken +tokenizers +tomli +tornado +tqdm +traitlets +transformer-lens +triton +typeguard +typer +typing_extensions +tzdata +urllib3 +uvloop +wandb +wcwidth +xxhash +yarl +zipp +zstandard diff --git a/requirements.txt b/requirements.txt index 11d7bfb..ea3f7c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,154 +1,1008 @@ -accelerate==0.33.0 -adjustText==1.2.0 -aiohappyeyeballs==2.3.4 -aiohttp==3.10.0 -aiosignal==1.3.1 -anyio==4.4.0 -asttokens==2.4.1 -async-timeout==4.0.3 -attrs==24.1.0 -automated-interpretability==0.0.5 +# +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: +# +# pip-compile requirements.in +# +#euporie==2.10.4 + # euporie +absl-py==2.4.0 + # via + # keras + # tensorboard + # tensorflow +accelerate==1.12.0 + # via + # -r requirements.in + # transformer-lens +adjusttext==1.3.0 + # via -r requirements.in +aiohappyeyeballs==2.6.1 + # via + # -r requirements.in + # aiohttp +aiohttp==3.13.3 + # via + # -r requirements.in + # boostedblob + # fsspec +aiosignal==1.4.0 + # via + # -r requirements.in + # aiohttp +annotated-doc==0.0.4 + # via typer +annotated-types==0.7.0 + # via pydantic +anyio==4.12.1 + # via + # -r requirements.in + # httpx + # jupyter-server +argon2-cffi==25.1.0 + # via jupyter-server +argon2-cffi-bindings==25.1.0 + # via argon2-cffi +arrow==1.4.0 + # via isoduration +asttokens==3.0.1 + # via + # -r requirements.in + # stack-data +astunparse==1.6.3 + # via tensorflow +async-lru==2.1.0 + # via jupyterlab +async-timeout==5.0.1 + # via -r requirements.in +attrs==25.4.0 + # via + # -r requirements.in + # aiohttp + # jsonschema + # referencing +automated-interpretability==0.0.23 + # via -r requirements.in babe==0.0.7 + # via + # -r requirements.in + # sae-lens +babel==2.18.0 + # via jupyterlab-server beartype==0.14.1 + # via + # -r requirements.in + # transformer-lens +beautifulsoup4==4.14.3 + # via nbconvert better-abc==0.0.3 + # via + # -r requirements.in + # transformer-lens +bleach[css]==6.3.0 + # via nbconvert blobfile==2.1.1 -boostedblob==0.15.4 -certifi==2024.7.4 -charset-normalizer==3.3.2 -circuitsvis==1.43.2 -click==8.1.7 -comm==0.2.2 -config2py==0.1.36 -contourpy==1.2.1 + # via + # -r requirements.in + # automated-interpretability +boostedblob==0.15.6 + # via + # -r requirements.in + # automated-interpretability +certifi==2026.1.4 + # via + # -r requirements.in + # httpcore + # httpx + # requests + # sentry-sdk +cffi==2.0.0 + # via argon2-cffi-bindings +charset-normalizer==3.4.4 + # via + # -r requirements.in + # requests +circuitsvis==1.43.3 + # via -r requirements.in +click==8.3.1 + # via + # -r requirements.in + # nltk + # typer + # wandb +comm==0.2.3 + # via + # -r requirements.in + # ipykernel + # ipywidgets +config2py==0.1.46 + # via + # -r requirements.in + # py2store +contourpy==1.3.3 + # via + # -r requirements.in + # matplotlib +cuda-bindings==12.9.4 + # via torch +cuda-pathfinder==1.3.4 + # via cuda-bindings cycler==0.12.1 -datasets==2.20.0 -debugpy>=1.6.5 -decorator==5.1.1 -dill==0.3.8 + # via + # -r requirements.in + # matplotlib +datasets==4.5.0 + # via + # -r requirements.in + # sae-lens + # transformer-lens +debugpy==1.8.20 + # via + # -r requirements.in + # ipykernel +decorator==5.2.1 + # via + # -r requirements.in + # ipython +defusedxml==0.7.1 + # via nbconvert +dill==0.4.0 + # via + # -r requirements.in + # datasets + # multiprocess docker-pycreds==0.4.0 -dol==0.2.55 -einops==0.8.0 -exceptiongroup==1.2.2 -executing==2.0.1 + # via -r requirements.in +docstring-parser==0.17.0 + # via simple-parsing +dol==0.3.38 + # via + # -r requirements.in + # config2py + # graze + # py2store +einops==0.8.2 + # via + # -r requirements.in + # transformer-lens +exceptiongroup==1.3.1 + # via -r requirements.in +executing==2.2.1 + # via + # -r requirements.in + # stack-data fancy-einsum==0.0.3 -filelock==3.15.4 -fonttools==4.53.1 -frozenlist==1.4.1 -fsspec==2024.5.0 -gitdb==4.0.11 -GitPython==3.1.43 -gprof2dot==2024.6.6 -graze==0.1.24 -h11==0.14.0 -httpcore==1.0.5 -httpx==0.27.0 -huggingface-hub==0.24.5 -i2==0.1.18 -idna==3.7 -importlib_metadata==8.2.0 -importlib_resources==6.4.0 -iniconfig==2.0.0 -ipykernel==6.29.5 -ipython==8.26.0 -jaxtyping==0.2.33 -jedi==0.19.1 -Jinja2==3.1.4 -joblib==1.4.2 -jupyter_client==8.6.2 -jupyter_core==5.7.2 -kiwisolver==1.4.5 + # via + # -r requirements.in + # transformer-lens +fastjsonschema==2.21.2 + # via nbformat +filelock==3.24.2 + # via + # -r requirements.in + # blobfile + # datasets + # hf + # huggingface-hub + # torch + # transformers +flatbuffers==25.12.19 + # via tensorflow +fonttools==4.61.1 + # via + # -r requirements.in + # matplotlib +fqdn==1.5.1 + # via jsonschema +frozenlist==1.8.0 + # via + # -r requirements.in + # aiohttp + # aiosignal +fsspec[http]==2025.10.0 + # via + # -r requirements.in + # datasets + # hf + # huggingface-hub + # torch +gast==0.7.0 + # via tensorflow +gitdb==4.0.12 + # via + # -r requirements.in + # gitpython +gitpython==3.1.46 + # via + # -r requirements.in + # wandb +google-pasta==0.2.0 + # via tensorflow +gprof2dot==2025.4.14 + # via + # -r requirements.in + # pytest-profiling +graze==0.1.39 + # via + # -r requirements.in + # babe +grpcio==1.78.0 + # via + # tensorboard + # tensorflow +h11==0.16.0 + # via + # -r requirements.in + # httpcore +h5py==3.15.1 + # via + # keras + # tensorflow +hf==1.1.0 + # via -r requirements.in +hf-xet==1.2.0 + # via + # hf + # huggingface-hub +httpcore==1.0.9 + # via + # -r requirements.in + # httpx +httpx==0.28.1 + # via + # -r requirements.in + # automated-interpretability + # datasets + # hf + # jupyterlab +huggingface-hub==0.36.2 + # via + # -r requirements.in + # accelerate + # datasets + # tokenizers + # transformer-lens + # transformers +i2==0.1.63 + # via + # -r requirements.in + # config2py +idna==3.11 + # via + # -r requirements.in + # anyio + # httpx + # jsonschema + # requests + # yarl +importlib-metadata==8.7.1 + # via + # -r requirements.in + # circuitsvis +importlib-resources==6.5.2 + # via + # -r requirements.in + # py2store +iniconfig==2.3.0 + # via + # -r requirements.in + # pytest +ipykernel==7.2.0 + # via jupyterlab +ipython==9.10.0 + # via + # ipykernel + # ipywidgets +ipython-pygments-lexers==1.1.1 + # via ipython +ipywidgets==8.1.8 + # via -r requirements.in +isoduration==20.11.0 + # via jsonschema +jaxtyping==0.3.9 + # via + # -r requirements.in + # transformer-lens +jedi==0.19.2 + # via + # -r requirements.in + # ipython +jinja2==3.1.6 + # via + # -r requirements.in + # jupyter-server + # jupyterlab + # jupyterlab-server + # nbconvert + # torch +joblib==1.5.3 + # via + # -r requirements.in + # nltk + # scikit-learn +json5==0.13.0 + # via jupyterlab-server +jsonpointer==3.0.0 + # via jsonschema +jsonschema[format-nongpl]==4.26.0 + # via + # jupyter-events + # jupyterlab-server + # nbformat +jsonschema-specifications==2025.9.1 + # via jsonschema +jupyter-client==8.8.0 + # via + # -r requirements.in + # ipykernel + # jupyter-server + # nbclient +jupyter-core==5.9.1 + # via + # -r requirements.in + # ipykernel + # jupyter-client + # jupyter-server + # jupyterlab + # nbclient + # nbconvert + # nbformat +jupyter-events==0.12.0 + # via jupyter-server +jupyter-lsp==2.3.0 + # via jupyterlab +jupyter-server==2.17.0 + # via + # jupyter-lsp + # jupyterlab + # jupyterlab-server + # notebook-shim +jupyter-server-terminals==0.5.4 + # via jupyter-server +jupyterlab==4.5.3 + # via -r requirements.in +jupyterlab-pygments==0.3.0 + # via nbconvert +jupyterlab-server==2.28.0 + # via jupyterlab +jupyterlab-widgets==3.0.16 + # via ipywidgets +keras==3.13.2 + # via tensorflow +kiwisolver==1.4.9 + # via + # -r requirements.in + # matplotlib +lark==1.3.1 + # via rfc3987-syntax +libclang==18.1.1 + # via tensorflow lxml==4.9.4 -markdown-it-py==3.0.0 -MarkupSafe==2.1.5 -matplotlib==3.9.1 -matplotlib-inline==0.1.7 + # via + # -r requirements.in + # blobfile + # boostedblob +markdown==3.10.2 + # via tensorboard +markdown-it-py==4.0.0 + # via + # -r requirements.in + # rich +markupsafe==3.0.3 + # via + # -r requirements.in + # jinja2 + # nbconvert + # werkzeug +matplotlib==3.10.8 + # via + # -r requirements.in + # adjusttext +matplotlib-inline==0.2.1 + # via + # -r requirements.in + # ipykernel + # ipython mdurl==0.1.2 + # via + # -r requirements.in + # markdown-it-py +mistune==3.2.0 + # via nbconvert +ml-dtypes==0.5.4 + # via + # keras + # tensorflow mpmath==1.3.0 -multidict==6.0.5 -multiprocess==0.70.16 + # via + # -r requirements.in + # sympy +multidict==6.7.1 + # via + # -r requirements.in + # aiohttp + # yarl +multiprocess==0.70.18 + # via + # -r requirements.in + # datasets +namex==0.1.0 + # via keras +narwhals==2.16.0 + # via plotly +nbclient==0.10.4 + # via nbconvert +nbconvert==7.17.0 + # via jupyter-server +nbformat==5.10.4 + # via + # jupyter-server + # nbclient + # nbconvert nest-asyncio==1.6.0 -networkx==3.3 -nltk==3.8.1 + # via + # -r requirements.in + # ipykernel +networkx==3.6.1 + # via + # -r requirements.in + # torch +nltk==3.9.2 + # via + # -r requirements.in + # sae-lens +notebook-shim==0.2.4 + # via jupyterlab numpy==1.26.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==8.9.2.26 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.18.1 -nvidia-nvjitlink-cu12==12.6.20 -nvidia-nvtx-cu12==12.1.105 -orjson==3.10.6 -packaging==24.1 -pandas==2.2.2 -parso==0.8.4 -patsy==0.5.6 + # via + # -r requirements.in + # accelerate + # adjusttext + # automated-interpretability + # circuitsvis + # contourpy + # datasets + # h5py + # keras + # matplotlib + # ml-dtypes + # pandas + # patsy + # plotly-express + # scikit-learn + # scipy + # statsmodels + # tensorboard + # tensorflow + # transformer-lens + # transformers +nvidia-cublas-cu12==12.8.4.1 + # via + # -r requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.8.90 + # via + # -r requirements.in + # torch +nvidia-cuda-nvrtc-cu12==12.8.93 + # via + # -r requirements.in + # torch +nvidia-cuda-runtime-cu12==12.8.90 + # via + # -r requirements.in + # torch +nvidia-cudnn-cu12==9.10.2.21 + # via + # -r requirements.in + # torch +nvidia-cufft-cu12==11.3.3.83 + # via + # -r requirements.in + # torch +nvidia-cufile-cu12==1.13.1.3 + # via torch +nvidia-curand-cu12==10.3.9.90 + # via + # -r requirements.in + # torch +nvidia-cusolver-cu12==11.7.3.90 + # via + # -r requirements.in + # torch +nvidia-cusparse-cu12==12.5.8.93 + # via + # -r requirements.in + # nvidia-cusolver-cu12 + # torch +nvidia-cusparselt-cu12==0.7.1 + # via torch +nvidia-nccl-cu12==2.27.5 + # via + # -r requirements.in + # torch +nvidia-nvjitlink-cu12==12.8.93 + # via + # -r requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvshmem-cu12==3.4.5 + # via torch +nvidia-nvtx-cu12==12.8.90 + # via + # -r requirements.in + # torch +opt-einsum==3.4.0 + # via tensorflow +optree==0.18.0 + # via keras +orjson==3.11.7 + # via + # -r requirements.in + # automated-interpretability +packaging==26.0 + # via + # -r requirements.in + # accelerate + # datasets + # hf + # huggingface-hub + # ipykernel + # jupyter-events + # jupyter-server + # jupyterlab + # jupyterlab-server + # keras + # matplotlib + # nbconvert + # plotly + # pytest + # statsmodels + # tensorboard + # tensorflow + # transformers + # wandb + # wheel +pandas==3.0.1 + # via + # -r requirements.in + # babe + # datasets + # plotly-express + # statsmodels + # transformer-lens +pandocfilters==1.5.1 + # via nbconvert +parso==0.8.6 + # via + # -r requirements.in + # jedi +patsy==1.0.2 + # via + # -r requirements.in + # plotly-express + # statsmodels pexpect==4.9.0 -pillow==10.4.0 -platformdirs==4.2.2 -plotly==5.23.0 + # via + # -r requirements.in + # ipython +pillow==12.1.1 + # via + # -r requirements.in + # matplotlib + # tensorboard +platformdirs==4.9.2 + # via + # -r requirements.in + # jupyter-core + # wandb +plotly==6.5.2 + # via + # -r requirements.in + # plotly-express + # sae-lens plotly-express==0.4.1 -pluggy==1.5.0 -prompt_toolkit==3.0.47 -protobuf==5.27.3 -psutil==6.0.0 + # via + # -r requirements.in + # sae-lens +pluggy==1.6.0 + # via + # -r requirements.in + # pytest +prometheus-client==0.24.1 + # via jupyter-server +prompt-toolkit==3.0.52 + # via + # -r requirements.in + # ipython +propcache==0.4.1 + # via + # aiohttp + # yarl +protobuf==6.33.5 + # via + # -r requirements.in + # tensorboard + # tensorflow + # transformer-lens + # wandb +psutil==7.2.2 + # via + # -r requirements.in + # accelerate + # ipykernel ptyprocess==0.7.0 -pure_eval==0.2.3 -py2store==0.1.20 -pyarrow==17.0.0 -pyarrow-hotfix==0.6 -pycryptodomex==3.20.0 -Pygments==2.18.0 -pyparsing==3.1.2 -pytest==8.3.2 -pytest-profiling==1.7.0 + # via + # -r requirements.in + # pexpect + # terminado +pure-eval==0.2.3 + # via + # -r requirements.in + # stack-data +py2store==0.1.22 + # via + # -r requirements.in + # babe +pyarrow==23.0.1 + # via + # -r requirements.in + # datasets +pyarrow-hotfix==0.7 + # via -r requirements.in +pycparser==3.0 + # via cffi +pycryptodomex==3.23.0 + # via + # -r requirements.in + # blobfile + # boostedblob +pydantic==2.12.5 + # via wandb +pydantic-core==2.41.5 + # via pydantic +pygments==2.19.2 + # via + # -r requirements.in + # ipython + # ipython-pygments-lexers + # nbconvert + # pytest + # rich +pyparsing==3.3.2 + # via + # -r requirements.in + # matplotlib +pytest==9.0.2 + # via + # -r requirements.in + # pytest-profiling +pytest-profiling==1.8.1 + # via -r requirements.in python-dateutil==2.9.0.post0 -python-dotenv==1.0.1 -pytz==2024.1 -PyYAML==6.0.1 -pyzmq==26.0.0 -regex==2024.7.24 -requests==2.32.3 -rich==13.7.1 -sae-lens==3.13.1 -safetensors==0.4.3 -scikit-learn==1.5.1 -scipy==1.14.0 -sentencepiece==0.2.0 -sentry-sdk==2.12.0 -setproctitle==1.3.3 + # via + # -r requirements.in + # arrow + # jupyter-client + # matplotlib + # pandas +python-dotenv==1.2.1 + # via + # -r requirements.in + # sae-lens +python-json-logger==4.0.0 + # via jupyter-events +pytz==2025.2 + # via -r requirements.in +pyyaml==6.0.3 + # via + # -r requirements.in + # accelerate + # datasets + # hf + # huggingface-hub + # jupyter-events + # sae-lens + # transformers + # wandb +pyzmq==27.1.0 + # via + # -r requirements.in + # ipykernel + # jupyter-client + # jupyter-server +referencing==0.37.0 + # via + # jsonschema + # jsonschema-specifications + # jupyter-events +regex==2026.1.15 + # via + # -r requirements.in + # nltk + # tiktoken + # transformers +requests==2.32.5 + # via + # -r requirements.in + # datasets + # graze + # huggingface-hub + # jupyterlab-server + # tensorflow + # tiktoken + # transformers + # wandb +rfc3339-validator==0.1.4 + # via + # jsonschema + # jupyter-events +rfc3986-validator==0.1.1 + # via + # jsonschema + # jupyter-events +rfc3987-syntax==1.1.0 + # via jsonschema +rich==14.3.2 + # via + # -r requirements.in + # keras + # transformer-lens + # typer +rpds-py==0.30.0 + # via + # jsonschema + # referencing +sae-lens==6.37.1 + # via -r requirements.in +safetensors==0.7.0 + # via + # -r requirements.in + # accelerate + # sae-lens + # transformers +scikit-learn==1.8.0 + # via + # -r requirements.in + # automated-interpretability +scipy==1.17.0 + # via + # -r requirements.in + # adjusttext + # plotly-express + # scikit-learn + # statsmodels +send2trash==2.1.0 + # via jupyter-server +sentencepiece==0.2.1 + # via + # -r requirements.in + # transformer-lens +sentry-sdk==2.53.0 + # via + # -r requirements.in + # wandb +setproctitle==1.3.7 + # via -r requirements.in shellingham==1.5.4 -six==1.16.0 -smmap==5.0.1 + # via + # -r requirements.in + # hf + # typer +simple-parsing==0.1.8 + # via sae-lens +six==1.17.0 + # via + # -r requirements.in + # astunparse + # docker-pycreds + # google-pasta + # pytest-profiling + # python-dateutil + # rfc3339-validator + # tensorflow +smmap==5.0.2 + # via + # -r requirements.in + # gitdb sniffio==1.3.1 + # via -r requirements.in +soupsieve==2.8.3 + # via beautifulsoup4 stack-data==0.6.3 -statsmodels==0.14.2 -sympy==1.13.1 -tenacity==9.0.0 -threadpoolctl==3.5.0 -tiktoken==0.6.0 -tokenizers==0.19.1 -tomli==2.0.1 -torch==2.1.2 -tornado==6.4.1 -tqdm==4.66.4 + # via + # -r requirements.in + # ipython +statsmodels==0.14.6 + # via + # -r requirements.in + # plotly-express +sympy==1.14.0 + # via + # -r requirements.in + # torch +tenacity==9.1.4 + # via + # -r requirements.in + # sae-lens +tensorboard==2.20.0 + # via tensorflow +tensorboard-data-server==0.7.2 + # via tensorboard +tensorflow==2.20.0 + # via -r requirements.in +termcolor==3.3.0 + # via tensorflow +terminado==0.18.1 + # via + # jupyter-server + # jupyter-server-terminals +threadpoolctl==3.6.0 + # via + # -r requirements.in + # scikit-learn +tiktoken==0.12.0 + # via + # -r requirements.in + # automated-interpretability +tinycss2==1.4.0 + # via bleach +tokenizers==0.22.2 + # via + # -r requirements.in + # transformers +tomli==2.4.0 + # via -r requirements.in +torch==2.10.0 + # via + # -r requirements.in + # accelerate + # circuitsvis + # transformer-lens +tornado==6.5.4 + # via + # -r requirements.in + # ipykernel + # jupyter-client + # jupyter-server + # jupyterlab + # terminado +tqdm==4.67.3 + # via + # -r requirements.in + # datasets + # hf + # huggingface-hub + # nltk + # transformer-lens + # transformers traitlets==5.14.3 -transformer-lens==2.3.0 -transformers==4.43.3 -triton==2.1.0 -typeguard==2.13.3 -typer==0.12.3 -typing_extensions==4.12.2 -tzdata==2024.1 -urllib3==2.2.2 -uvloop==0.19.0 -wandb==0.17.5 -wcwidth==0.2.13 -xxhash==3.4.1 -yarl==1.9.4 -zipp==3.19.2 -zstandard==0.22.0 + # via + # -r requirements.in + # ipykernel + # ipython + # ipywidgets + # jupyter-client + # jupyter-core + # jupyter-events + # jupyter-server + # jupyterlab + # matplotlib-inline + # nbclient + # nbconvert + # nbformat +transformer-lens==2.17.0 + # via + # -r requirements.in + # sae-lens +transformers==4.57.6 + # via + # -r requirements.in + # sae-lens + # transformer-lens + # transformers-stream-generator +transformers-stream-generator==0.0.5 + # via transformer-lens +triton==3.6.0 + # via + # -r requirements.in + # torch +typeguard==4.5.0 + # via + # -r requirements.in + # transformer-lens +typer==0.24.0 + # via + # -r requirements.in + # typer-slim +typer-slim==0.24.0 + # via + # hf + # transformers +typing-extensions==4.15.0 + # via + # -r requirements.in + # aiosignal + # anyio + # beautifulsoup4 + # exceptiongroup + # grpcio + # hf + # huggingface-hub + # optree + # pydantic + # pydantic-core + # referencing + # sae-lens + # simple-parsing + # tensorflow + # torch + # transformer-lens + # typeguard + # typing-inspection + # wandb +typing-inspection==0.4.2 + # via pydantic +tzdata==2025.3 + # via + # -r requirements.in + # arrow +uri-template==1.3.0 + # via jsonschema +urllib3==2.6.3 + # via + # -r requirements.in + # blobfile + # requests + # sentry-sdk +uvloop==0.22.1 + # via + # -r requirements.in + # boostedblob +wadler-lindig==0.1.7 + # via jaxtyping +wandb==0.25.0 + # via + # -r requirements.in + # transformer-lens +wcwidth==0.6.0 + # via + # -r requirements.in + # prompt-toolkit +webcolors==25.10.0 + # via jsonschema +webencodings==0.5.1 + # via + # bleach + # tinycss2 +websocket-client==1.9.0 + # via jupyter-server +werkzeug==3.1.5 + # via tensorboard +wheel==0.46.3 + # via astunparse +widgetsnbextension==4.0.15 + # via ipywidgets +wrapt==2.1.1 + # via tensorflow +xxhash==3.6.0 + # via + # -r requirements.in + # datasets +yarl==1.22.0 + # via + # -r requirements.in + # aiohttp +zipp==3.23.0 + # via + # -r requirements.in + # importlib-metadata +zstandard==0.25.0 + # via -r requirements.in + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/sae_multid_feature_discovery/generate_feature_occurence_data.py b/sae_multid_feature_discovery/generate_feature_occurence_data.py index 63733aa..f12dff5 100644 --- a/sae_multid_feature_discovery/generate_feature_occurence_data.py +++ b/sae_multid_feature_discovery/generate_feature_occurence_data.py @@ -1,12 +1,12 @@ # %% - +from pathlib import Path import os from utils import BASE_DIR # hopefully this will help with memory fragmentation os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" -os.environ["TRANSFORMERS_CACHE"] = f"{BASE_DIR}.cache/" +os.environ["TRANSFORMERS_CACHE"] = f"{(Path(BASE_DIR) / '.cache').absolute()}/" import einops import numpy as np @@ -32,7 +32,7 @@ model_name = "mistral-7b" batch_size = 16 layers_to_evaluate = [8, 16, 24] - num_devices = 2 + num_devices = max(1, t.cuda.device_count()) sae_hidden_size = 65536 else: @@ -44,16 +44,17 @@ num_workers = 8 sae_hidden_size = 24576 +tl_model_name = "gpt2" if model_name == "gpt-2" else model_name model = transformer_lens.HookedTransformer.from_pretrained( - model_name, device=device, n_devices=num_devices + tl_model_name, device=device, n_devices=num_devices ) ctx_len = 256 num_sae_activations_to_save = 10**9 -save_folder = f"{BASE_DIR}{model_name}" -os.makedirs(save_folder, exist_ok=True) +save_folder = Path(BASE_DIR) / model_name +save_folder.mkdir(exist_ok=True, parents=True) t.set_grad_enabled(False) @@ -139,8 +140,12 @@ def next_batch_activations(): forward_pass = ae.forward(activations) if isinstance(forward_pass, tuple): hidden_sae = forward_pass[1] - else: + elif hasattr(forward_pass, "feature_acts"): hidden_sae = forward_pass.feature_acts + else: + # Newer sae_lens returns reconstructed tensor from forward(); + # use encode() to get feature activations instead + hidden_sae = ae.encode(activations) nonzero_sae = hidden_sae.abs() > 1e-6 nonzero_sae_values = hidden_sae[nonzero_sae] diff --git a/sae_multid_feature_discovery/saes/sparse_autoencoder.py b/sae_multid_feature_discovery/saes/sparse_autoencoder.py index 85a9f57..ba61316 100755 --- a/sae_multid_feature_discovery/saes/sparse_autoencoder.py +++ b/sae_multid_feature_discovery/saes/sparse_autoencoder.py @@ -182,10 +182,12 @@ def load_from_pretrained(cls, path: str): if path.endswith(".pt"): try: if torch.backends.mps.is_available(): - state_dict = torch.load(path, map_location="mps") + state_dict = torch.load( + path, map_location="mps", weights_only=False + ) state_dict["cfg"].device = "mps" else: - state_dict = torch.load(path) + state_dict = torch.load(path, weights_only=False) except Exception as e: raise IOError(f"Error loading the state dictionary from .pt file: {e}") diff --git a/sae_multid_feature_discovery/utils.py b/sae_multid_feature_discovery/utils.py index a229782..370290f 100644 --- a/sae_multid_feature_discovery/utils.py +++ b/sae_multid_feature_discovery/utils.py @@ -1,7 +1,9 @@ +from pathlib import Path from huggingface_hub import hf_hub_download import os -BASE_DIR = "/data/scratch/jae/" +BASE_DIR = Path(os.environ.get("BASE_DIR", Path(__file__).parent.parent / "cache")) + def get_gpt2_sae(device, layer): from sae_lens import SAE @@ -9,7 +11,7 @@ def get_gpt2_sae(device, layer): return SAE.from_pretrained( release="gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml sae_id=f"blocks.{layer}.hook_resid_pre", # won't always be a hook point - device=device + device=device, )[0]