diff --git a/sae_multid_feature_discovery/generate_feature_occurence_data.py b/sae_multid_feature_discovery/generate_feature_occurence_data.py index 63733aa..7301f76 100644 --- a/sae_multid_feature_discovery/generate_feature_occurence_data.py +++ b/sae_multid_feature_discovery/generate_feature_occurence_data.py @@ -139,8 +139,12 @@ def next_batch_activations(): forward_pass = ae.forward(activations) if isinstance(forward_pass, tuple): hidden_sae = forward_pass[1] - else: + elif hasattr(forward_pass, "feature_acts"): hidden_sae = forward_pass.feature_acts + else: + # Newer sae_lens returns reconstructed tensor from forward(); + # use encode() to get feature activations instead + hidden_sae = ae.encode(activations) nonzero_sae = hidden_sae.abs() > 1e-6 nonzero_sae_values = hidden_sae[nonzero_sae]