From dd818247fb3421ec1a4721e1347e329c886e620d Mon Sep 17 00:00:00 2001 From: Zaki Alaoui Date: Tue, 19 Aug 2025 18:27:11 +0000 Subject: [PATCH 1/6] move over to leg --- bscope/ic/semantic.py | 13 ++++++++++++- bscope/scope.py | 3 +-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/bscope/ic/semantic.py b/bscope/ic/semantic.py index 7da7dbf..a06c901 100644 --- a/bscope/ic/semantic.py +++ b/bscope/ic/semantic.py @@ -514,4 +514,15 @@ def get_channels(self, mode_indices: np.ndarray, n: int = 10, concept_idx=None, important_channels.extend(channels) - return list(important_channels) + return list(set(important_channels)) # Return only unique channel indices + def get_concept_correlations(self, concept_idx: int) -> np.ndarray: + """ + Get correlations between a specific concept and all modes. + + Args: + concept_idx: The index of the concept + + Returns: + Array of correlations between the concept and each mode + """ + return self.corr_mtx[:, concept_idx] \ No newline at end of file diff --git a/bscope/scope.py b/bscope/scope.py index c0ad76f..512a03e 100644 --- a/bscope/scope.py +++ b/bscope/scope.py @@ -103,8 +103,7 @@ def log_stop(self): for i in range(self.num_layers): self.log_gradients[i] = np.concatenate(self.log_gradients[i]) self.log_activations[i] = np.concatenate(self.log_activations[i]) - self.log_contributions[i] = np.concatenate( - self.log_contributions[i]) + self.log_contributions[i] = np.concatenate(self.log_contributions[i]) self.logging = False From 89fc04598e0e3be02f579136d9ffe7a190f605fd Mon Sep 17 00:00:00 2001 From: Zaki Alaoui Date: Tue, 19 Aug 2025 21:52:26 +0000 Subject: [PATCH 2/6] edit --- bscope/ic/semantic_analyzer.py | 101 +-------------------------------- 1 file changed, 1 insertion(+), 100 deletions(-) diff --git a/bscope/ic/semantic_analyzer.py b/bscope/ic/semantic_analyzer.py index 0b2f40d..639604e 100644 --- a/bscope/ic/semantic_analyzer.py +++ b/bscope/ic/semantic_analyzer.py @@ -417,103 +417,4 @@ def get_concepts_from_path(self, concept): return ordered_names - -if __name__ == "__main__": - sem = SemanticAnalyzer() - embed() - - - - - Args: - n: Number of top channels to return (or number of std deviations if method='std') - concept_idx: The index of the concept - method: Method to select channels ('argsort' or 'std') - - Returns: - Array of important channel indices - """ - avg_contribution = self.get_average_contribution(concept_idx) - - if method == 'argsort': - top_channels = np.argsort(avg_contribution)[-n:][::-1] - elif method == 'std': - mean = np.mean(avg_contribution) - std = np.std(avg_contribution) - threshold = mean + n * std - top_channels = np.where(avg_contribution > threshold)[0] - else: - raise ValueError(f"Unknown method: {method}. Use 'argsort' or 'std'.") - - return list(top_channels) - - def get_top_modes(self, concept_idx: int, n_modes: int = 5) -> np.ndarray: - """ - Get the top modes (features) that are most correlated with a concept. - - Args: - concept_idx: The index of the concept - n_modes: Number of top modes to return - - Returns: - Array of indices of the top n_modes that correlate with the concept - """ - # Sort modes by correlation with the concept - concept_modes = np.argsort(self.corr_mtx[:, concept_idx]) - # Get top n_modes (in descending order) - top_modes = concept_modes[-n_modes:][::-1] - return top_modes - def get_channels(self, mode_indices: np.ndarray, n: int = 10, concept_idx=None, - method: str = 'argsort') -> List[np.ndarray]: - """ - Get the important channels for specified modes. - - Args: - mode_indices: Array of mode indices to analyze - n: Number of top channels per mode to return. If None, returns all channels ranked by importance - concept_idx: Optional concept index (not used in this implementation) - method: Method to select channels ('argsort' or 'std') - - Returns: - List of important channel indices across all modes - """ - important_channels = [] - for mode_idx in mode_indices: - mode = self.dictionary[mode_idx] - - if method == 'argsort': - if n is None: - # Return all channels sorted by importance (highest to lowest) - channels = np.argsort(mode)[::-1] # Full sorted array in descending order - else: - # Get top n channels with highest values - channels = np.argsort(mode)[-n:][::-1] - elif method == 'std': - if n is None: - # For std method with n=None, sort by deviation from mean (highest to lowest) - mean = np.mean(mode) - deviations = (mode - mean) / np.std(mode) - channels = np.argsort(deviations)[::-1] # Full sorted array in descending order - else: - # Get channels that are x standard deviations above the mean - mean = np.mean(mode) - std = np.std(mode) - threshold = mean + n * std # using n as the number of std devs - channels = np.where(mode > threshold)[0] - else: - raise ValueError(f"Unknown method: {method}. Use 'argsort' or 'std' ") - - important_channels.extend(channels) - - return list(set(important_channels)) # Return only unique channel indices - def get_concept_correlations(self, concept_idx: int) -> np.ndarray: - """ - Get correlations between a specific concept and all modes. - - Args: - concept_idx: The index of the concept - - Returns: - Array of correlations between the concept and each mode - """ - return self.corr_mtx[:, concept_idx] + \ No newline at end of file From 0efc416d9b0d995c64ec249021fcc978e8168be5 Mon Sep 17 00:00:00 2001 From: Zaki Alaoui Date: Thu, 25 Sep 2025 04:29:34 +0000 Subject: [PATCH 3/6] surprisal --- bscope/ic/custom_dataset.py | 6 +- bscope/ic/mode_summary.py | 232 ++++++++++++++++++++++++++++----- bscope/ic/models.py | 2 +- bscope/ic/semantic_analyzer.py | 2 +- bscope/sae.py | 6 +- bscope/scope.py | 96 ++++++++++++-- bscope/utils.py | 3 + 7 files changed, 292 insertions(+), 55 deletions(-) diff --git a/bscope/ic/custom_dataset.py b/bscope/ic/custom_dataset.py index 30a5c53..61f0627 100644 --- a/bscope/ic/custom_dataset.py +++ b/bscope/ic/custom_dataset.py @@ -992,9 +992,9 @@ def __init__(self, root: Union[str, Path], split: str = "train", subsample = Non idxs.extend(np.random.choice(range(i * 50, (i + 1) * 50), size=subsample, replace=False).tolist()) self.subsample_idxs = idxs - if self.subsample_idxs is not None: - print('Use a different dataloader without subsampling') - input() + # if self.subsample_idxs is not None: + # print('Use a different dataloader without subsampling') + # input() def __len__(self) -> int: if self.subsample_idxs is not None: diff --git a/bscope/ic/mode_summary.py b/bscope/ic/mode_summary.py index 96ccd6b..36213c6 100644 --- a/bscope/ic/mode_summary.py +++ b/bscope/ic/mode_summary.py @@ -202,7 +202,7 @@ def get_top_modes(self, layer_idx: int, concept_name: str, method: str = 'percen correlations = layer.corr_mtx[:, concept_idx] # Use select_significant_indices to get top modes - modes = bscope.select_significant_indices( + modes = select_significant_indices( correlations, method=method, param=param, @@ -257,7 +257,7 @@ def get_top_channels(self, layer_idx: int, concept_name: str, # Get top channels for this atom - top_channels = bscope.select_significant_indices( + top_channels = select_significant_indices( atom, method=channel_method, param=channel_param, @@ -376,7 +376,7 @@ def find_similar_concepts_by_channels( results = [] # Get all concepts to check - syn = bic.SemanticAnalyzer('/data/codec/hierarchy_metadata/misc/semantic_indexes_test.json') + syn = bic.SemanticAnalyzer('/home/zalaoui/semantic_indexes_test.json') _, imagenet_class_names = syn.get_all_imagenet_masks(list(range(1000))) for concept_label in imagenet_class_names: @@ -454,6 +454,9 @@ def print_similar_concepts( print("=" * 80) for i, (concept, shared_count, overlap_ratio, shared_channels) in enumerate(results[:top_n]): + if concept == 'black_grouse': + print("HOLY BLACK GROUSEπŸ‘€πŸ‘€πŸ‘€") + print(f"{i+1:2d}. {concept:20s} | " f"Shared: {shared_count:2d} | " f"Overlap: {overlap_ratio:.1%}") @@ -478,7 +481,8 @@ def plot_mode_comparison( # Remove seed_top_n_channels parameter channel_method: str = 'std', # Add these parameters to be consistent channel_param: float = 2.0, - concat=False # with find_similar_concepts_by_channels + concat=False, # with find_similar_concepts_by_channels, + text=False ): """ Plot seed concept's top mode and similar concepts' top modes with shared channels highlighted. @@ -549,14 +553,15 @@ def plot_mode_comparison( ax.set_title(f'SEED: {seed_concept} (mode {seed_mode_idx}, corr={seed_correlations[seed_mode_idx]:.3f})', fontsize=8, fontweight='bold', color='blue') ax.set_ylabel('Activation', fontsize=12) + + if text: - # Highlight seed's top channels in blue - for ch in seed_top_channels_set: - ax.axvline(x=ch, color='blue', linestyle='-', alpha=0.2, linewidth=1.5) - - ax.text(0.02, 0.95, f'{len(seed_top_channels_set)} top channels', - transform=ax.transAxes, va='top', fontsize=10, - bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.3)) + for ch in seed_top_channels_set: + ax.axvline(x=ch, color='blue', linestyle='-', alpha=0.2, linewidth=1.5) + + ax.text(0.02, 0.95, f'{len(seed_top_channels_set)} top channels', + transform=ax.transAxes, va='top', fontsize=10, + bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.3)) # 2. Plot SIMILAR CONCEPTS for i, (concept_name, shared_count, overlap_ratio, shared_channels) in enumerate(display_results): @@ -589,20 +594,22 @@ def plot_mode_comparison( f'Shared: {shared_count}/{len(seed_top_channels_set)} ({overlap_ratio:.1%})', fontsize=8) ax.set_ylabel('Activation', fontsize=8) + + if text: - # Highlight shared channels in RED - for ch in shared_channels: - ax.axvline(x=ch, color='red', linestyle='-', alpha=0.2, linewidth=2) - - # Highlight seed's non-shared top channels in light blue - non_shared_seed_channels = seed_top_channels_set - set(shared_channels) - for ch in non_shared_seed_channels: - ax.axvline(x=ch, color='lightblue', linestyle='--', alpha=0.4, linewidth=1) - - # Add legend info - ax.text(0.02, 0.95, f'{len(shared_channels)} shared channels', - transform=ax.transAxes, va='top', fontsize=10, - bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral", alpha=0.7)) + # Highlight shared channels in RED + for ch in shared_channels: + ax.axvline(x=ch, color='red', linestyle='-', alpha=0.2, linewidth=2) + + # Highlight seed's non-shared top channels in light blue + non_shared_seed_channels = seed_top_channels_set - set(shared_channels) + for ch in non_shared_seed_channels: + ax.axvline(x=ch, color='lightblue', linestyle='--', alpha=0.4, linewidth=1) + + # Add legend info + ax.text(0.02, 0.95, f'{len(shared_channels)} shared channels', + transform=ax.transAxes, va='top', fontsize=10, + bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral", alpha=0.7)) except Exception as e: ax.text(0.5, 0.5, f'Error loading {concept_name}:\n{str(e)}', @@ -621,14 +628,15 @@ def plot_mode_comparison( fontsize=8, fontweight='bold') # Create legend - from matplotlib.lines import Line2D - legend_elements = [ - Line2D([0], [0], color='blue', lw=2, label=f'{seed_concept} top channels'), - Line2D([0], [0], color='red', lw=2, label='Shared channels'), - Line2D([0], [0], color='lightblue', lw=1, linestyle='--', alpha=0.6, - label=f'{seed_concept} non-shared channels') - ] - fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.98, 0.98)) + if text: + from matplotlib.lines import Line2D + legend_elements = [ + Line2D([0], [0], color='blue', lw=2, label=f'{seed_concept} top channels'), + Line2D([0], [0], color='red', lw=2, label='Shared channels'), + Line2D([0], [0], color='lightblue', lw=1, linestyle='--', alpha=0.6, + label=f'{seed_concept} non-shared channels') + ] + fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.98, 0.98)) plt.tight_layout() plt.subplots_adjust(top=0.92) # Make room for suptitle and legend @@ -867,7 +875,165 @@ def print_discovery_network(self, discovery_info: Dict): if concept != seed: print(f" {concept}: {channels}") - +def select_significant_indices(vector, method='threshold', param=0.8, min_indices=1, max_indices=None): + """ + Select indices that contribute most to the overall sum of the vector. + + Parameters: + ----------- + vector : array-like + Input vector of values + method : str, optional (default='threshold') + Method to use for selecting indices: + - 'threshold': Select indices that contribute to param (e.g. 0.8) of the total sum + - 'percentile': Select indices above the param (e.g. 90th) percentile + - 'top_n': Select the top param (e.g. 10) indices by value + - 'kmeans': Use k-means clustering to separate significant from non-significant values + - 'otsu': Use Otsu's thresholding method (common in image processing) + param : float or int, optional + Parameter specific to the chosen method + min_indices : int, optional (default=1) + Minimum number of indices to return + max_indices : int, optional (default=None) + Maximum number of indices to return + + Returns: + -------- + significant_indices : numpy.ndarray + Array of indices that contribute most to the total sum + """ + vector = np.asarray(vector) + n = len(vector) + + if max_indices is None: + max_indices = n + + # Handle edge cases + if n == 0: + return np.array([], dtype=int) + + if method == 'threshold': + # Sort indices by their values in descending order + sorted_indices = np.argsort(-vector) + cumsum = np.cumsum(vector[sorted_indices]) + total_sum = cumsum[-1] + + # Find how many indices we need to reach the threshold + if total_sum == 0: # Handle zero-sum case + return np.array([0], dtype=int) + + # Find indices that contribute to param (e.g. 80%) of the total sum + threshold_idx = np.searchsorted(cumsum / total_sum, param) + threshold_idx = max(min_indices, min(threshold_idx + 1, max_indices)) + return sorted_indices[:threshold_idx] + + elif method == 'percentile': + # Select indices above a certain percentile + threshold = np.percentile(vector, 100 - param) + significant_indices = np.where(vector >= threshold)[0] + + # Adjust if we have too few or too many indices + if len(significant_indices) < min_indices: + sorted_indices = np.argsort(-vector) + significant_indices = sorted_indices[:min_indices] + elif len(significant_indices) > max_indices: + sorted_significant = sorted(significant_indices, key=lambda i: -vector[i]) + significant_indices = np.array(sorted_significant[:max_indices]) + + return significant_indices + + elif method == 'top_n': + # Select the top N indices + n_indices = min(max(min_indices, int(param)), max_indices) + return np.argsort(-vector)[:n_indices] + + elif method == 'kmeans': + # Use k-means to separate significant from non-significant values + from sklearn.cluster import KMeans + + # Reshape for k-means + X = vector.reshape(-1, 1) + + # Try to estimate the optimal number of clusters if not specified + if param == 0: + from sklearn.metrics import silhouette_score + scores = [] + for k in range(2, min(10, n)): + kmeans = KMeans(n_clusters=k, random_state=42).fit(X) + score = silhouette_score(X, kmeans.labels_) + scores.append(score) + param = np.argmax(scores) + 2 # Add 2 because we started from k=2 + + # Apply k-means clustering + kmeans = KMeans(n_clusters=int(param), random_state=42).fit(X) + + # Get the cluster with the highest mean value + cluster_means = [np.mean(vector[kmeans.labels_ == i]) for i in range(int(param))] + top_cluster = np.argmax(cluster_means) + + # Get indices belonging to the top cluster + significant_indices = np.where(kmeans.labels_ == top_cluster)[0] + + # Sort by value within the cluster and apply min/max constraints + significant_indices = sorted(significant_indices, key=lambda i: -vector[i]) + significant_indices = np.array(significant_indices[:max_indices]) + + if len(significant_indices) < min_indices: + sorted_indices = np.argsort(-vector) + missing = min_indices - len(significant_indices) + extra_indices = [i for i in sorted_indices if i not in significant_indices][:missing] + significant_indices = np.append(significant_indices, extra_indices) + + return significant_indices + + elif method == 'otsu': + # Otsu's method to find optimal threshold + # Normalize to 0-255 range for Otsu's algorithm + if np.max(vector) == np.min(vector): + # Handle constant vectors + return np.array([0], dtype=int) + + normalized = ((vector - np.min(vector)) / (np.max(vector) - np.min(vector)) * 255).astype(np.uint8) + threshold = signal.threshold_otsu(normalized) + + # Convert back to original scale + original_threshold = threshold / 255 * (np.max(vector) - np.min(vector)) + np.min(vector) + significant_indices = np.where(vector >= original_threshold)[0] + + # Apply min/max constraints + if len(significant_indices) < min_indices: + sorted_indices = np.argsort(-vector) + significant_indices = sorted_indices[:min_indices] + elif len(significant_indices) > max_indices: + sorted_significant = sorted(significant_indices, key=lambda i: -vector[i]) + significant_indices = np.array(sorted_significant[:max_indices]) + + return significant_indices + elif method == 'std': + # Select indices above param standard deviations from the mean + mean_val = np.mean(vector) + std_val = np.std(vector) + + if std_val == 0: # Handle constant vectors + sorted_indices = np.argsort(-vector) + return sorted_indices[:min_indices] + + threshold = mean_val + param * std_val + significant_indices = np.where(vector >= threshold)[0] + + # Adjust if we have too few or too many indices + if len(significant_indices) < min_indices: + sorted_indices = np.argsort(-vector) + significant_indices = sorted_indices[:min_indices] + elif len(significant_indices) > max_indices: + sorted_significant = sorted(significant_indices, key=lambda i: -vector[i]) + significant_indices = np.array(sorted_significant[:max_indices]) + + return significant_indices + + else: + raise ValueError(f"Unknown method: {method}") + # Add these methods to your ModeAnalyzer class by copying them in diff --git a/bscope/ic/models.py b/bscope/ic/models.py index 2c722a2..0d1d757 100644 --- a/bscope/ic/models.py +++ b/bscope/ic/models.py @@ -10,7 +10,7 @@ from IPython import embed from .custom_dataset import CustomImageNetDataset -def get_model(which_model, return_layers=False, imagenet_path='/data/codec/imagenet', device='cuda', subsample=None,subclasses=None,dataloader_only=False,**kwargs): +def get_model(which_model, return_layers=False, imagenet_path='/mnt/data/imagenet', device='cuda', subsample=None,subclasses=None,dataloader_only=False,**kwargs): if which_model == 'resnet50': weights = ResNet50_Weights.IMAGENET1K_V1 diff --git a/bscope/ic/semantic_analyzer.py b/bscope/ic/semantic_analyzer.py index 639604e..edc348e 100644 --- a/bscope/ic/semantic_analyzer.py +++ b/bscope/ic/semantic_analyzer.py @@ -22,7 +22,7 @@ from scipy import signal class SemanticAnalyzer: - def __init__(self, semantic_hierarchy_path ='/data/codec/hierarchy_metadata/misc/semantic_indexes_test.json'): + def __init__(self, semantic_hierarchy_path ='/home/zalaoui/semantic_indexes_test.json'): self.data = self.load_data(semantic_hierarchy_path) diff --git a/bscope/sae.py b/bscope/sae.py index da5bc42..b72b437 100644 --- a/bscope/sae.py +++ b/bscope/sae.py @@ -95,7 +95,7 @@ def __init__(self, data_dim, num_atoms, mlp_hidden_dim=512): self.layers = nn.ModuleDict() # self.layers['layernorm1'] = nn.LayerNorm(data_dim, elementwise_affine=True) self.layers['layer1'] = nn.Linear(data_dim, self.mlp_hidden_dim, bias=True) - self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) + # self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) self.layers['dropout1'] = nn.Dropout(p=0.05) # Add dropout layer with p=0.2 self.layers['relu1'] = nn.ReLU()# Add sigmoid activation @@ -129,13 +129,13 @@ def __init__(self, data_dim, num_atoms, mlp_hidden_dim=512): # self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) self.layers['dropout1'] = nn.Dropout(p=0.05) # Add dropout layer with p=0.2 self.layers['relu1'] = nn.ReLU()# Add sigmoid activation - self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) + # self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) self.layers['layer2'] = nn.Linear(self.mlp_hidden_dim, self.mlp_hidden_dim, bias=True) # self.layers['layernorm2'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) self.layers['dropout2'] = nn.Dropout(p=0.05) # Add dropout layer with p=0.2 self.layers['relu2'] = nn.ReLU() # Add ReLU activation - self.layers['layernorm2'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) + # self.layers['layernorm2'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) self.layers['layer3'] = nn.Linear(self.mlp_hidden_dim, num_atoms, bias=False) self.layers['sigmoid'] = nn.Sigmoid() # Add sigmoid activation diff --git a/bscope/scope.py b/bscope/scope.py index 57e4341..059b1e5 100644 --- a/bscope/scope.py +++ b/bscope/scope.py @@ -67,6 +67,10 @@ def use_act_grad(self): def use_act_normgrad(self): self.contribution_type='act_normgrad' + + def use_input_int_grad(self, steps=20): + self.contribution_type = 'input_int_grad' + self.steps = steps def use_normact_normgrad(self): self.contribution_type='normact_normgrad' @@ -77,6 +81,11 @@ def use_jacobians(self): def wrt_entropy(self): self.contribution_target = 'entropy' self.softmax = True + + def wrt_firing_rate_sum(self, target_neurons=None, softmax=False): + self.contribution_target = 'firing_rate_sum' + self.target_neurons = target_neurons # None = sum all neurons + self.softmax = softmax def wrt_output_neuron(self, neuron_index=0, softmax=False): self.contribution_target = 'output_neuron' @@ -88,13 +97,22 @@ def wrt_topk(self, k=5, softmax=True): self.k = k self.softmax = softmax + def wrt_surprisal(self, softmax=False): + self.contribution_target = 'surprisal' + self.surprisal_mu = None + self.surprisal_sigma_inv = None + self.softmax = softmax # Raw neural outputs + + def set_surprisal_stats(self, mu, sigma_inv): + self.surprisal_mu = mu + self.surprisal_sigma_inv = sigma_inv + def log_start(self, reduction=None): self.logging = True self.log_gradients = [[] for i in range(self.num_layers)] self.log_activations = [[] for i in range(self.num_layers)] self.log_contributions = [[] for i in range(self.num_layers)] - self.log_outputs = [] self.reduction = reduction @@ -120,8 +138,28 @@ def backward_pass(self, y): sorted, indices = torch.topk(y, self.k, dim=-1) sorted.sum().backward() + elif self.contribution_target == 'firing_rate_sum': + if self.target_neurons is not None: + target_output = y[:, self.target_neurons].sum() + else: + target_output = y.sum() + target_output.backward() + + elif self.contribution_target == 'surprisal': + if self.surprisal_mu is None or self.surprisal_sigma_inv is None: + raise ValueError("Surprisal statistics not set. Call set_surprisal_stats() first.") + + # Convert numpy arrays to tensors on the right device + mu_tensor = torch.from_numpy(self.surprisal_mu).to(y.device).float() + sigma_inv_tensor = torch.from_numpy(self.surprisal_sigma_inv).to(y.device).float() + + # Compute (y - ΞΌ)α΅€ Σ⁻¹ (y - ΞΌ) + centered = y - mu_tensor # [batch_size, n_neurons] + surprisal = 0.5 * torch.sum((centered @ sigma_inv_tensor) * centered, dim=1) # [batch_size] + surprisal.sum().backward() + else: - raise ValueError(f"Unknown contribution target: {target}") + raise ValueError(f"Unknown contribution target: {self.contribution_target}") def __call__(self, input_tensor): if self.contribution_type is None: @@ -142,7 +180,8 @@ def __call__(self, input_tensor): stim = input_tensor.unsqueeze(0) stim.requires_grad = True self.steps = 1 - + elif self.contribution_type == 'input_int_grad': + stim = interpolate_stim(input_tensor, self.steps) # Same as int_grad else: raise ValueError( f"Unknown contribution type: {self.contribution_type}") @@ -150,20 +189,33 @@ def __call__(self, input_tensor): self.last_activations = [[] for i in range(self.num_layers)] self.last_gradients = [[] for i in range(self.num_layers)] self.last_outputs = [] + self.input_gradients = [] + self.input_activations = [] for step in range(self.steps): + self.model.zero_grad() # get device of the model device = next(self.model.parameters()).device + current_stim = stim[step].to(device) # Store it + current_stim.requires_grad = True # Make sure it requires grad y = self.model(stim[step].to(device)) + if self.contribution_type == 'input_int_grad': + y = self.model(current_stim) # CORRECT - use the same tensor if self.softmax: y = torch.softmax(y, dim=-1) self.backward_pass(y) + if self.contribution_type == 'input_int_grad': + if current_stim.grad is not None: + self.input_gradients.append(current_stim.grad.detach().cpu().numpy()) + self.input_activations.append(current_stim.detach().cpu().numpy()) + self.last_outputs.append(y.detach().cpu().numpy()) + [ self.last_activations[i].append(self.inspector.activations[i]) for i in range(self.num_layers) @@ -281,6 +333,22 @@ def __call__(self, input_tensor): cam = np.maximum(cam, 0) contributions.append(cam) + + + elif self.contribution_type == 'input_int_grad': + contributions = [] + for layer in range(self.num_layers): + + + input_int_grads = interneuron_integral_approximation( + self.input_gradients, # All interpolated inputs you collected + self.input_activations # All input gradients you collected + ) + contributions.append(input_int_grads) + + self.activations = [self.last_activations[layer][-1] for layer in range(self.num_layers)] + self.gradients = [self.last_gradients[layer][-1] for layer in range(self.num_layers)] + else: raise ValueError( f"Unknown contribution type: {self.contribution_type}") @@ -318,6 +386,7 @@ def __call__(self, input_tensor): self.log_contributions[layer].append(c) + def corrupt_stim(stim, sigma=0.03, steps=10): """ Corrupt the stimulus by adding Gaussian noise. @@ -333,21 +402,20 @@ def corrupt_stim(stim, sigma=0.03, steps=10): return corrupted_stim + def interpolate_stim(stim, steps=10): - """ - Interpolate the stimulus to create a series of inputs for integrated gradients. - """ - stim = stim.detach() + # Don't detach, but also don't stack first baseline = torch.zeros_like(stim) - interp_stim = [ - baseline + (float(i) / steps) * (stim - baseline) - for i in range(0, steps + 1) - ] + interp_stim = [] + for i in range(0, steps + 1): + # Create each interpolated stimulus with gradients enabled + alpha = float(i) / steps + interpolated = baseline + alpha * (stim - baseline) + interpolated = interpolated.detach().requires_grad_(True) # Make each one a leaf + interp_stim.append(interpolated) - interp_stim = torch.stack(interp_stim) - interp_stim.requires_grad = True - return interp_stim + return interp_stim # Return list, not stacked tensor def interneuron_integral_approximation(acts, grads): diff --git a/bscope/utils.py b/bscope/utils.py index 18fc862..e1e7194 100644 --- a/bscope/utils.py +++ b/bscope/utils.py @@ -1,6 +1,9 @@ import numpy as np import torch import matplotlib as mpl +import bscope +import bscope.ic as bic +from scipy import signal Epsilon = 1e-6 From c9a912752424c45e90ba893797a70a17ead949d7 Mon Sep 17 00:00:00 2001 From: Zaki Alaoui Date: Sat, 29 Nov 2025 21:53:02 +0000 Subject: [PATCH 4/6] constrative top2 --- bscope/scope.py | 236 ++++++++++++++++++++++++++++++------------------ 1 file changed, 147 insertions(+), 89 deletions(-) diff --git a/bscope/scope.py b/bscope/scope.py index 059b1e5..83c8317 100644 --- a/bscope/scope.py +++ b/bscope/scope.py @@ -9,6 +9,28 @@ import numpy as np +def torch_normalize_batch_across_cyx(array): + # array shape: [B, C, Y, X] + + # Get batch size + b_size = array.shape[0] + + # Reshape to [B, C*Y*X] to compute norm across all C, Y, X dimensions + reshaped = array.reshape(b_size, -1) + + # Compute the norm for each batch item + norms = torch.norm(reshaped, dim=1, keepdim=True) + + # Avoid division by zero + norms = torch.clamp(norms, min=1e-8) + + # Normalize each batch item + normalized_reshaped = reshaped / norms + + # Reshape back to original shape + normalized = normalized_reshaped.reshape(array.shape) + + return normalized def normalize_batch_across_cyx(array): # array shape: [B, C, Y, X] @@ -39,12 +61,13 @@ class Scope: log_X is the logging of X """ - def __init__(self, model, layer_list): + def __init__(self, model, layer_list, to_numpy=True): model.eval() self.model = model self.layer_list = layer_list - - self.inspector = Inspector(layer_list) + + self.to_numpy = to_numpy + self.inspector = Inspector(layer_list, to_numpy=to_numpy) self.num_layers = len(layer_list) self.reduction = None @@ -67,10 +90,6 @@ def use_act_grad(self): def use_act_normgrad(self): self.contribution_type='act_normgrad' - - def use_input_int_grad(self, steps=20): - self.contribution_type = 'input_int_grad' - self.steps = steps def use_normact_normgrad(self): self.contribution_type='normact_normgrad' @@ -78,23 +97,19 @@ def use_normact_normgrad(self): def use_jacobians(self): self.contribution_type = 'jacobians' - def wrt_entropy(self): + def wrt_entropy(self, softmax=True): self.contribution_target = 'entropy' - self.softmax = True - - def wrt_firing_rate_sum(self, target_neurons=None, softmax=False): - self.contribution_target = 'firing_rate_sum' - self.target_neurons = target_neurons # None = sum all neurons - self.softmax = softmax + + self.softmax=softmax def wrt_output_neuron(self, neuron_index=0, softmax=False): self.contribution_target = 'output_neuron' self.neuron_index = neuron_index self.softmax = softmax - def wrt_topk(self, k=5, softmax=True): - self.contribution_target = 'topk' - self.k = k + def wrt_firing_rate_sum(self, target_neurons=None, softmax=False): + self.contribution_target = 'firing_rate_sum' + self.target_neurons = target_neurons self.softmax = softmax def wrt_surprisal(self, softmax=False): @@ -107,12 +122,26 @@ def set_surprisal_stats(self, mu, sigma_inv): self.surprisal_mu = mu self.surprisal_sigma_inv = sigma_inv + def wrt_topk(self, k=5, softmax=True): + self.contribution_target = 'topk' + self.k = k + self.softmax = softmax + + def wrt_sum(self, softmax=False): + self.softmax = softmax + self.contribution_target = 'sum' + + def wrt_contrastive_top2(self, softmax=True): + self.contribution_target = 'contrastive_top2' + self.softmax = softmax + def log_start(self, reduction=None): self.logging = True self.log_gradients = [[] for i in range(self.num_layers)] self.log_activations = [[] for i in range(self.num_layers)] self.log_contributions = [[] for i in range(self.num_layers)] + self.log_outputs = [] self.reduction = reduction @@ -121,7 +150,8 @@ def log_stop(self): for i in range(self.num_layers): self.log_gradients[i] = np.concatenate(self.log_gradients[i]) self.log_activations[i] = np.concatenate(self.log_activations[i]) - self.log_contributions[i] = np.concatenate(self.log_contributions[i]) + self.log_contributions[i] = np.concatenate( + self.log_contributions[i]) self.logging = False @@ -138,12 +168,8 @@ def backward_pass(self, y): sorted, indices = torch.topk(y, self.k, dim=-1) sorted.sum().backward() - elif self.contribution_target == 'firing_rate_sum': - if self.target_neurons is not None: - target_output = y[:, self.target_neurons].sum() - else: - target_output = y.sum() - target_output.backward() + elif self.contribution_target == 'sum': + y.sum().backward() elif self.contribution_target == 'surprisal': if self.surprisal_mu is None or self.surprisal_sigma_inv is None: @@ -158,6 +184,15 @@ def backward_pass(self, y): surprisal = 0.5 * torch.sum((centered @ sigma_inv_tensor) * centered, dim=1) # [batch_size] surprisal.sum().backward() + elif self.contribution_target == 'contrastive_top2': + top2_values, top2_indices = torch.topk(y, 2, dim=-1) + + # Compute difference: top-1 minus top-2 + contrastive_score = top2_values[:, 0] - top2_values[:, 1] + + # Backward through the difference + contrastive_score.sum().backward() + else: raise ValueError(f"Unknown contribution target: {self.contribution_target}") @@ -173,15 +208,14 @@ def __call__(self, input_tensor): # Prepare the stimulus if self.contribution_type == 'int_grad': - stim = interpolate_stim(input_tensor, self.steps) + self.stim = interpolate_stim(input_tensor, self.steps) elif self.contribution_type == 'smooth_grad': - stim = corrupt_stim(input_tensor, self.sigma, self.steps) + self.stim = corrupt_stim(input_tensor, self.sigma, self.steps) elif self.contribution_type == 'act_normgrad' or self.contribution_type == 'normact_normgrad' or self.contribution_type == 'jacobians' or self.contribution_type == 'act_grad': - stim = input_tensor.unsqueeze(0) - stim.requires_grad = True - self.steps = 1 - elif self.contribution_type == 'input_int_grad': - stim = interpolate_stim(input_tensor, self.steps) # Same as int_grad + self.stim = input_tensor.unsqueeze(0) + # self.stim.requires_grad = True + self.steps = 0 + else: raise ValueError( f"Unknown contribution type: {self.contribution_type}") @@ -189,33 +223,20 @@ def __call__(self, input_tensor): self.last_activations = [[] for i in range(self.num_layers)] self.last_gradients = [[] for i in range(self.num_layers)] self.last_outputs = [] - self.input_gradients = [] - self.input_activations = [] - for step in range(self.steps): - + for step in range(self.steps+1): self.model.zero_grad() # get device of the model device = next(self.model.parameters()).device - current_stim = stim[step].to(device) # Store it - current_stim.requires_grad = True # Make sure it requires grad - y = self.model(stim[step].to(device)) - if self.contribution_type == 'input_int_grad': - y = self.model(current_stim) # CORRECT - use the same tensor + y = self.model(self.stim[step].to(device)) if self.softmax: y = torch.softmax(y, dim=-1) self.backward_pass(y) - if self.contribution_type == 'input_int_grad': - if current_stim.grad is not None: - self.input_gradients.append(current_stim.grad.detach().cpu().numpy()) - self.input_activations.append(current_stim.detach().cpu().numpy()) - self.last_outputs.append(y.detach().cpu().numpy()) - [ self.last_activations[i].append(self.inspector.activations[i]) for i in range(self.num_layers) @@ -226,20 +247,30 @@ def __call__(self, input_tensor): ] self.last_activations = [ - np.array(self.last_activations[i]) for i in range(self.num_layers) + self.last_activations[i] for i in range(self.num_layers) ] self.last_gradients = [ - np.array(self.last_gradients[i]) for i in range(self.num_layers) + self.last_gradients[i] for i in range(self.num_layers) ] + + if self.to_numpy: + self.last_activations = [np.array(self.last_activations[i]) for i in range(self.num_layers)] + self.last_gradients = [np.array(self.last_gradients[i]) for i in range(self.num_layers)] if self.contribution_type == 'act_normgrad': contributions = [] for layer in range(self.num_layers): act = self.last_activations[layer][0] - grad = self.last_gradients[layer][0] - norm_grad = normalize_batch_across_cyx(grad) - contributions.append(np.array(act * norm_grad)) + + if self.to_numpy: + grad = self.last_gradients[layer][0] + norm_grad = normalize_batch_across_cyx(grad) + contributions.append(np.array(act * norm_grad)) + else: + grad = self.last_gradients[layer][0] + norm_grad = torch_normalize_batch_across_cyx(grad) + contributions.append(act * norm_grad) self.activations = [ self.last_activations[layer][0] for layer in range(self.num_layers) @@ -254,9 +285,16 @@ def __call__(self, input_tensor): for layer in range(self.num_layers): act = self.last_activations[layer][0] grad = self.last_gradients[layer][0] - norm_act = normalize_batch_across_cyx(act) - norm_grad = normalize_batch_across_cyx(grad) - contributions.append(np.array(norm_act* norm_grad)) + + + if self.to_numpy: + contributions.append(np.array(norm_act* norm_grad)) + norm_act = normalize_batch_across_cyx(act) + norm_grad = normalize_batch_across_cyx(grad) + else: + contributions.append(norm_act* norm_grad) + norm_act = torch_normalize_batch_across_cyx(act) + norm_grad = torch_normalize_batch_across_cyx(grad) self.activations = [ self.last_activations[layer][0] @@ -272,7 +310,11 @@ def __call__(self, input_tensor): for layer in range(self.num_layers): act = self.last_activations[layer][0] grad = self.last_gradients[layer][0] - contributions.append(np.array(act * grad)) + + if self.to_numpy: + contributions.append(np.array(act * grad)) + else: + contributions.append(act * grad) self.activations = [ self.last_activations[layer][0] @@ -287,10 +329,16 @@ def __call__(self, input_tensor): for layer in range(self.num_layers): interp_activations = self.last_activations[layer] interp_gradients = self.last_gradients[layer] - contributions.append( - np.array( - interneuron_integral_approximation( - interp_activations, interp_gradients))) + + if self.to_numpy: + contributions.append( + np.array( + interneuron_integral_approximation( + interp_activations, interp_gradients))) + else: + contributions.append( + torch_interneuron_integral_approximation( + interp_activations, interp_gradients)) self.activations = [ self.last_activations[layer][-1] @@ -333,22 +381,6 @@ def __call__(self, input_tensor): cam = np.maximum(cam, 0) contributions.append(cam) - - - elif self.contribution_type == 'input_int_grad': - contributions = [] - for layer in range(self.num_layers): - - - input_int_grads = interneuron_integral_approximation( - self.input_gradients, # All interpolated inputs you collected - self.input_activations # All input gradients you collected - ) - contributions.append(input_int_grads) - - self.activations = [self.last_activations[layer][-1] for layer in range(self.num_layers)] - self.gradients = [self.last_gradients[layer][-1] for layer in range(self.num_layers)] - else: raise ValueError( f"Unknown contribution type: {self.contribution_type}") @@ -372,9 +404,9 @@ def __call__(self, input_tensor): c = c.sum((2, 3)) if 'patch_ei_split' in self.reduction: - g = ei_split(g, patch=True) - a = ei_split(a, patch=True) - c = ei_split(c, patch=True) + g = ei_split(g, dim=-1) + a = ei_split(a, dim=-1) + c = ei_split(c, dim=-1) if 'patch_sum' in self.reduction: g = g.sum(1) @@ -385,6 +417,7 @@ def __call__(self, input_tensor): self.log_activations[layer].append(a) self.log_contributions[layer].append(c) + return y def corrupt_stim(stim, sigma=0.03, steps=10): @@ -402,22 +435,47 @@ def corrupt_stim(stim, sigma=0.03, steps=10): return corrupted_stim - def interpolate_stim(stim, steps=10): - # Don't detach, but also don't stack first + """ + Interpolate the stimulus to create a series of inputs for integrated gradients. + """ + stim = stim.detach() baseline = torch.zeros_like(stim) - interp_stim = [] - for i in range(0, steps + 1): - # Create each interpolated stimulus with gradients enabled - alpha = float(i) / steps - interpolated = baseline + alpha * (stim - baseline) - interpolated = interpolated.detach().requires_grad_(True) # Make each one a leaf - interp_stim.append(interpolated) + interp_stim = [ + baseline + (float(i) / steps) * (stim - baseline) + for i in range(0, steps + 1) + ] + + interp_stim = torch.stack(interp_stim) + interp_stim.requires_grad = True + return interp_stim - return interp_stim # Return list, not stacked tensor +def torch_interneuron_integral_approximation(acts, grads): + """ + Trapezoidal integral approximation for integrated gradients. + """ + igs = [] + for i, (a, g) in enumerate(zip(acts, grads)): + if i == 0: + last_act = a + continue + diff_act = a - last_act + last_act = a + + trapezoidal = grads[i - 1] + grads[i] + + trapezoidal /= 2 + + ig = trapezoidal * diff_act + igs.append(ig) + + igs = torch.stack(igs) + igs = torch.sum(igs, axis=0) + + return igs def interneuron_integral_approximation(acts, grads): """ Trapezoidal integral approximation for integrated gradients. @@ -441,4 +499,4 @@ def interneuron_integral_approximation(acts, grads): igs = np.array(igs) igs = np.sum(igs, axis=0) - return igs + return igs \ No newline at end of file From 562aea1cfbeff3dfc0f8b7bc419aec7f1848092d Mon Sep 17 00:00:00 2001 From: Zaki Alaoui Date: Wed, 14 Jan 2026 19:05:20 +0000 Subject: [PATCH 5/6] nonnegativity change to Encoder, StSAE, and SigThresh --- bscope/ic/evaluation.py | 179 +++- bscope/ic/mode_summary.py | 1506 +++++++++++++++++--------------- bscope/ic/semantic_analyzer.py | 521 ++++++++++- bscope/sae.py | 108 ++- 4 files changed, 1567 insertions(+), 747 deletions(-) diff --git a/bscope/ic/evaluation.py b/bscope/ic/evaluation.py index 1dc1ba4..6735536 100644 --- a/bscope/ic/evaluation.py +++ b/bscope/ic/evaluation.py @@ -1,3 +1,148 @@ +# import tqdm +# import torchvision +# import torch.nn as nn +# import torch +# import numpy as np +# import tqdm +# import torchvision +# import torch.nn as nn +# import torch +# import numpy as np + + +# def calculate_accuracy(model, val_loader, device='cuda'): +# """ +# Evaluate a model on the validation dataset and return top-1 and top-5 accuracy. + +# Args: +# model: PyTorch model to evaluate +# val_loader: DataLoader for the validation dataset +# device: Device to run evaluation on ('cuda' or 'cpu') + +# Returns: +# top1_acc: Top-1 accuracy as a percentage +# top5_acc: Top-5 accuracy as a percentage +# """ +# model.eval() +# model = model.to(device) + +# correct_1 = 0 +# correct_5 = 0 +# total = 0 + +# with torch.no_grad(): +# for inputs, targets in tqdm.tqdm(val_loader): +# inputs, targets = inputs.to(device), targets.to(device) + +# # Forward pass +# outputs = model(inputs) + +# # Top-1 accuracy +# _, predicted = outputs.max(1) +# correct_1 += (predicted == targets).sum().item() + +# # Top-5 accuracy +# _, top5_predicted = outputs.topk(5, 1) +# for i in range(targets.size(0)): +# if targets[i] in top5_predicted[i]: +# correct_5 += 1 + +# total += targets.size(0) + +# top1_acc = 100 * correct_1 / total +# top5_acc = 100 * correct_5 / total + +# return top1_acc, top5_acc + + + +# def calculate_class_accuracy(model, +# val_loader, +# num_classes=1000, +# device='cuda:1', +# target_classes=None, +# target_topk=5, +# nontarget_topk=1): + +# if isinstance(target_classes, int): +# target_classes = [target_classes] +# elif target_classes is None: +# target_classes = [] + +# model.eval() +# model = model.to(device) + +# correct_per_class = torch.zeros(num_classes, device=device) +# total_per_class = torch.zeros(num_classes, device=device) + +# with torch.no_grad(): +# for inputs, targets in tqdm.tqdm(val_loader): +# inputs, targets = inputs.to(device), targets.to(device) +# outputs = model(inputs) + +# # Get top max(target_topk, nontarget_topk) predictions once for efficiency +# max_k = max(target_topk, nontarget_topk) +# _, pred_all = outputs.topk(max_k, dim=1, largest=True, sorted=True) + +# for i in range(targets.size(0)): +# label = targets[i].item() +# total_per_class[label] += 1 + +# # Choose dynamic top-k +# topk = target_topk if label in target_classes else nontarget_topk +# pred = pred_all[i][:topk] + +# if label in pred: +# correct_per_class[label] += 1 + +# class_accuracy = torch.zeros(num_classes, device=device) +# for i in range(num_classes): +# if total_per_class[i] > 0: +# class_accuracy[i] = (correct_per_class[i] / +# total_per_class[i]) * 100 + +# return class_accuracy.detach().cpu().numpy(), total_per_class.int() + + +# def calculate_subsample_accuracy(model, +# val_loader, +# subclasses, +# topk=5, +# device='cuda:1'): + + +# model.eval() +# model = model.to(device) + +# correct_per_class = torch.zeros(len(subclasses), device=device) +# n_subsample = len(val_loader.dataset) // len(subclasses) +# print('N subsample ', n_subsample) + +# label_mapping = {label: idx for idx, label in enumerate(subclasses)} + +# with torch.no_grad(): +# for inputs, targets in tqdm.tqdm(val_loader): +# inputs, targets = inputs.to(device), targets.to(device) +# outputs = model(inputs) + +# # Get top max(target_topk, nontarget_topk) predictions once for efficiency +# _, pred_all = outputs.topk(topk, dim=1, largest=True, sorted=True) + +# for i in range(targets.size(0)): +# label = targets[i].item() +# pred = pred_all[i][:topk] + +# if label in pred: +# correct_per_class[label_mapping[label]] += 1 + +# class_accuracy = torch.zeros(len(subclasses), device=device) +# num_classes = len(subclasses) +# for i in range(num_classes): +# class_accuracy[i] = (correct_per_class[i] / +# n_subsample) * 100 + +# return class_accuracy.detach().cpu().numpy() + import tqdm import torchvision import torch.nn as nn @@ -107,38 +252,50 @@ def calculate_class_accuracy(model, def calculate_subsample_accuracy(model, val_loader, subclasses, - topk=5, device='cuda:1'): model.eval() model = model.to(device) - correct_per_class = torch.zeros(len(subclasses), device=device) + correct_per_class_top_1 = torch.zeros(len(subclasses), device=device) + correct_per_class_top_5 = torch.zeros(len(subclasses), device=device) + n_subsample = len(val_loader.dataset) // len(subclasses) print('N subsample ', n_subsample) label_mapping = {label: idx for idx, label in enumerate(subclasses)} with torch.no_grad(): - for inputs, targets in tqdm.tqdm(val_loader): + for inputs, targets in val_loader: inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) # Get top max(target_topk, nontarget_topk) predictions once for efficiency - _, pred_all = outputs.topk(topk, dim=1, largest=True, sorted=True) + _, pred_all = outputs.topk(5, dim=1, largest=True, sorted=True) + for i in range(targets.size(0)): label = targets[i].item() - pred = pred_all[i][:topk] - if label in pred: - correct_per_class[label_mapping[label]] += 1 + topk = 1 + pred_1 = pred_all[i][:topk] + + topk = 5 + pred_5 = pred_all[i][:topk] + + if label in pred_1: + correct_per_class_top_1[label_mapping[label]] += 1 + if label in pred_5: + correct_per_class_top_5[label_mapping[label]] += 1 + + top1_class_accuracy = torch.zeros(len(subclasses), device=device) + top5_class_accuracy = torch.zeros(len(subclasses), device=device) - class_accuracy = torch.zeros(len(subclasses), device=device) num_classes = len(subclasses) for i in range(num_classes): - class_accuracy[i] = (correct_per_class[i] / - n_subsample) * 100 + top1_class_accuracy[i] = (correct_per_class_top_1[i] / n_subsample) * 100 + top5_class_accuracy[i] = (correct_per_class_top_5[i] / n_subsample) * 100 + - return class_accuracy.detach().cpu().numpy() + return top1_class_accuracy.cpu().numpy(), top5_class_accuracy.cpu().numpy() \ No newline at end of file diff --git a/bscope/ic/mode_summary.py b/bscope/ic/mode_summary.py index 36213c6..45a8c17 100644 --- a/bscope/ic/mode_summary.py +++ b/bscope/ic/mode_summary.py @@ -6,6 +6,14 @@ import bscope.ic as bic from typing import List, Tuple, Dict, Union, Optional from dataclasses import dataclass +import numpy as np +from IPython import embed +import h5py as h5 +import matplotlib.pyplot as plt +import bscope +import bscope.ic as bic +from typing import List, Tuple, Dict, Union, Optional +from dataclasses import dataclass @dataclass class LayerSummary: @@ -19,6 +27,7 @@ class LayerSummary: dictionary: Dictionary/atoms matrix from SAE (features Γ— channels) """ corr_mtx: np.ndarray + imgnet_corr_mtx: np.ndarray # Correlation matrix with ImageNet classes loadings: np.ndarray dictionary: np.ndarray @@ -31,6 +40,7 @@ class LayerSummary: def __post_init__(self): self.corr_mtx[np.isnan(self.corr_mtx)] = 0 # Replace NaNs with zeros + self.imgnet_corr_mtx[np.isnan(self.imgnet_corr_mtx)] = 0 self.num_modes = self.dictionary.shape[0] @@ -53,7 +63,18 @@ def __init__(self, h5_path): else str(label).strip().split('.')[0] for label in raw_labels] else: self.mask_labels = None + + # Load imgnet mask labels if available + if 'imgnet_mask_labels' in self.file: + raw_imgnet_labels = self.file['imgnet_mask_labels'][:] + self.imgnet_mask_labels = [label.decode('utf-8').strip() if isinstance(label, bytes) + else str(label).strip() for label in raw_imgnet_labels] + else: + self.imgnet_mask_labels = None + + self.mask_matrix = self.file['mask_matrix'][:] if 'mask_matrix' in self.file else None + self.imgnet_mask_matrix = self.file['imgnet_mask_matrix'][:] if 'imgnet_mask_matrix' in self.file else None self.layer_idxs = np.sort([int(l) for l in list(self.file['layers'].keys())]) self.layers = [] @@ -62,818 +83,885 @@ def __init__(self, h5_path): layer_data = self.file['layers'][layer_key] corr_mtx = layer_data['corr_mtx'][:] + imgnet_corr_mtx = layer_data['imgnet_corr_mtx'][:] corr_mtx=corr_mtx.T + imgnet_corr_mtx=imgnet_corr_mtx.T loadings = layer_data['loadings'][:] dictionary = layer_data['dictionary'][:] aggregated_data = layer_data['data_agg'][:] if 'data' in layer_data else None aggregated_reconstruction = layer_data['reconstructed_agg'][:] - r2 = layer_data['r2'][()] if 'r2' in layer_data else None + r2 = layer_data.attrs['r2'][()] if 'r2' in layer_data.attrs else None - self.layers.append(LayerSummary(corr_mtx, loadings, dictionary, layer_idx, r2, aggregated_data, aggregated_reconstruction)) + self.layers.append(LayerSummary(corr_mtx, imgnet_corr_mtx, loadings, dictionary, layer_idx, r2, aggregated_data, aggregated_reconstruction)) +# @dataclass +# class LayerSummary: +# """ +# A class to summarize the information of a single layer in the mode summary. + +# Attributes: +# corr_mtx: Correlation matrix between loadings and mask matrix +# r2: R-squared values for the layer +# loadings: Loadings matrix from SAE (samples Γ— features) +# dictionary: Dictionary/atoms matrix from SAE (features Γ— channels) +# """ +# corr_mtx: np.ndarray +# loadings: np.ndarray +# dictionary: np.ndarray -class ModeAnalyzer: - """ - A class for analyzing modes and channels using ModeSummary data. - """ +# idx: Optional[int] = None # Optional index for the layer, if needed +# r2: Optional[int] = None # Optional R-squared values for the layer + +# aggregated_data: Optional[np.ndarray] = None # Aggregated data if available +# aggregated_reconstruction: Optional[np.ndarray] = None # Aggregated reconstruction if available + +# def __post_init__(self): +# self.corr_mtx[np.isnan(self.corr_mtx)] = 0 # Replace NaNs with zeros +# self.num_modes = self.dictionary.shape[0] + + +# class ModeSummary: +# def __init__(self, h5_path): +# """ +# Initialize the ModeSummary with a path to an HDF5 file. + +# Args: +# h5_path: Path to the HDF5 file containing mode summary data +# """ +# self.h5_path = h5_path +# self.file = h5.File(h5_path, 'r') + +# # Load mask labels if available +# if 'mask_labels' in self.file: +# raw_labels = self.file['mask_labels'][:] +# # Convert bytes to strings, strip whitespace, and extract base name +# self.mask_labels = [label.decode('utf-8').strip().split('.')[0] if isinstance(label, bytes) +# else str(label).strip().split('.')[0] for label in raw_labels] +# else: +# self.mask_labels = None +# self.mask_matrix = self.file['mask_matrix'][:] if 'mask_matrix' in self.file else None + +# self.layer_idxs = np.sort([int(l) for l in list(self.file['layers'].keys())]) +# self.layers = [] +# for layer_idx in self.layer_idxs: +# layer_key = str(layer_idx) +# layer_data = self.file['layers'][layer_key] + +# corr_mtx = layer_data['corr_mtx'][:] +# corr_mtx=corr_mtx.T +# loadings = layer_data['loadings'][:] +# dictionary = layer_data['dictionary'][:] +# aggregated_data = layer_data['data_agg'][:] if 'data' in layer_data else None +# aggregated_reconstruction = layer_data['reconstructed_agg'][:] + +# r2 = layer_data['r2'][()] if 'r2' in layer_data else None + +# self.layers.append(LayerSummary(corr_mtx, loadings, dictionary, layer_idx, r2, aggregated_data, aggregated_reconstruction)) + + + +# class ModeAnalyzer: +# """ +# A class for analyzing modes and channels using ModeSummary data. +# """ - def __init__(self, mode_summary_path: ModeSummary): - """ - Initialize the ModeAnalyzer with a ModeSummary instance. +# def __init__(self, mode_summary_path: ModeSummary): +# """ +# Initialize the ModeAnalyzer with a ModeSummary instance. - Args: - mode_summary: ModeSummary instance containing the data - contributions: Optional contributions array for sample analysis - """ - self.summary= ModeSummary(mode_summary_path) +# Args: +# mode_summary: ModeSummary instance containing the data +# contributions: Optional contributions array for sample analysis +# """ +# self.summary= ModeSummary(mode_summary_path) - def find_concept_indices(self, concept_name: str) -> List[int]: - """ - Find the indices of concepts that match the given name. +# def find_concept_indices(self, concept_name: str) -> List[int]: +# """ +# Find the indices of concepts that match the given name. - Args: - concept_name: The name of the concept to find +# Args: +# concept_name: The name of the concept to find - Returns: - List of matching concept indices - """ - if self.summary.mask_labels is None: - raise ValueError("No mask labels available in the ModeSummary") +# Returns: +# List of matching concept indices +# """ +# if self.summary.mask_labels is None: +# raise ValueError("No mask labels available in the ModeSummary") - matching_indices = [] - concept_name_lower = concept_name.lower() - - for i, label in enumerate(self.summary.mask_labels): - # Handle both string and bytes labels - if isinstance(label, bytes): - label_str = label.decode('utf-8') - else: - label_str = str(label) +# matching_indices = [] +# concept_name_lower = concept_name.lower() + +# for i, label in enumerate(self.summary.mask_labels): +# # Handle both string and bytes labels +# if isinstance(label, bytes): +# label_str = label.decode('utf-8') +# else: +# label_str = str(label) - # Extract base name (same as SemanticAnalyzer) - base_name = label_str.split('.')[0].strip().lower() +# # Extract base name (same as SemanticAnalyzer) +# base_name = label_str.split('.')[0].strip().lower() - if base_name == concept_name_lower: - matching_indices.append(i) +# if base_name == concept_name_lower: +# matching_indices.append(i) - return matching_indices +# return matching_indices - def get_concept_info(self, concept_name: str, select_first: bool = True) -> Tuple[int, str]: - """ - Get information about a concept. +# def get_concept_info(self, concept_name: str, select_first: bool = True) -> Tuple[int, str]: +# """ +# Get information about a concept. - Args: - concept_name: The name of the concept to find - select_first: If True, automatically select the first match if multiple found +# Args: +# concept_name: The name of the concept to find +# select_first: If True, automatically select the first match if multiple found - Returns: - Tuple of (concept_index, concept_label) +# Returns: +# Tuple of (concept_index, concept_label) - Raises: - ValueError: If no matching concepts found or multiple matches found and select_first is False - """ - matching_indices = self.find_concept_indices(concept_name) +# Raises: +# ValueError: If no matching concepts found or multiple matches found and select_first is False +# """ +# matching_indices = self.find_concept_indices(concept_name) - if not matching_indices: - raise ValueError(f"No concepts found with name '{concept_name}'") +# if not matching_indices: +# raise ValueError(f"No concepts found with name '{concept_name}'") - if len(matching_indices) > 1: - print(f"Found multiple matching concepts:") - for idx in matching_indices: - print(f" {idx}: {self.summary.mask_labels[idx]}") +# if len(matching_indices) > 1: +# print(f"Found multiple matching concepts:") +# for idx in matching_indices: +# print(f" {idx}: {self.summary.mask_labels[idx]}") - if select_first: - concept_idx = matching_indices[0] - print(f"Using the first match: {self.summary.mask_labels[concept_idx]}") - else: - raise ValueError(f"Multiple concepts found with name '{concept_name}'. Set select_first=True to use the first match.") - else: - concept_idx = matching_indices[0] +# if select_first: +# concept_idx = matching_indices[0] +# print(f"Using the first match: {self.summary.mask_labels[concept_idx]}") +# else: +# raise ValueError(f"Multiple concepts found with name '{concept_name}'. Set select_first=True to use the first match.") +# else: +# concept_idx = matching_indices[0] - return concept_idx, self.summary.mask_labels[concept_idx] +# return concept_idx, self.summary.mask_labels[concept_idx] - def get_layer(self, layer_idx: int) -> LayerSummary: - """ - Get a specific layer from the ModeSummary. +# def get_layer(self, layer_idx: int) -> LayerSummary: +# """ +# Get a specific layer from the ModeSummary. - Args: - layer_idx: Index of the layer to retrieve +# Args: +# layer_idx: Index of the layer to retrieve - Returns: - LayerSummary for the specified layer +# Returns: +# LayerSummary for the specified layer - Raises: - ValueError: If layer_idx is not found - """ - for layer in self.summary.layers: - if layer.idx == layer_idx: - return layer +# Raises: +# ValueError: If layer_idx is not found +# """ +# for layer in self.summary.layers: +# if layer.idx == layer_idx: +# return layer - raise ValueError(f"Layer {layer_idx} not found. Available layers: {self.summary.layer_idxs}") +# raise ValueError(f"Layer {layer_idx} not found. Available layers: {self.summary.layer_idxs}") - def get_top_modes(self, layer_idx: int, concept_name: str, method: str = 'percentile', - param: float = 0.7, min_indices: int = 1, max_indices: int = 50, - select_first: bool = True) -> np.ndarray: - """ - Get top modes for a specific layer and concept using select_significant_indices. +# def get_top_modes(self, layer_idx: int, concept_name: str, method: str = 'percentile', +# param: float = 0.7, min_indices: int = 1, max_indices: int = 50, +# select_first: bool = True) -> np.ndarray: +# """ +# Get top modes for a specific layer and concept using select_significant_indices. - Args: - layer_idx: Index of the layer - concept_name: Name of the concept to analyze - method: Method for select_significant_indices ('threshold', 'percentile', etc.) - param: Parameter for the selection method - min_indices: Minimum number of indices to return - max_indices: Maximum number of indices to return - select_first: Whether to auto-select first concept match +# Args: +# layer_idx: Index of the layer +# concept_name: Name of the concept to analyze +# method: Method for select_significant_indices ('threshold', 'percentile', etc.) +# param: Parameter for the selection method +# min_indices: Minimum number of indices to return +# max_indices: Maximum number of indices to return +# select_first: Whether to auto-select first concept match - Returns: - Array of mode indices - """ - # Get the layer - layer = self.get_layer(layer_idx) +# Returns: +# Array of mode indices +# """ +# # Get the layer +# layer = self.get_layer(layer_idx) - # Get concept info - concept_idx, concept_label = self.get_concept_info(concept_name, select_first) +# # Get concept info +# concept_idx, concept_label = self.get_concept_info(concept_name, select_first) - # Get correlations for this concept (equivalent to your rs array) - correlations = layer.corr_mtx[:, concept_idx] +# # Get correlations for this concept (equivalent to your rs array) +# correlations = layer.corr_mtx[:, concept_idx] - # Use select_significant_indices to get top modes - modes = select_significant_indices( - correlations, - method=method, - param=param, - min_indices=min_indices, - max_indices=max_indices - ) +# # Use select_significant_indices to get top modes +# modes = select_significant_indices( +# correlations, +# method=method, +# param=param, +# min_indices=min_indices, +# max_indices=max_indices +# ) - return modes +# return modes - def get_top_channels(self, layer_idx: int, concept_name: str, - mode_method: str = 'percentile', mode_param: float = 0.7, - mode_min_indices: int = 1, mode_max_indices: int = 50, - channel_method: str = 'percentile', channel_param: float = 0.5, - channel_min_indices: int = 1, channel_max_indices: int = 50, - select_first: bool = True, concat: bool = False) -> List[int]: - """ - Get top channels for a specific layer and concept by first getting top modes, - then getting channels from those modes' dictionary vectors. +# def get_top_channels(self, layer_idx: int, concept_name: str, +# mode_method: str = 'percentile', mode_param: float = 0.7, +# mode_min_indices: int = 1, mode_max_indices: int = 50, +# channel_method: str = 'percentile', channel_param: float = 0.5, +# channel_min_indices: int = 1, channel_max_indices: int = 50, +# select_first: bool = True, concat: bool = False) -> List[int]: +# """ +# Get top channels for a specific layer and concept by first getting top modes, +# then getting channels from those modes' dictionary vectors. - Args: - layer_idx: Index of the layer - concept_name: Name of the concept to analyze - mode_method: Method for selecting modes - mode_param: Parameter for mode selection - mode_min_indices: Minimum number of modes - mode_max_indices: Maximum number of modes - channel_method: Method for selecting channels from each mode - channel_param: Parameter for channel selection - channel_min_indices: Minimum number of channels per mode - channel_max_indices: Maximum number of channels per mode - select_first: Whether to auto-select first concept match - concat: If True, map doubled channel indices back to original channel space +# Args: +# layer_idx: Index of the layer +# concept_name: Name of the concept to analyze +# mode_method: Method for selecting modes +# mode_param: Parameter for mode selection +# mode_min_indices: Minimum number of modes +# mode_max_indices: Maximum number of modes +# channel_method: Method for selecting channels from each mode +# channel_param: Parameter for channel selection +# channel_min_indices: Minimum number of channels per mode +# channel_max_indices: Maximum number of channels per mode +# select_first: Whether to auto-select first concept match +# concat: If True, map doubled channel indices back to original channel space - Returns: - List of unique channel indices across all selected modes - """ - # Get the layer - layer = self.get_layer(layer_idx) +# Returns: +# List of unique channel indices across all selected modes +# """ +# # Get the layer +# layer = self.get_layer(layer_idx) - # Get top modes first - top_modes = self.get_top_modes( - layer_idx, concept_name, mode_method, mode_param, - mode_min_indices, mode_max_indices, select_first - ) +# # Get top modes first +# top_modes = self.get_top_modes( +# layer_idx, concept_name, mode_method, mode_param, +# mode_min_indices, mode_max_indices, select_first +# ) - # Collect channels from all modes - all_channels = [] +# # Collect channels from all modes +# all_channels = [] - for mode_idx in top_modes: - # Get the dictionary vector (atom) for this mode - atom = layer.dictionary[mode_idx, :] +# for mode_idx in top_modes: +# # Get the dictionary vector (atom) for this mode +# atom = layer.dictionary[mode_idx, :] - # Get top channels for this atom - top_channels = select_significant_indices( - atom, - method=channel_method, - param=channel_param, - min_indices=channel_min_indices, - max_indices=channel_max_indices - ) +# # Get top channels for this atom +# top_channels = select_significant_indices( +# atom, +# method=channel_method, +# param=channel_param, +# min_indices=channel_min_indices, +# max_indices=channel_max_indices +# ) - all_channels.extend(top_channels) +# all_channels.extend(top_channels) - # Determine which channels to return - if len(top_modes) == 1: - channels = top_channels - else: - channels = all_channels +# # Determine which channels to return +# if len(top_modes) == 1: +# channels = top_channels +# else: +# channels = all_channels - # Apply concat mapping if requested - if concat: - channels = self.map_to_original_channels(channels, layer) +# # Apply concat mapping if requested +# if concat: +# channels = self.map_to_original_channels(channels, layer) - return channels +# return channels - def get_mode_correlations(self, layer_idx: int, concept_name: str, - select_first: bool = True) -> np.ndarray: - """ - Get correlation values for all modes with a specific concept. +# def get_mode_correlations(self, layer_idx: int, concept_name: str, +# select_first: bool = True) -> np.ndarray: +# """ +# Get correlation values for all modes with a specific concept. - Args: - layer_idx: Index of the layer - concept_name: Name of the concept - select_first: Whether to auto-select first concept match +# Args: +# layer_idx: Index of the layer +# concept_name: Name of the concept +# select_first: Whether to auto-select first concept match - Returns: - Array of correlation values (equivalent to your rs array) - """ - layer = self.get_layer(layer_idx) - concept_idx, _ = self.get_concept_info(concept_name, select_first) - return layer.corr_mtx[:, concept_idx] +# Returns: +# Array of correlation values (equivalent to your rs array) +# """ +# layer = self.get_layer(layer_idx) +# concept_idx, _ = self.get_concept_info(concept_name, select_first) +# return layer.corr_mtx[:, concept_idx] - def get_concept_sample_indices(self, concept_idx: int) -> np.ndarray: - """ - Get the sample indices for a specific concept. +# def get_concept_sample_indices(self, concept_idx: int) -> np.ndarray: +# """ +# Get the sample indices for a specific concept. - Args: - concept_idx: The index of the concept +# Args: +# concept_idx: The index of the concept - Returns: - Array of sample indices where the concept is present - """ - return np.where(self.mask_matrix[:, concept_idx] == 1)[0] - - def find_similar_concepts_by_channels( - self, - seed_concept: str, - layer_idx: int, - mode_method: str = 'percentile', - mode_param: float = 0.7, - channel_method: str = 'std', # Changed default to 'std' - channel_param: float = 2.0, # Changed default to 2.0 std deviations - min_overlap: int = 1, - select_first: bool = True, - concat=False, - # Remove top_n_channels parameter to avoid limiting - ): - """ - Find concepts that share the most contributing channels with a seed concept, - using consistent channel selection criteria for all concepts. +# Returns: +# Array of sample indices where the concept is present +# """ +# return np.where(self.mask_matrix[:, concept_idx] == 1)[0] + +# def find_similar_concepts_by_channels( +# self, +# seed_concept: str, +# layer_idx: int, +# mode_method: str = 'percentile', +# mode_param: float = 0.7, +# channel_method: str = 'std', # Changed default to 'std' +# channel_param: float = 2.0, # Changed default to 2.0 std deviations +# min_overlap: int = 1, +# select_first: bool = True, +# concat=False, +# # Remove top_n_channels parameter to avoid limiting +# ): +# """ +# Find concepts that share the most contributing channels with a seed concept, +# using consistent channel selection criteria for all concepts. - Args: - seed_concept: The concept to find similar concepts for - layer_idx: Which layer to analyze - mode_method: Method for selecting the most salient mode - mode_param: Parameter for mode selection - channel_method: Method for selecting top channels ('percentile', 'std', 'threshold', etc.) - channel_param: Parameter for channel selection (percentile value, std deviations, or threshold) - min_overlap: Minimum number of shared channels to include in results - select_first: Whether to auto-select first concept match +# Args: +# seed_concept: The concept to find similar concepts for +# layer_idx: Which layer to analyze +# mode_method: Method for selecting the most salient mode +# mode_param: Parameter for mode selection +# channel_method: Method for selecting top channels ('percentile', 'std', 'threshold', etc.) +# channel_param: Parameter for channel selection (percentile value, std deviations, or threshold) +# min_overlap: Minimum number of shared channels to include in results +# select_first: Whether to auto-select first concept match - Returns: - List of tuples: (concept_name, shared_count, overlap_ratio, shared_channels) - Sorted by number of shared channels (descending) - """ +# Returns: +# List of tuples: (concept_name, shared_count, overlap_ratio, shared_channels) +# Sorted by number of shared channels (descending) +# """ - print(f"Finding concepts similar to '{seed_concept}' at layer {layer_idx}") - print(f"Using channel selection method: {channel_method}, param: {channel_param}") - print("-" * 60) - - # Get seed concept's channels using the specified method - try: - seed_channels = self.get_top_channels( - layer_idx=layer_idx, - concept_name=seed_concept, - mode_method=mode_method, - mode_param=mode_param, - mode_min_indices=1, - mode_max_indices=1, # Just get the most salient mode - channel_method=channel_method, - channel_param=channel_param, - channel_min_indices=1, # Minimum of 1 channel - channel_max_indices=float('inf'), # No upper limit - important! - select_first=select_first, - concat=concat, - ) +# print(f"Finding concepts similar to '{seed_concept}' at layer {layer_idx}") +# print(f"Using channel selection method: {channel_method}, param: {channel_param}") +# print("-" * 60) + +# # Get seed concept's channels using the specified method +# try: +# seed_channels = self.get_top_channels( +# layer_idx=layer_idx, +# concept_name=seed_concept, +# mode_method=mode_method, +# mode_param=mode_param, +# mode_min_indices=1, +# mode_max_indices=1, # Just get the most salient mode +# channel_method=channel_method, +# channel_param=channel_param, +# channel_min_indices=1, # Minimum of 1 channel +# channel_max_indices=float('inf'), # No upper limit - important! +# select_first=select_first, +# concat=concat, +# ) - # No limit on top_n_channels anymore - seed_channels_list = seed_channels # Keep ordered list - seed_channels_set = set(seed_channels_list) # Create set for intersections +# # No limit on top_n_channels anymore +# seed_channels_list = seed_channels # Keep ordered list +# seed_channels_set = set(seed_channels_list) # Create set for intersections - except Exception as e: - print(f"Error getting channels for seed concept '{seed_concept}': {e}") - return [] +# except Exception as e: +# print(f"Error getting channels for seed concept '{seed_concept}': {e}") +# return [] - print(f"Seed concept '{seed_concept}' selected {len(seed_channels)} channels: {seed_channels}") +# print(f"Seed concept '{seed_concept}' selected {len(seed_channels)} channels: {seed_channels}") - # Compare with all other concepts - results = [] +# # Compare with all other concepts +# results = [] - # Get all concepts to check - syn = bic.SemanticAnalyzer('/home/zalaoui/semantic_indexes_test.json') - _, imagenet_class_names = syn.get_all_imagenet_masks(list(range(1000))) +# # Get all concepts to check +# syn = bic.SemanticAnalyzer('/home/zalaoui/semantic_indexes_test.json') +# _, imagenet_class_names = syn.get_all_imagenet_masks(list(range(1000))) - for concept_label in imagenet_class_names: - # Skip the seed concept itself - base_concept_name = concept_label.split('.')[0] - if base_concept_name.lower() == seed_concept.lower(): - continue +# for concept_label in imagenet_class_names: +# # Skip the seed concept itself +# base_concept_name = concept_label.split('.')[0] +# if base_concept_name.lower() == seed_concept.lower(): +# continue - try: - # Get this concept's channels using EXACTLY THE SAME method and parameters - concept_channels = self.get_top_channels( - layer_idx=layer_idx, - concept_name=base_concept_name, - mode_method=mode_method, - mode_param=mode_param, - mode_min_indices=1, - mode_max_indices=1, # Just get the most salient mode - channel_method=channel_method, - channel_param=channel_param, - channel_min_indices=1, # Minimum of 1 channel - channel_max_indices=float('inf'), # No upper limit - important! - select_first=True, # Auto-select to avoid prompts - concat=concat - ) +# try: +# # Get this concept's channels using EXACTLY THE SAME method and parameters +# concept_channels = self.get_top_channels( +# layer_idx=layer_idx, +# concept_name=base_concept_name, +# mode_method=mode_method, +# mode_param=mode_param, +# mode_min_indices=1, +# mode_max_indices=1, # Just get the most salient mode +# channel_method=channel_method, +# channel_param=channel_param, +# channel_min_indices=1, # Minimum of 1 channel +# channel_max_indices=float('inf'), # No upper limit - important! +# select_first=True, # Auto-select to avoid prompts +# concat=concat +# ) - # No limit on top_n_channels anymore - concept_channels_set = set(concept_channels) +# # No limit on top_n_channels anymore +# concept_channels_set = set(concept_channels) - # Calculate overlap - shared_channels = seed_channels_set.intersection(concept_channels_set) - shared_count = len(shared_channels) +# # Calculate overlap +# shared_channels = seed_channels_set.intersection(concept_channels_set) +# shared_count = len(shared_channels) - # Calculate overlap ratio based on seed channels - total_channels = len(seed_channels_set) - overlap_ratio = shared_count / total_channels if total_channels > 0 else 0 +# # Calculate overlap ratio based on seed channels +# total_channels = len(seed_channels_set) +# overlap_ratio = shared_count / total_channels if total_channels > 0 else 0 - # Only include if meets minimum overlap threshold - if shared_count >= min_overlap: - results.append(( - base_concept_name, - shared_count, - overlap_ratio, - sorted(list(shared_channels)) - )) +# # Only include if meets minimum overlap threshold +# if shared_count >= min_overlap: +# results.append(( +# base_concept_name, +# shared_count, +# overlap_ratio, +# sorted(list(shared_channels)) +# )) - except Exception as e: - # Skip concepts that cause errors (e.g., not found) - continue - - # Sort by shared count (descending), then by overlap ratio - results.sort(key=lambda x: (x[1], x[2]), reverse=True) - print("=" * 80) - print(f"Seed concept '{seed_concept}' selected {len(seed_channels)} channels: {list(seed_channels)}") - return results - def print_similar_concepts( - self, - results: List[Tuple[str, int, float, List[int]]], - top_n: int = 10, - show_channels: bool = True - ): - """ - Pretty print the results from find_similar_concepts_by_channels +# except Exception as e: +# # Skip concepts that cause errors (e.g., not found) +# continue + +# # Sort by shared count (descending), then by overlap ratio +# results.sort(key=lambda x: (x[1], x[2]), reverse=True) +# print("=" * 80) +# print(f"Seed concept '{seed_concept}' selected {len(seed_channels)} channels: {list(seed_channels)}") +# return results +# def print_similar_concepts( +# self, +# results: List[Tuple[str, int, float, List[int]]], +# top_n: int = 10, +# show_channels: bool = True +# ): +# """ +# Pretty print the results from find_similar_concepts_by_channels - Args: - results: Output from find_similar_concepts_by_channels - top_n: Number of top results to show - show_channels: Whether to show the actual shared channel numbers - """ +# Args: +# results: Output from find_similar_concepts_by_channels +# top_n: Number of top results to show +# show_channels: Whether to show the actual shared channel numbers +# """ - if not results: - print("No similar concepts found.") - return +# if not results: +# print("No similar concepts found.") +# return - print("Most Similar Concepts:") - print("=" * 80) +# print("Most Similar Concepts:") +# print("=" * 80) - for i, (concept, shared_count, overlap_ratio, shared_channels) in enumerate(results[:top_n]): - if concept == 'black_grouse': - print("HOLY BLACK GROUSEπŸ‘€πŸ‘€πŸ‘€") +# for i, (concept, shared_count, overlap_ratio, shared_channels) in enumerate(results[:top_n]): +# if concept == 'black_grouse': +# print("HOLY BLACK GROUSEπŸ‘€πŸ‘€πŸ‘€") - print(f"{i+1:2d}. {concept:20s} | " - f"Shared: {shared_count:2d} | " - f"Overlap: {overlap_ratio:.1%}") +# print(f"{i+1:2d}. {concept:20s} | " +# f"Shared: {shared_count:2d} | " +# f"Overlap: {overlap_ratio:.1%}") - if show_channels and shared_channels: - # Show channels in groups of 10 for readability - channel_str = str(shared_channels) - if len(channel_str) > 80: - channel_str = channel_str[:77] + "..." - print(f" Shared channels: {channel_str}") - print() +# if show_channels and shared_channels: +# # Show channels in groups of 10 for readability +# channel_str = str(shared_channels) +# if len(channel_str) > 80: +# channel_str = channel_str[:77] + "..." +# print(f" Shared channels: {channel_str}") +# print() - def plot_mode_comparison( - self, - seed_concept: str, - results: List[Tuple[str, int, float, List[int]]], - layer_idx: int, - top_n_display: int = 10, - mode_method: str = 'top_n', - mode_param: int = 1, - figsize_per_subplot: Tuple[int, int] = (12, 3), - # Remove seed_top_n_channels parameter - channel_method: str = 'std', # Add these parameters to be consistent - channel_param: float = 2.0, - concat=False, # with find_similar_concepts_by_channels, - text=False - ): - """ - Plot seed concept's top mode and similar concepts' top modes with shared channels highlighted. +# def plot_mode_comparison( +# self, +# seed_concept: str, +# results: List[Tuple[str, int, float, List[int]]], +# layer_idx: int, +# top_n_display: int = 10, +# mode_method: str = 'top_n', +# mode_param: int = 1, +# figsize_per_subplot: Tuple[int, int] = (12, 3), +# # Remove seed_top_n_channels parameter +# channel_method: str = 'std', # Add these parameters to be consistent +# channel_param: float = 2.0, +# concat=False, # with find_similar_concepts_by_channels, +# text=False +# ): +# """ +# Plot seed concept's top mode and similar concepts' top modes with shared channels highlighted. - Args: - analyzer: ModeAnalyzer instance - seed_concept: The original seed concept - results: Output from find_similar_concepts_by_channels - layer_idx: Layer to analyze - top_n_display: Number of similar concepts to display - mode_method: Method for getting top mode - mode_param: Parameter for mode selection - figsize_per_subplot: Size of each subplot - channel_method: Method for selecting channels (same as in find_similar_concepts) - channel_param: Parameter for channel selection (same as in find_similar_concepts) - """ +# Args: +# analyzer: ModeAnalyzer instance +# seed_concept: The original seed concept +# results: Output from find_similar_concepts_by_channels +# layer_idx: Layer to analyze +# top_n_display: Number of similar concepts to display +# mode_method: Method for getting top mode +# mode_param: Parameter for mode selection +# figsize_per_subplot: Size of each subplot +# channel_method: Method for selecting channels (same as in find_similar_concepts) +# channel_param: Parameter for channel selection (same as in find_similar_concepts) +# """ - display_results = results[:top_n_display] - total_concepts = len(display_results) + 1 # +1 for seed - - # Create subplot grid - fig, axes = plt.subplots(total_concepts, 1, - figsize=(figsize_per_subplot[0], - figsize_per_subplot[1] * total_concepts)) - - if total_concepts == 1: - axes = [axes] - - # Get layer for extracting dictionary atoms - layer = self.get_layer(layer_idx) - - # 1. Plot SEED CONCEPT first - print(f"Getting seed concept '{seed_concept}' top mode...") - seed_modes = self.get_top_modes( - layer_idx=layer_idx, - concept_name=seed_concept, - method=mode_method, - param=mode_param, - min_indices=1, - max_indices=1 - ) - - seed_mode_idx = seed_modes[0] - seed_atom = layer.dictionary[seed_mode_idx, :] - seed_correlations = self.get_mode_correlations(layer_idx=layer_idx, concept_name=seed_concept) - - # Get seed's top channels using the SAME method as in find_similar_concepts - seed_top_channels = self.get_top_channels( - layer_idx=layer_idx, - concept_name=seed_concept, - mode_method=mode_method, - mode_param=mode_param, - mode_min_indices=1, - mode_max_indices=1, - channel_method=channel_method, - channel_param=channel_param, - channel_min_indices=1, - channel_max_indices=float('inf'), - select_first=True, - concat=concat - ) - seed_top_channels_set = set(seed_top_channels) - - # Plot seed concept - ax = axes[0] - ax.plot(seed_atom, 'k-', linewidth=1, alpha=0.8) - ax.set_title(f'SEED: {seed_concept} (mode {seed_mode_idx}, corr={seed_correlations[seed_mode_idx]:.3f})', - fontsize=8, fontweight='bold', color='blue') - ax.set_ylabel('Activation', fontsize=12) - - if text: - - for ch in seed_top_channels_set: - ax.axvline(x=ch, color='blue', linestyle='-', alpha=0.2, linewidth=1.5) +# display_results = results[:top_n_display] +# total_concepts = len(display_results) + 1 # +1 for seed + +# # Create subplot grid +# fig, axes = plt.subplots(total_concepts, 1, +# figsize=(figsize_per_subplot[0], +# figsize_per_subplot[1] * total_concepts)) + +# if total_concepts == 1: +# axes = [axes] + +# # Get layer for extracting dictionary atoms +# layer = self.get_layer(layer_idx) + +# # 1. Plot SEED CONCEPT first +# print(f"Getting seed concept '{seed_concept}' top mode...") +# seed_modes = self.get_top_modes( +# layer_idx=layer_idx, +# concept_name=seed_concept, +# method=mode_method, +# param=mode_param, +# min_indices=1, +# max_indices=1 +# ) + +# seed_mode_idx = seed_modes[0] +# seed_atom = layer.dictionary[seed_mode_idx, :] +# seed_correlations = self.get_mode_correlations(layer_idx=layer_idx, concept_name=seed_concept) + +# # Get seed's top channels using the SAME method as in find_similar_concepts +# seed_top_channels = self.get_top_channels( +# layer_idx=layer_idx, +# concept_name=seed_concept, +# mode_method=mode_method, +# mode_param=mode_param, +# mode_min_indices=1, +# mode_max_indices=1, +# channel_method=channel_method, +# channel_param=channel_param, +# channel_min_indices=1, +# channel_max_indices=float('inf'), +# select_first=True, +# concat=concat +# ) +# seed_top_channels_set = set(seed_top_channels) + +# # Plot seed concept +# ax = axes[0] +# ax.plot(seed_atom, 'k-', linewidth=1, alpha=0.8) +# ax.set_title(f'SEED: {seed_concept} (mode {seed_mode_idx}, corr={seed_correlations[seed_mode_idx]:.3f})', +# fontsize=8, fontweight='bold', color='blue') +# ax.set_ylabel('Activation', fontsize=12) + +# if text: + +# for ch in seed_top_channels_set: +# ax.axvline(x=ch, color='blue', linestyle='-', alpha=0.2, linewidth=1.5) - ax.text(0.02, 0.95, f'{len(seed_top_channels_set)} top channels', - transform=ax.transAxes, va='top', fontsize=10, - bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.3)) +# ax.text(0.02, 0.95, f'{len(seed_top_channels_set)} top channels', +# transform=ax.transAxes, va='top', fontsize=10, +# bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.3)) - # 2. Plot SIMILAR CONCEPTS - for i, (concept_name, shared_count, overlap_ratio, shared_channels) in enumerate(display_results): - ax = axes[i + 1] +# # 2. Plot SIMILAR CONCEPTS +# for i, (concept_name, shared_count, overlap_ratio, shared_channels) in enumerate(display_results): +# ax = axes[i + 1] - # Get this concept's top mode - try: - # Extract base name if needed - base_concept_name = concept_name.split('.')[0] +# # Get this concept's top mode +# try: +# # Extract base name if needed +# base_concept_name = concept_name.split('.')[0] - concept_modes = self.get_top_modes( - layer_idx=layer_idx, - concept_name=base_concept_name, - method=mode_method, - param=mode_param, - min_indices=1, - max_indices=1, - select_first=True - ) +# concept_modes = self.get_top_modes( +# layer_idx=layer_idx, +# concept_name=base_concept_name, +# method=mode_method, +# param=mode_param, +# min_indices=1, +# max_indices=1, +# select_first=True +# ) - concept_mode_idx = concept_modes[0] - concept_atom = layer.dictionary[concept_mode_idx, :] - concept_correlations = self.get_mode_correlations( - layer_idx=layer_idx, concept_name=base_concept_name, select_first=True - ) +# concept_mode_idx = concept_modes[0] +# concept_atom = layer.dictionary[concept_mode_idx, :] +# concept_correlations = self.get_mode_correlations( +# layer_idx=layer_idx, concept_name=base_concept_name, select_first=True +# ) - # Plot the atom - ax.plot(concept_atom, 'k-', linewidth=1, alpha=0.8) - ax.set_title(f'{concept_name} (mode {concept_mode_idx}, corr={concept_correlations[concept_mode_idx]:.3f})\n' - f'Shared: {shared_count}/{len(seed_top_channels_set)} ({overlap_ratio:.1%})', - fontsize=8) - ax.set_ylabel('Activation', fontsize=8) - - if text: +# # Plot the atom +# ax.plot(concept_atom, 'k-', linewidth=1, alpha=0.8) +# ax.set_title(f'{concept_name} (mode {concept_mode_idx}, corr={concept_correlations[concept_mode_idx]:.3f})\n' +# f'Shared: {shared_count}/{len(seed_top_channels_set)} ({overlap_ratio:.1%})', +# fontsize=8) +# ax.set_ylabel('Activation', fontsize=8) + +# if text: - # Highlight shared channels in RED - for ch in shared_channels: - ax.axvline(x=ch, color='red', linestyle='-', alpha=0.2, linewidth=2) +# # Highlight shared channels in RED +# for ch in shared_channels: +# ax.axvline(x=ch, color='red', linestyle='-', alpha=0.2, linewidth=2) - # Highlight seed's non-shared top channels in light blue - non_shared_seed_channels = seed_top_channels_set - set(shared_channels) - for ch in non_shared_seed_channels: - ax.axvline(x=ch, color='lightblue', linestyle='--', alpha=0.4, linewidth=1) +# # Highlight seed's non-shared top channels in light blue +# non_shared_seed_channels = seed_top_channels_set - set(shared_channels) +# for ch in non_shared_seed_channels: +# ax.axvline(x=ch, color='lightblue', linestyle='--', alpha=0.4, linewidth=1) - # Add legend info - ax.text(0.02, 0.95, f'{len(shared_channels)} shared channels', - transform=ax.transAxes, va='top', fontsize=10, - bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral", alpha=0.7)) +# # Add legend info +# ax.text(0.02, 0.95, f'{len(shared_channels)} shared channels', +# transform=ax.transAxes, va='top', fontsize=10, +# bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral", alpha=0.7)) - except Exception as e: - ax.text(0.5, 0.5, f'Error loading {concept_name}:\n{str(e)}', - transform=ax.transAxes, ha='center', va='center', fontsize=10) - ax.set_title(f'{concept_name} (ERROR)', fontsize=12, color='red') - - # Set x-label only on bottom plot - axes[-1].set_xlabel('Channel Index', fontsize=12) - # No x-tick labels - for i, ax in enumerate(axes): - if i != len(axes) - 1: - ax.set_xticks([]) - - # Add overall title and legend - fig.suptitle(f'Mode Comparison: {seed_concept} vs Similar Concepts (Layer {layer_idx})', - fontsize=8, fontweight='bold') - - # Create legend - if text: - from matplotlib.lines import Line2D - legend_elements = [ - Line2D([0], [0], color='blue', lw=2, label=f'{seed_concept} top channels'), - Line2D([0], [0], color='red', lw=2, label='Shared channels'), - Line2D([0], [0], color='lightblue', lw=1, linestyle='--', alpha=0.6, - label=f'{seed_concept} non-shared channels') - ] - fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.98, 0.98)) - - plt.tight_layout() - plt.subplots_adjust(top=0.92) # Make room for suptitle and legend - plt.subplots_adjust(hspace=0.5) # Increase vertical spacing - plt.show() - - return fig - - def get_concepts_with_shared_channels( - self, - channels: List[int], - layer_idx: int, - top_n: int = 5, - concat: bool= False, - min_overlap: int = 1 - ) -> List[Tuple[str, int, List[int]]]: - """ - Find concepts that have the given channels in their top channels list. +# except Exception as e: +# ax.text(0.5, 0.5, f'Error loading {concept_name}:\n{str(e)}', +# transform=ax.transAxes, ha='center', va='center', fontsize=10) +# ax.set_title(f'{concept_name} (ERROR)', fontsize=12, color='red') + +# # Set x-label only on bottom plot +# axes[-1].set_xlabel('Channel Index', fontsize=12) +# # No x-tick labels +# for i, ax in enumerate(axes): +# if i != len(axes) - 1: +# ax.set_xticks([]) + +# # Add overall title and legend +# fig.suptitle(f'Mode Comparison: {seed_concept} vs Similar Concepts (Layer {layer_idx})', +# fontsize=8, fontweight='bold') + +# # Create legend +# if text: +# from matplotlib.lines import Line2D +# legend_elements = [ +# Line2D([0], [0], color='blue', lw=2, label=f'{seed_concept} top channels'), +# Line2D([0], [0], color='red', lw=2, label='Shared channels'), +# Line2D([0], [0], color='lightblue', lw=1, linestyle='--', alpha=0.6, +# label=f'{seed_concept} non-shared channels') +# ] +# fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.98, 0.98)) + +# plt.tight_layout() +# plt.subplots_adjust(top=0.92) # Make room for suptitle and legend +# plt.subplots_adjust(hspace=0.5) # Increase vertical spacing +# plt.show() + +# return fig + +# def get_concepts_with_shared_channels( +# self, +# channels: List[int], +# layer_idx: int, +# top_n: int = 5, +# concat: bool= False, +# min_overlap: int = 1 +# ) -> List[Tuple[str, int, List[int]]]: +# """ +# Find concepts that have the given channels in their top channels list. - Args: - channels: List of channel indices to search for - layer_idx: Layer to analyze - top_n: Number of top channels to get for each concept - min_overlap: Minimum number of shared channels to include +# Args: +# channels: List of channel indices to search for +# layer_idx: Layer to analyze +# top_n: Number of top channels to get for each concept +# min_overlap: Minimum number of shared channels to include - Returns: - List of (concept_name, shared_count, shared_channels) sorted by shared_count - """ - results = [] - channels_set = set(channels) +# Returns: +# List of (concept_name, shared_count, shared_channels) sorted by shared_count +# """ +# results = [] +# channels_set = set(channels) - for concept_label in self.summary.mask_labels: - try: - # Get this concept's top channels - concept_channels = self.get_top_channels( - layer_idx=layer_idx, - concept_name=concept_label, - mode_method='top_n', - mode_param=1, - mode_min_indices=1, - mode_max_indices=1, - channel_method='top_n', - channel_param=top_n, - channel_min_indices=top_n, - channel_max_indices=top_n, - select_first=True, - concat=concat - ) +# for concept_label in self.summary.mask_labels: +# try: +# # Get this concept's top channels +# concept_channels = self.get_top_channels( +# layer_idx=layer_idx, +# concept_name=concept_label, +# mode_method='top_n', +# mode_param=1, +# mode_min_indices=1, +# mode_max_indices=1, +# channel_method='top_n', +# channel_param=top_n, +# channel_min_indices=top_n, +# channel_max_indices=top_n, +# select_first=True, +# concat=concat +# ) - concept_channels_set = set(concept_channels[:top_n]) - shared_channels = list(channels_set.intersection(concept_channels_set)) - shared_count = len(shared_channels) +# concept_channels_set = set(concept_channels[:top_n]) +# shared_channels = list(channels_set.intersection(concept_channels_set)) +# shared_count = len(shared_channels) - if shared_count >= min_overlap: - results.append((concept_label, shared_count, shared_channels)) +# if shared_count >= min_overlap: +# results.append((concept_label, shared_count, shared_channels)) - except Exception as e: - # Skip concepts that cause errors - continue +# except Exception as e: +# # Skip concepts that cause errors +# continue - # Sort by shared count (descending) - results.sort(key=lambda x: x[1], reverse=True) - return results - def map_to_original_channels(self, doubled_indices, layer): - """Convert doubled channel indices back to original channel indices""" - original_n_channels = layer.dictionary.shape[1] // 2 - - pos_channels = [] - neg_channels = [] - - for idx in doubled_indices: - if idx < original_n_channels: - pos_channels.append(idx) - else: - neg_channels.append(idx - original_n_channels) - - # Check for conflicts and print if found - pos_set = set(pos_channels) - neg_set = set(neg_channels) - overlap = pos_set.intersection(neg_set) - - print(f"Found {len(overlap)} channels in both positive and negative: {sorted(overlap)}") - - return list(np.unique(pos_channels + neg_channels)) - def discover_related_concepts( - self, - seed_concept: str, - layer_idx: int, - seed_top_channels: int = 5, - concepts_top_channels: int = 5, - min_overlap: int = 1, - mode_method: str = 'top_n', - mode_param: int = 1, - channel_method: str = 'top_n', - channel_param: int = 10, - select_first: bool = True - ) -> Tuple[List[str], List[int], Dict]: - """ - Discover related concepts through direct channel overlap analysis. +# # Sort by shared count (descending) +# results.sort(key=lambda x: x[1], reverse=True) +# return results +# def map_to_original_channels(self, doubled_indices, layer): +# """Convert doubled channel indices back to original channel indices""" +# original_n_channels = layer.dictionary.shape[1] // 2 + +# pos_channels = [] +# neg_channels = [] + +# for idx in doubled_indices: +# if idx < original_n_channels: +# pos_channels.append(idx) +# else: +# neg_channels.append(idx - original_n_channels) + +# # Check for conflicts and print if found +# pos_set = set(pos_channels) +# neg_set = set(neg_channels) +# overlap = pos_set.intersection(neg_set) + +# print(f"Found {len(overlap)} channels in both positive and negative: {sorted(overlap)}") + +# return list(np.unique(pos_channels + neg_channels)) +# def discover_related_concepts( +# self, +# seed_concept: str, +# layer_idx: int, +# seed_top_channels: int = 5, +# concepts_top_channels: int = 5, +# min_overlap: int = 1, +# mode_method: str = 'top_n', +# mode_param: int = 1, +# channel_method: str = 'top_n', +# channel_param: int = 10, +# select_first: bool = True +# ) -> Tuple[List[str], List[int], Dict]: +# """ +# Discover related concepts through direct channel overlap analysis. - Args: - seed_concept: Starting concept - layer_idx: Layer to analyze - seed_top_channels: Number of top channels to get from seed - concepts_top_channels: Number of top channels to consider for each concept - min_overlap: Minimum number of shared channels to include - ... (other params same as existing methods) - - Returns: - Tuple of (all_concepts, all_channels, discovery_info) - """ +# Args: +# seed_concept: Starting concept +# layer_idx: Layer to analyze +# seed_top_channels: Number of top channels to get from seed +# concepts_top_channels: Number of top channels to consider for each concept +# min_overlap: Minimum number of shared channels to include +# ... (other params same as existing methods) - print(f"Starting discovery from seed concept: '{seed_concept}'") - print("=" * 60) - - # Step 1: Get seed concept's top channels - print(f"Step 1: Getting top {seed_top_channels} channels for '{seed_concept}'") - - seed_channels = self.get_top_channels( - layer_idx=layer_idx, - concept_name=seed_concept, - mode_method=mode_method, - mode_param=mode_param, - mode_min_indices=1, - mode_max_indices=1, - channel_method=channel_method, - channel_param=channel_param, - channel_min_indices=seed_top_channels, - channel_max_indices=seed_top_channels, - select_first=select_first - ) - - seed_channels = seed_channels[:seed_top_channels] - print(f"Seed channels: {seed_channels}") - print() - - # Step 2: Find concepts that share these channels - print(f"Step 2: Finding concepts that have these channels in their top {concepts_top_channels}") - - shared_concepts = self.get_concepts_with_shared_channels( - channels=seed_channels, - layer_idx=layer_idx, - top_n=concepts_top_channels, - min_overlap=min_overlap - ) - - # Remove seed concept from results and print - discovered_concepts = [] - for concept_name, shared_count, shared_channels_list in shared_concepts: - if concept_name.lower() != seed_concept.lower(): - discovered_concepts.append(concept_name) - print(f"{concept_name}: {shared_count} shared channels {shared_channels_list}") - - print(f"\nDiscovered {len(discovered_concepts)} related concepts") - print() - - # Step 3: Get top channels for each discovered concept - print(f"Step 3: Getting top {seed_top_channels} channels for each discovered concept") - - all_channels = set(seed_channels) - concept_channel_map = {seed_concept: seed_channels} - - for concept in discovered_concepts: - try: - concept_channels = self.get_top_channels( - layer_idx=layer_idx, - concept_name=concept, - mode_method=mode_method, - mode_param=mode_param, - mode_min_indices=1, - mode_max_indices=1, - channel_method=channel_method, - channel_param=channel_param, - channel_min_indices=seed_top_channels, - channel_max_indices=seed_top_channels, - select_first=True, - ) - - concept_channels = concept_channels[:seed_top_channels] - concept_channel_map[concept] = concept_channels - all_channels.update(concept_channels) - - print(f"{concept}: {concept_channels}") - - except Exception as e: - print(f"Error getting channels for {concept}: {e}") - continue - - all_channels = sorted(list(all_channels)) - all_concepts = [seed_concept] + discovered_concepts - - print() - print("=" * 60) - print("DISCOVERY SUMMARY") - print("=" * 60) - print(f"Seed concept: {seed_concept}") - print(f"Discovered concepts ({len(discovered_concepts)}): {discovered_concepts}") - print(f"Total concepts: {len(all_concepts)}") - print(f"Total unique channels: {len(all_channels)}") - print(f"All channels: {all_channels}") - - # Package discovery info - discovery_info = { - 'seed_concept': seed_concept, - 'seed_channels': seed_channels, - 'shared_concepts': shared_concepts, - 'concept_channel_map': concept_channel_map, - 'discovered_concepts': discovered_concepts - } - - return all_concepts, all_channels, discovery_info - - def print_discovery_network(self, discovery_info: Dict): - """ - Pretty print the discovery network showing connections. - """ - print("\nDISCOVERY NETWORK") - print("=" * 50) - - seed = discovery_info['seed_concept'] - seed_channels = discovery_info['seed_channels'] - shared_concepts = discovery_info['shared_concepts'] - - print(f"🌱 SEED: {seed}") - print(f" Channels: {seed_channels}") - print() - - print("πŸ”— CONCEPTS WITH SHARED CHANNELS:") - for concept_name, shared_count, shared_channels_list in shared_concepts: - if concept_name.lower() != seed.lower(): - print(f" {concept_name}: {shared_count} shared {shared_channels_list}") - - print("\nπŸ“Š DISCOVERED CONCEPT CHANNELS:") - concept_channel_map = discovery_info['concept_channel_map'] - for concept, channels in concept_channel_map.items(): - if concept != seed: - print(f" {concept}: {channels}") +# Returns: +# Tuple of (all_concepts, all_channels, discovery_info) +# """ + +# print(f"Starting discovery from seed concept: '{seed_concept}'") +# print("=" * 60) + +# # Step 1: Get seed concept's top channels +# print(f"Step 1: Getting top {seed_top_channels} channels for '{seed_concept}'") + +# seed_channels = self.get_top_channels( +# layer_idx=layer_idx, +# concept_name=seed_concept, +# mode_method=mode_method, +# mode_param=mode_param, +# mode_min_indices=1, +# mode_max_indices=1, +# channel_method=channel_method, +# channel_param=channel_param, +# channel_min_indices=seed_top_channels, +# channel_max_indices=seed_top_channels, +# select_first=select_first +# ) + +# seed_channels = seed_channels[:seed_top_channels] +# print(f"Seed channels: {seed_channels}") +# print() + +# # Step 2: Find concepts that share these channels +# print(f"Step 2: Finding concepts that have these channels in their top {concepts_top_channels}") + +# shared_concepts = self.get_concepts_with_shared_channels( +# channels=seed_channels, +# layer_idx=layer_idx, +# top_n=concepts_top_channels, +# min_overlap=min_overlap +# ) + +# # Remove seed concept from results and print +# discovered_concepts = [] +# for concept_name, shared_count, shared_channels_list in shared_concepts: +# if concept_name.lower() != seed_concept.lower(): +# discovered_concepts.append(concept_name) +# print(f"{concept_name}: {shared_count} shared channels {shared_channels_list}") + +# print(f"\nDiscovered {len(discovered_concepts)} related concepts") +# print() + +# # Step 3: Get top channels for each discovered concept +# print(f"Step 3: Getting top {seed_top_channels} channels for each discovered concept") + +# all_channels = set(seed_channels) +# concept_channel_map = {seed_concept: seed_channels} + +# for concept in discovered_concepts: +# try: +# concept_channels = self.get_top_channels( +# layer_idx=layer_idx, +# concept_name=concept, +# mode_method=mode_method, +# mode_param=mode_param, +# mode_min_indices=1, +# mode_max_indices=1, +# channel_method=channel_method, +# channel_param=channel_param, +# channel_min_indices=seed_top_channels, +# channel_max_indices=seed_top_channels, +# select_first=True, +# ) + +# concept_channels = concept_channels[:seed_top_channels] +# concept_channel_map[concept] = concept_channels +# all_channels.update(concept_channels) + +# print(f"{concept}: {concept_channels}") + +# except Exception as e: +# print(f"Error getting channels for {concept}: {e}") +# continue + +# all_channels = sorted(list(all_channels)) +# all_concepts = [seed_concept] + discovered_concepts + +# print() +# print("=" * 60) +# print("DISCOVERY SUMMARY") +# print("=" * 60) +# print(f"Seed concept: {seed_concept}") +# print(f"Discovered concepts ({len(discovered_concepts)}): {discovered_concepts}") +# print(f"Total concepts: {len(all_concepts)}") +# print(f"Total unique channels: {len(all_channels)}") +# print(f"All channels: {all_channels}") + +# # Package discovery info +# discovery_info = { +# 'seed_concept': seed_concept, +# 'seed_channels': seed_channels, +# 'shared_concepts': shared_concepts, +# 'concept_channel_map': concept_channel_map, +# 'discovered_concepts': discovered_concepts +# } + +# return all_concepts, all_channels, discovery_info + +# def print_discovery_network(self, discovery_info: Dict): +# """ +# Pretty print the discovery network showing connections. +# """ +# print("\nDISCOVERY NETWORK") +# print("=" * 50) + +# seed = discovery_info['seed_concept'] +# seed_channels = discovery_info['seed_channels'] +# shared_concepts = discovery_info['shared_concepts'] + +# print(f"🌱 SEED: {seed}") +# print(f" Channels: {seed_channels}") +# print() + +# print("πŸ”— CONCEPTS WITH SHARED CHANNELS:") +# for concept_name, shared_count, shared_channels_list in shared_concepts: +# if concept_name.lower() != seed.lower(): +# print(f" {concept_name}: {shared_count} shared {shared_channels_list}") + +# print("\nπŸ“Š DISCOVERED CONCEPT CHANNELS:") +# concept_channel_map = discovery_info['concept_channel_map'] +# for concept, channels in concept_channel_map.items(): +# if concept != seed: +# print(f" {concept}: {channels}") def select_significant_indices(vector, method='threshold', param=0.8, min_indices=1, max_indices=None): """ diff --git a/bscope/ic/semantic_analyzer.py b/bscope/ic/semantic_analyzer.py index edc348e..f617073 100644 --- a/bscope/ic/semantic_analyzer.py +++ b/bscope/ic/semantic_analyzer.py @@ -417,4 +417,523 @@ def get_concepts_from_path(self, concept): return ordered_names - \ No newline at end of file + +import os +from IPython import embed +import json +import tqdm +import numpy as np +import json +import requests +import bscope +import numpy as np +import matplotlib.pyplot as plt +from scipy import signal +import os +from IPython import embed +import json +import tqdm +import numpy as np +import json +import requests +import bscope +import numpy as np +import matplotlib.pyplot as plt +from scipy import signal + +def get_top_mode(mode_summary, layer, class_idx, which_mode=1): + corrs = mode_summary.layers[layer].imgnet_corr_mtx + top_mode = np.argsort(corrs[:, class_idx])[::-1] + top_mode = top_mode[which_mode] + + atom = mode_summary.layers[layer].dictionary[top_mode] + loadings = mode_summary.layers[layer].loadings[:, top_mode] + corr = corrs[top_mode, class_idx] + return top_mode, atom, loadings, corr + +def single_image_semantic_loading(mode_summary, layer, image_idx): + corrs = mode_summary.layers[layer].imgnet_corr_mtx + loadings = mode_summary.layers[layer].loadings[image_idx] + loading_idxs = np.where(loadings > 0.5)[0] + imagenet_labels = mode_summary.imgnet_mask_labels + labels = [] + for loading_idx in loading_idxs: + # Classes + corr = corrs[loading_idx,:] + top_classes = np.argsort(corr)[::-1] + top_class = top_classes[0] + imagenet_label = imagenet_labels[top_class] + labels.append(imagenet_label) + + + return loadings, loading_idxs, labels + +def top_n(vector, n=5): + idxs = np.argsort(vector)[-n:][::-1] + return idxs, vector[idxs] + +def load_hierarchy(path='/data/hierarchy_metadata/pruned_hierarchy.json'): + with open(path, 'r') as f: + return json.load(f) + +def get_masks(path='/data/hierarchy_metadata/pruned_hierarchy.json', leaf_only=False, targets=None): + if targets is None: + targets = [] + for i in range(1000): + targets.extend(np.ones(50)*i) + targets = np.array(targets).astype(int) + + print("Loading hierarchy from:", path) + hierarchy = load_hierarchy(path) + + masks = [] + labels = [] + for k,v in hierarchy.items(): + idxs = v['idxs'] + + if leaf_only and not v['leaf']: + continue + + masks.append(np.isin(targets, idxs)) + labels.append(k) + + return np.array(masks), labels + +def chunk_masks(mask_matrix, bins = [50, 5000]): + + if bins[0] != 0: + bins = [0] + bins + + if bins[-1] != None: + bins = bins + [np.inf] + + summed = mask_matrix.sum(1) + + valid_idxs = [] + for b1, b2 in zip(bins[:-1], bins[1:]): + valid = np.where((summed > b1) & (summed <= b2))[0] + valid_idxs.append(valid) + + return valid_idxs + + # if target_indices is None: + # target_indices = list(range(1000)) + + # data = load_hierarchy(path) + # node_names = list(data.keys()) + # num_nodes = len(node_names) + + # masks = [] + + # for node in node_names: + # indices = set(data[node]['idxs']) + # mask = np.array([idx in indices for idx in target_indices]) + # masks.append(mask) + + # mask_array = np.array(masks) + # return mask_array, node_names + +# class SemanticAnalyzer: +# def __init__(self, semantic_hierarchy_path ='/data/codec/hierarchy_metadata/misc/semantic_indexes_test.json'): +# self.data = self.load_data(semantic_hierarchy_path) + + + +# def load_data(self, path): +# """Load synset data from JSON file.""" +# with open(path, 'r') as f: +# return json.load(f) + +# def recursively_clean_names(self, tree): +# new_tree = {} + +# # Reverse the name if it exists +# for k, v in tree.items(): +# if k !='name': +# new_tree[k] = v +# else: +# new_tree['name'] = tree['name'].split('.n')[0] + +# if 'children' in tree and tree['children']: +# new_tree['children'] = [self.recursively_clean_names(child) for child in tree['children']] + +# return new_tree + +# def add_level_depth(self, tree, current_level=0): +# """ +# Recursively traverse a hierarchical tree and add a 'level' key +# to each node indicating how many levels from the top it is. + +# Args: +# tree (dict): Dictionary with keys 'name', 'definition', 'children' +# current_level (int): Current depth level (0 for root) + +# Returns: +# dict: New tree with 'level' key added to each node +# """ +# # Create a new dictionary to avoid modifying the original +# new_tree = {} + +# for k, v in tree.items(): +# if k != 'level': +# new_tree[k] = v + +# # Add the level depth +# new_tree['level'] = current_level + +# # Recursively process children with incremented level +# if 'children' in tree and tree['children']: +# new_tree['children'] = [ +# self.add_level_depth(child, current_level + 1) +# for child in tree['children'] +# ] + +# return new_tree +# def indices_helper(self, name, partial_match=False): +# """ +# Retrieve synsets by name or partial match, including all indices from descendants. +# """ +# # Get all descendants using the get_descendant_nodes function +# all_descendants = self.get_descendant_nodes(name) + +# # If no descendants found, try with partial match +# if not all_descendants and partial_match: +# # Try again with partial matching +# direct_matches = {} +# for synset_name, info in self.data.items(): +# base_name = synset_name.split('.')[0] +# if name.lower() in base_name.lower(): +# direct_matches[synset_name] = info + +# # Return just these direct partial matches if no descendants found +# if direct_matches: +# return direct_matches + +# # Process the descendants to add combined indices to each direct match +# results = {} +# direct_matches_found = False + +# # First identify direct matches +# for synset_name, info in all_descendants.items(): +# base_name = synset_name.split('.')[0] +# if base_name.lower() == name.lower(): +# # This is a direct match - store it separately +# results[synset_name] = info.copy( +# ) # Copy to avoid modifying original +# direct_matches_found = True + +# # If no direct matches but we have descendants, include all descendants +# if not direct_matches_found and all_descendants: +# # Just return all descendants when no direct matches +# return all_descendants + +# # Now, let's add all indices from all descendants to each direct match +# if direct_matches_found: +# # Collect all indices from all descendants +# all_indices = [] +# for desc_name, desc_info in all_descendants.items(): +# # Get indices from this descendant +# if 'indices' in desc_info and desc_info['indices']: +# all_indices.extend(desc_info['indices']) + +# # Remove duplicates and sort +# all_indices = sorted(list(set(all_indices))) + +# # Add these combined indices to each direct match +# for match_name in results: +# results[match_name]['all_descendant_indices'] = all_indices + +# # Return the results +# return results + +# def get_indices(self, name, partial_match=False): +# """ +# Print matching synsets and also show all indices from all descendants. +# Returns a list of all unique indices from all descendants. +# """ +# # First get direct matches +# matches = self.indices_helper(name, partial_match) + +# if not matches: +# print(f"No synsets found with name '{name}'.") +# return [] + +# print(f"Found {len(matches)} matching term for '{name}':") + +# all_descendants = self.get_descendant_nodes(name) + +# # Collect all indices from all descendants +# all_indices = set() +# for desc_name, desc_info in all_descendants.items(): +# if 'indices' in desc_info and desc_info['indices']: +# all_indices.update(desc_info['indices']) + +# # Sort the indices +# all_indices = sorted(list(all_indices)) + + +# # Return all unique indices for use in mask function +# return all_indices + +# def get_parent_nodes(self, name, partial_match=False): +# """Find parent nodes for the given name.""" +# matches = self.indices_helper(name, partial_match) +# if not matches: +# print(f"No synsets found with name '{name}'.") +# return {} +# parent_nodes = {} +# for synset_name, info in matches.items(): +# full_path = info['path'] +# components = full_path.split('.') +# current_path = "" +# for i, component in enumerate(components[:-1]): +# current_path += ("." if i > 0 else "") + component +# for other_name, other_info in self.data.items(): +# if other_info['path'] == current_path: +# parent_nodes[other_name] = other_info +# return parent_nodes + +# def get_descendant_nodes(self, term): +# """ +# Find all synsets that are descendants of or match the given term. +# The term is matched against the name part of the synset. +# """ +# descendant_nodes = {} +# matched_synsets = [] + +# # First try to find direct matches +# for synset_name, info in self.data.items(): +# base_name = synset_name.split('.')[0] +# if base_name.lower() == term.lower(): +# descendant_nodes[synset_name] = info +# matched_synsets.append(synset_name) + +# # If we found direct matches, look for their descendants +# for matched_synset in matched_synsets: +# for synset_name, info in self.data.items(): +# if synset_name in descendant_nodes: +# continue # Skip if already added +# # Check if the synset is a descendant of any matched synset +# if matched_synset in info['path']: +# descendant_nodes[synset_name] = info + +# # If we still didn't find anything or if the term is something like 'aquatic_mammal', +# # try looking for it as a component in paths +# if not descendant_nodes: +# for synset_name, info in self.data.items(): +# if (term.decode('utf-8') if isinstance(term, bytes) else term) in info['path']: +# descendant_nodes[synset_name] = info +# # Now look for descendants of this synset +# for other_synset, other_info in self.data.items(): +# if synset_name in other_info[ +# 'path'] and other_synset != synset_name: +# descendant_nodes[other_synset] = other_info + +# return descendant_nodes + + +# def get_all_indices_for_search(self, term): +# """ +# Get ALL indices from all synsets that have the term in their name or path. +# This includes all direct and indirect descendants. +# """ +# all_indices = [] + +# # Get all descendant nodes +# descendants = self.get_descendant_nodes(term) + +# # Extract all indices from these descendants +# for info in descendants.values(): +# all_indices.extend(info.get('indices', [])) + +# return sorted(list(set(all_indices))) + +# def get_mask(self, search_term, target_indices): +# """ +# Create a boolean mask for target_indices based on whether each index +# is in the descendants of the search_term. +# """ +# all_indices = set(self.get_all_indices_for_search(search_term)) +# return np.array([idx in all_indices for idx in target_indices]) + +# def get_all_semantic_masks(self, target_indices): +# """ +# Create boolean masks for all nodes using the create_mask method. + +# Parameters: +# - target_indices: A list of indices to check against each node's descendants + +# Returns: +# - mask_array: A numpy array of shape (num_nodes, len(target_indices)) +# where mask_array[i, j] is True if target_indices[j] is in node i's descendants +# - node_names: List of node names corresponding to rows in the mask_array +# """ +# import numpy as np + +# # Get list of all synset names +# node_names = list(self.data.keys()) +# num_nodes = len(node_names) + +# # Initialize list to store masks +# masks = [] + +# # For each node, create a mask using the existing create_mask method +# for synset_name in node_names: +# # Get the base name without the synset identifier +# base_name = synset_name.split('.')[0] + +# # Use the existing create_mask method +# mask = self.get_mask(base_name, target_indices) +# masks.append(mask) + +# # Convert list of masks to numpy array +# mask_array = np.array(masks) + +# return mask_array, node_names + +# def get_all_imagenet_masks(self, target_indices): +# """ +# Create masks for the 1000 ImageNet classes, using the most specific class name for each index. + +# Parameters: +# - target_indices: A list of indices to check + +# Returns: +# - mask_array: A numpy array of shape (1000, len(target_indices)) +# - imagenet_class_names: List of the 1000 specific class names +# """ +# import numpy as np + +# # For each index, find the most specific class (i.e., the class with the longest path) +# index_to_specific_class = {} + +# for synset_name, data in self.data.items(): +# path_length = len(data['path'].split('.')) + +# for idx in data['indices']: +# if 0 <= idx < 1000: +# # If we haven't seen this index before, or if this class is more specific +# if idx not in index_to_specific_class or \ +# path_length > len(self.data[index_to_specific_class[idx]]['path'].split('.')): +# index_to_specific_class[idx] = synset_name + +# # Create a list of class names, one for each index +# class_names = [] +# for idx in range(1000): +# if idx in index_to_specific_class: +# class_names.append(index_to_specific_class[idx]) +# else: +# print(f"Warning: No class found for index {idx}") +# # Use a placeholder if no class is found +# class_names.append(f"unknown_class_{idx}") + +# # Create masks for each index +# masks = [] +# for idx in range(1000): +# # Simple mask where target index equals current index +# mask = np.array([i == idx for i in target_indices]) +# masks.append(mask) + +# # Convert list of masks to numpy array +# mask_array = np.array(masks) + +# return mask_array, class_names + +# def get_normalized_distance(self, concept, from_top=True, use_global_max=False): +# """ +# Get normalized distance (0-1) for a concept from top or bottom of the hierarchy. + +# Parameters: +# concept (str): Concept name (e.g., 'dog') +# from_top (bool): If True, distance is measured from root to concept. +# If False, distance is measured from concept to leaf. +# use_global_max (bool): If True, normalize against global max depth of hierarchy. +# If False, normalize within concept's own branch. + +# Returns: +# float or None: Normalized distance in [0, 1] or None if concept not found. +# """ +# # Find the concept using same method as get_indices +# matches = self.indices_helper(concept, partial_match=True) + +# if not matches: +# print(f"[Warning] Concept '{concept}' not found in semantic data.") +# return None + +# # Choose the best match (you can modify sorting to prefer specific heuristics) +# concept_name = sorted(matches.keys(), key=lambda k: len(self.data[k]['path']))[0] +# concept_data = self.data[concept_name] + +# # Get concept distances +# distance_from_root = concept_data.get('distance_from_root', 0) +# distance_to_leaves = concept_data.get('distance_to_leaves', 0) + + +# # Calculate the normalization factor +# if use_global_max: +# max_depth = 0 +# for info in self.data.values(): +# branch_depth = info.get('distance_from_root', 0) + info.get('distance_to_leaves', 0) +# max_depth = max(max_depth, branch_depth) +# normalizer = max_depth +# else: +# normalizer = distance_from_root + distance_to_leaves + +# if normalizer <= 0: +# return 0.0 + +# if from_top: +# return distance_from_root / normalizer +# else: +# return distance_to_leaves / normalizer if not use_global_max else (normalizer - distance_from_root) / normalizer + +# def get_concepts_from_path(self, concept): +# """ +# Get all concept names from paths containing the specified concept, +# preserving the original order they appear in each path. + +# Parameters: +# - concept: String representing the starting concept to search for + +# Returns: +# - List of concept names in the order they appear in paths +# """ +# # Get all descendant nodes for the concept +# descendants = self.get_descendant_nodes(concept) + +# # Track all unique paths to handle duplicates while preserving order +# all_paths = set() +# for info in descendants.values(): +# all_paths.add(info['path']) + +# # Process each unique path and extract names in order +# ordered_names = [] +# seen_names = set() # To track duplicates + +# for path in all_paths: +# # Split the path into components +# components = path.split(".") + +# # Process components in groups of 3 (name, pos, number) +# i = 0 +# while i < len(components): +# # Add the name component if not already seen +# if i < len(components) and components[i] not in seen_names: +# ordered_names.append(components[i]) +# seen_names.add(components[i]) + +# # Skip to the next name +# if i + 2 < len(components) and components[i+1] == 'n': +# i += 3 # Standard case: skip name, 'n', and number +# else: +# i += 1 # Fallback: move forward one component + +# return ordered_names + + +# if __name__ == "__main__": +# sem = SemanticAnalyzer() +# embed() + + + diff --git a/bscope/sae.py b/bscope/sae.py index b72b437..f8ec7da 100644 --- a/bscope/sae.py +++ b/bscope/sae.py @@ -95,7 +95,7 @@ def __init__(self, data_dim, num_atoms, mlp_hidden_dim=512): self.layers = nn.ModuleDict() # self.layers['layernorm1'] = nn.LayerNorm(data_dim, elementwise_affine=True) self.layers['layer1'] = nn.Linear(data_dim, self.mlp_hidden_dim, bias=True) - # self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) + self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) self.layers['dropout1'] = nn.Dropout(p=0.05) # Add dropout layer with p=0.2 self.layers['relu1'] = nn.ReLU()# Add sigmoid activation @@ -117,6 +117,30 @@ def forward(self, x): x = layer(x) return x +class OneLayerEncoder(nn.Module): + def __init__(self, data_dim, num_atoms, mlp_hidden_dim=512): + super(OneLayerEncoder, self).__init__() + self.data_dim = data_dim + self.num_atoms = num_atoms + self.mlp_hidden_dim = mlp_hidden_dim + self.layers = nn.ModuleDict() + self.layers['layernorm1'] = nn.LayerNorm(data_dim, elementwise_affine=True) + self.layers['layer1'] = nn.Linear(data_dim, self.mlp_hidden_dim, bias=True) + # self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) + self.layers['dropout1'] = nn.Dropout(p=0.05) # Add dropout layer with p=0.2 + self.layers['relu1'] = nn.ReLU()# Add sigmoid activation + self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) + + self.layers['layer2'] = nn.Linear(self.mlp_hidden_dim, self.mlp_hidden_dim, bias=True) + # self.layers['layernorm2'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) + self.layers['dropout2'] = nn.Dropout(p=0.05) # Add dropout layer with p=0.2 + self.layers['sigmoid'] = nn.Sigmoid() # Add sigmoid activation + + + def forward(self, x): + for layer in self.layers.values(): + x = layer(x) + return x class DefaultEncoder(nn.Module): def __init__(self, data_dim, num_atoms, mlp_hidden_dim=512): super(DefaultEncoder, self).__init__() @@ -129,13 +153,13 @@ def __init__(self, data_dim, num_atoms, mlp_hidden_dim=512): # self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) self.layers['dropout1'] = nn.Dropout(p=0.05) # Add dropout layer with p=0.2 self.layers['relu1'] = nn.ReLU()# Add sigmoid activation - # self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) + self.layers['layernorm1'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) self.layers['layer2'] = nn.Linear(self.mlp_hidden_dim, self.mlp_hidden_dim, bias=True) # self.layers['layernorm2'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) self.layers['dropout2'] = nn.Dropout(p=0.05) # Add dropout layer with p=0.2 self.layers['relu2'] = nn.ReLU() # Add ReLU activation - # self.layers['layernorm2'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) + self.layers['layernorm2'] = nn.LayerNorm(self.mlp_hidden_dim, elementwise_affine=True) self.layers['layer3'] = nn.Linear(self.mlp_hidden_dim, num_atoms, bias=False) self.layers['sigmoid'] = nn.Sigmoid() # Add sigmoid activation @@ -169,7 +193,8 @@ def forward(self, x): reconstructed = self.dictionary(z) return codes, z, reconstructed -class STSAE(nn.Module): + +class SSSAE(nn.Module): def __init__(self, data_dim, num_atoms, threshold = 0.95, mlp_hidden_dim=512, encoder=None): super(STSAE, self).__init__() @@ -190,13 +215,34 @@ def forward(self, x): reconstructed = self.dictionary(z) return codes, z, reconstructed +class STSAE(nn.Module): + def __init__(self, data_dim, num_atoms, threshold = 0.95, mlp_hidden_dim=512, nonnegative=False, encoder=None): + super(STSAE, self).__init__() + + + if encoder is not None: + self.encoder = encoder + else: + self.encoder = DefaultEncoder(data_dim, num_atoms,mlp_hidden_dim) + + self.dictionary = Dictionary(num_atoms, data_dim, nonnegative=nonnegative) + self.threshold = threshold + + def forward(self, x): + codes = self.encoder(x) + + mask = (codes >= self.threshold).float().detach() + z = codes * mask + + reconstructed = self.dictionary(z) + return codes, z, reconstructed class SigThreshSAE(nn.Module): - def __init__(self, data_dim, num_atoms, threshold = 0.95, mlp_hidden_dim=512): + def __init__(self, data_dim, num_atoms, threshold = 0.95, mlp_hidden_dim=512, nonnegative=False): super(SigThreshSAE, self).__init__() self.encoder = Encoder(data_dim, num_atoms,mlp_hidden_dim) - self.dictionary = Dictionary(num_atoms, data_dim) + self.dictionary = Dictionary(num_atoms, data_dim, nonnegative=nonnegative) self.threshold = threshold @@ -211,18 +257,21 @@ def forward(self, x): -class SigSigSAE(nn.Module): - def __init__(self, data_dim, num_atoms, a, b, mlp_hidden_dim=512, sigma=0.05): - super(SigSigSAE, self).__init__() - self.encoder = Encoder(data_dim, num_atoms,mlp_hidden_dim) +class SSSAE(nn.Module): + def __init__(self, data_dim, num_atoms, a, b, mlp_hidden_dim=512, sigma=0.05, encoder=None): + super(SSSAE, self).__init__() + + if encoder is not None: + self.encoder = encoder + else: + self.encoder = DefaultEncoder(data_dim, num_atoms,mlp_hidden_dim) + self.dictionary = Dictionary(num_atoms, data_dim) self.a = a self.b = b - self.noise = GaussianNoise(sigma=sigma) # Add Gaussian noise with sigma=0.1 - def sigmoid(self, x, a, b): """ Sigmoid function with parameters a and b. @@ -232,32 +281,39 @@ def sigmoid(self, x, a, b): a (float): Steepness of the sigmoid curve. b (float): Horizontal shift of the sigmoid curve. - An example of a very steep sigmoid function: + An example of a very steep sigmoid function would be a + """ s = torch.clip(x, min=1e-8, max=1 - 1e-8) # Avoid log(0) issues s = 1 / (1 + torch.exp(-a * (x - b))) s = torch.clamp(s, min=1e-8, max=1 - 1e-8) # Avoid log(0) issues return s + + def plot_sigmoid(self): + x = torch.linspace(-1, 1, 100) + y = self.sigmoid(x, self.a, self.b).detach().cpu().numpy() + plt.plot(x, y) + plt.title(f'Sigmoid Function (a={self.a}, b={self.b})') + plt.show() def forward(self, x): # If training # if self.training: - if self.training: - codes = self.encoder(x) - z = self.sigmoid(codes, a=self.a, b=self.b) - mask = torch.ones_like(codes).float().detach() # Use ones to keep all codes - reconstructed = self.dictionary(z) - return codes, z, reconstructed - else: - codes = self.encoder(x) - z= self.sigmoid(codes, a=self.a, b=self.b) + # if self.training: + codes = self.encoder(x) + z = self.sigmoid(codes, a=self.a, b=self.b) + reconstructed = self.dictionary(z) + return codes, z, reconstructed + # else: + # codes = self.encoder(x) + # z= self.sigmoid(codes, a=self.a, b=self.b) - mask = (z >= self.b).float().detach() - z= z* mask + # mask = (z >= self.b).float().detach() + # z= z* mask - reconstructed = self.dictionary(z) + # reconstructed = self.dictionary(z) - return codes, z, reconstructed + # return codes, z, reconstructed From 49059db6a4d4365926658b57d8ca4e271a349c2b Mon Sep 17 00:00:00 2001 From: Zaki Alaoui Date: Wed, 21 Jan 2026 22:03:23 +0000 Subject: [PATCH 6/6] update --- bscope/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/bscope/utils.py b/bscope/utils.py index e1e7194..e3fb699 100644 --- a/bscope/utils.py +++ b/bscope/utils.py @@ -7,6 +7,21 @@ Epsilon = 1e-6 +def sort_data(x, y): + sorted_indices = np.argsort(x) + return np.array(x)[sorted_indices], np.array(y)[sorted_indices] +def compute_auc(percentages, accuracies, method='trapz'): + x,y = sort_data(percentages, accuracies) + + if method == 'trapz': + # Trapezoidal rule - most common and robust + auc = np.trapz(y, x) + + elif method == 'simps': + # Simpson's rule - more accurate for smooth curves + auc = integrate.simps(y, x) + + return auc def select_significant_indices(vector, method='threshold', param=0.8, min_indices=1, max_indices=None): """ Select indices that contribute most to the overall sum of the vector.