diff --git a/doc/how_to/import_kilosort_data.rst b/doc/how_to/import_kilosort_data.rst
index dad522334a..ed7bf7b6c0 100644
--- a/doc/how_to/import_kilosort_data.rst
+++ b/doc/how_to/import_kilosort_data.rst
@@ -49,7 +49,7 @@ If you'd like to store the information you've computed, you can save the analyze
)
You now have a fully functional ``SortingAnalyzer`` - congrats! You can now use `spikeinterface-gui `__. to view the results
-interactively, or start manually labelling your units to `create an automated curation model `__.
+interactively, or start manually labeling your units to `create an automated curation model `__.
Note that if you have access to the raw recording, you can attach it to the analyzer, and re-compute extensions from the raw data. E.g.
diff --git a/doc/modules/metrics.rst b/doc/modules/metrics.rst
index d3f3512670..fe5b631f19 100644
--- a/doc/modules/metrics.rst
+++ b/doc/modules/metrics.rst
@@ -28,14 +28,32 @@ metric information. For example, you can get the list of available metrics using
.. code-block::
Available metric columns:
- ['peak_to_valley', 'peak_trough_ratio', 'half_width', 'repolarization_slope',
- 'recovery_slope', 'num_positive_peaks', 'num_negative_peaks', 'velocity_above',
- 'velocity_below', 'exp_decay', 'spread']
+ [
+ 'peak_to_trough_duration',
+ 'half_width',
+ 'repolarization_slope',
+ 'recovery_slope',
+ 'num_positive_peaks',
+ 'num_negative_peaks',
+ 'main_to_next_peak_duration',
+ 'peak_before_to_trough_ratio',
+ 'peak_after_to_trough_ratio',
+ 'peak_before_to_peak_after_ratio',
+ 'main_peak_to_trough_ratio',
+ 'trough_width',
+ 'peak_before_width',
+ 'peak_after_width',
+ 'waveform_baseline_flatness',
+ 'velocity_above',
+ 'velocity_below',
+ 'exp_decay',
+ 'spread'
+ ]
.. code-block:: python
- metric_descriptions = ComputeTemplateMetrics.get_metric_descriptions()
+ metric_descriptions = ComputeTemplateMetrics.get_metric_column_descriptions()
print("Metric descriptions: ")
print(metric_descriptions)
@@ -44,21 +62,30 @@ metric information. For example, you can get the list of available metrics using
Metric descriptions:
{
- 'peak_to_valley': 'Duration in s between the trough (minimum) and the peak (maximum) of the spike waveform.',
- 'peak_trough_ratio': 'Ratio of the amplitude of the peak (maximum) to the trough (minimum) of the spike waveform.',
+ 'peak_to_trough_duration': 'Duration in seconds between the trough (minimum) and the peak (maximum) of the spike waveform.',
'half_width': 'Duration in s at half the amplitude of the trough (minimum) of the spike waveform.',
'repolarization_slope': 'Slope of the repolarization phase of the spike waveform, between the trough (minimum) and return to baseline in uV/s.',
'recovery_slope': 'Slope of the recovery phase of the spike waveform, after the peak (maximum) returning to baseline in uV/s.',
'num_positive_peaks': 'Number of positive peaks in the template',
- 'num_negative_peaks': 'Number of negative peaks in the template',
+ 'num_negative_peaks': 'Number of negative peaks (troughs) in the template',
+ 'main_to_next_peak_duration': 'Duration in seconds from main extremum to next extremum.',
+ 'peak_before_to_trough_ratio': 'Ratio of peak before amplitude to trough amplitude',
+ 'peak_after_to_trough_ratio': 'Ratio of peak after amplitude to trough amplitude',
+ 'peak_before_to_peak_after_ratio': 'Ratio of peak before amplitude to peak after amplitude',
+ 'main_peak_to_trough_ratio': 'Ratio of main peak amplitude to trough amplitude',
+ 'trough_width': 'Width of the main trough in seconds',
+ 'peak_before_width': 'Width of the main peak before trough in seconds',
+ 'peak_after_width': 'Width of the main peak after trough in seconds',
+ 'waveform_baseline_flatness': 'Ratio of max baseline amplitude to max waveform amplitude. Lower = flatter baseline.',
'velocity_above': 'Velocity of the spike propagation above the max channel in um/ms',
'velocity_below': 'Velocity of the spike propagation below the max channel in um/ms',
- 'exp_decay': 'Exponential decay of the template amplitude over distance from the extremum channel (1/um).',
+ 'exp_decay': 'Spatial decay of the template amplitude over distance from the extremum channel (1/um). Uses exponential or linear fit based on linear_fit parameter.',
'spread': 'Spread of the template amplitude in um, calculated as the distance between channels whose templates exceed the spread_threshold.'
}
+
.. toctree::
:caption: Metrics submodules
:maxdepth: 1
diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py
index dd06e60458..1771ba5c16 100644
--- a/src/spikeinterface/comparison/comparisontools.py
+++ b/src/spikeinterface/comparison/comparisontools.py
@@ -569,7 +569,7 @@ def make_hungarian_match(agreement_scores, min_score):
def do_score_labels(sorting1, sorting2, delta_frames, unit_map12, label_misclassification=False):
"""
- Makes the labelling at spike level for each spike train:
+ Makes the labeling at spike level for each spike train:
* TP: true positive
* CL: classification error
* FN: False negative
diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py
index 30038bc270..2c551b5342 100644
--- a/src/spikeinterface/core/analyzer_extension_core.py
+++ b/src/spikeinterface/core/analyzer_extension_core.py
@@ -997,6 +997,38 @@ def get_optional_dependencies(cls, **params):
depend_on = list(cls.depend_on) + list(metric_depend_on)
return depend_on
+ def get_computed_metric_names(self):
+ """
+ Get the list of already computed metric names.
+
+ Returns
+ -------
+ computed_metric_names : list[str]
+ List of computed metric names.
+ """
+ if self.data is None or len(self.data) == 0:
+ return []
+ else:
+ computed_metric_columns = self.data["metrics"].columns.tolist()
+ computed_metric_names = []
+ for m in self.metric_list:
+ if all(col in computed_metric_columns for col in m.metric_columns.keys()):
+ computed_metric_names.append(m.metric_name)
+ return computed_metric_names
+
+ def _cast_metrics(self, metrics_df):
+ metric_dtypes = {}
+ for m in self.metric_list:
+ metric_dtypes.update(m.metric_columns)
+
+ for col in metrics_df.columns:
+ if col in metric_dtypes:
+ try:
+ metrics_df[col] = metrics_df[col].astype(metric_dtypes[col])
+ except Exception as e:
+ print(f"Error casting column {col}: {e}")
+ return metrics_df
+
def _set_params(
self,
metric_names: list[str] | None = None,
@@ -1155,6 +1187,13 @@ def _compute_metrics(
metric = [m for m in self.metric_list if m.metric_name == metric_name][0]
column_names_dtypes.update(metric.metric_columns)
+ # drop metric that don't map to any metric names
+ possible_metric_names = [m.metric_name for m in self.metric_list]
+ wrong_metric_names = [m for m in metric_names if m not in possible_metric_names]
+ if len(wrong_metric_names) > 0:
+ warnings.warn(f"The following metric names are not recognized and will be ignored: {wrong_metric_names}")
+ metric_names = [m for m in metric_names if m in possible_metric_names]
+
metrics = pd.DataFrame(index=unit_ids, columns=list(column_names_dtypes.keys()))
run_times = {}
diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py
index 119ab1d598..d5ba74cfd9 100644
--- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py
+++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py
@@ -189,7 +189,7 @@ def test_aggregation_labeling_for_lists():
assert np.all(user_group_property == [6, 6, 7, 7])
-def test_aggretion_labelling_for_dicts():
+def test_aggretion_labeling_for_dicts():
"""Aggregated dicts of recordings get different labels depending on their underlying `property`s"""
recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False)
diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py
index b64070e662..56de464cee 100644
--- a/src/spikeinterface/curation/__init__.py
+++ b/src/spikeinterface/curation/__init__.py
@@ -20,5 +20,12 @@
from .sortingview_curation import apply_sortingview_curation
# automated curation
+from .bombcell_curation import (
+ bombcell_get_default_thresholds,
+ bombcell_label_units,
+ get_bombcell_labeling_summary,
+ save_bombcell_results,
+)
+
from .model_based_curation import auto_label_units, load_model
from .train_manual_curation import train_model, get_default_classifier_search_spaces
diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py
new file mode 100644
index 0000000000..84a954ffa6
--- /dev/null
+++ b/src/spikeinterface/curation/bombcell_curation.py
@@ -0,0 +1,343 @@
+"""
+Unit labeling based on quality metrics (bombcell).
+
+Unit Types:
+ 0 (NOISE): Failed waveform quality checks
+ 1 (GOOD): Passed all thresholds
+ 2 (MUA): Failed spike quality checks
+ 3 (NON_SOMA): Non-somatic units (axonal)
+"""
+
+from __future__ import annotations
+
+import numpy as np
+from typing import Optional
+
+NOISE_METRICS = [
+ "num_positive_peaks",
+ "num_negative_peaks",
+ "peak_to_trough_duration",
+ "waveform_baseline_flatness",
+ "peak_after_to_trough_ratio",
+ "exp_decay",
+]
+
+SPIKE_QUALITY_METRICS = [
+ "amplitude_median",
+ "snr",
+ "amplitude_cutoff",
+ "num_spikes",
+ "rp_contamination",
+ "presence_ratio",
+ "drift_ptp",
+]
+
+NON_SOMATIC_METRICS = [
+ "peak_before_to_trough_ratio",
+ "peak_before_width",
+ "trough_width",
+ "peak_before_to_peak_after_ratio",
+ "main_peak_to_trough_ratio",
+]
+
+
+def bombcell_get_default_thresholds() -> dict:
+ """
+ bombcell - Returns default thresholds for unit labeling.
+
+ Each metric has 'min' and 'max' values. Use None to disable a threshold (e.g. to ignore a metric completely
+ or to only have a min or a max threshold)
+ """
+ # bombcell
+ return {
+ # Waveform quality (failures -> NOISE)
+ "num_positive_peaks": {"min": None, "max": 2},
+ "num_negative_peaks": {"min": None, "max": 1},
+ "peak_to_trough_duration": {"min": 0.0001, "max": 0.00115}, # seconds
+ "waveform_baseline_flatness": {"min": None, "max": 0.5},
+ "peak_after_to_trough_ratio": {"min": None, "max": 0.8},
+ "exp_decay": {"min": 0.01, "max": 0.1},
+ # Spike quality (failures -> MUA)
+ "amplitude_median": {"min": 40, "max": None}, # uV
+ "snr_baseline": {"min": 5, "max": None},
+ "amplitude_cutoff": {"min": None, "max": 0.2},
+ "num_spikes": {"min": 300, "max": None},
+ "rp_contamination": {"min": None, "max": 0.1},
+ "presence_ratio": {"min": 0.7, "max": None},
+ "drift_ptp": {"min": None, "max": 100}, # um
+ # Non-somatic detection
+ "peak_before_to_trough_ratio": {"min": None, "max": 3},
+ "peak_before_width": {"min": 0.00015, "max": None}, # seconds
+ "trough_width": {"min": 0.0002, "max": None}, # seconds
+ "peak_before_to_peak_after_ratio": {"min": None, "max": 3},
+ "main_peak_to_trough_ratio": {"min": None, "max": 0.8},
+ }
+
+
+def _is_threshold_disabled(value):
+ """Check if a threshold value is disabled (None or np.nan)."""
+ if value is None:
+ return True
+ if isinstance(value, float) and np.isnan(value):
+ return True
+ return False
+
+
+def bombcell_label_units(
+ sorting_analyzer=None,
+ thresholds: Optional[dict] = None,
+ label_non_somatic: bool = True,
+ split_non_somatic_good_mua: bool = False,
+ external_metrics: Optional["pd.DataFrame | list[pd.DataFrame]"] = None,
+) -> tuple[np.ndarray, np.ndarray]:
+ """
+ bombcell - label units based on quality metrics and thresholds.
+
+ Parameters
+ ----------
+ sorting_analyzer : SortingAnalyzer, optional
+ SortingAnalyzer with computed quality_metrics and/or template_metrics extensions.
+ If provided, metrics are extracted automatically using get_metrics_extension_data().
+ thresholds : dict or None
+ Threshold dict: {"metric": {"min": val, "max": val}}. Use None to disable.
+ label_non_somatic : bool
+ If True, detect non-somatic (axonal) units.
+ split_non_somatic_good_mua : bool
+ If True, split non-somatic into NON_SOMA_GOOD (3) and NON_SOMA_MUA (4).
+ external_metrics: Optional[pd.DataFrame | list[pd.DataFrame]] = None
+ External metrics DataFrame(s) (index = unit_ids) to use instead of those from SortingAnalyzer.
+
+ Returns
+ -------
+ unit_type : np.ndarray
+ Numeric: 0=NOISE, 1=GOOD, 2=MUA, 3=NON_SOMA
+ unit_type_string : np.ndarray
+ String labels.
+ """
+ import pandas as pd
+
+ if sorting_analyzer is not None:
+ combined_metrics = sorting_analyzer.get_metrics_extension_data()
+ if combined_metrics.empty:
+ raise ValueError(
+ "SortingAnalyzer has no metrics extensions computed. "
+ "Compute quality_metrics and/or template_metrics first."
+ )
+ else:
+ if external_metrics is None:
+ raise ValueError("Either sorting_analyzer or external_metrics must be provided")
+ if isinstance(external_metrics, list):
+ assert all(
+ isinstance(df, pd.DataFrame) for df in external_metrics
+ ), "All items in external_metrics must be DataFrames"
+ combined_metrics = pd.concat(external_metrics, axis=1)
+ else:
+ combined_metrics = external_metrics
+
+ if thresholds is None:
+ thresholds = bombcell_get_default_thresholds()
+
+ n_units = len(combined_metrics)
+ unit_type = np.full(n_units, np.nan)
+ absolute_value_metrics = ["amplitude_median"]
+
+ # NOISE: waveform failures
+ noise_mask = np.zeros(n_units, dtype=bool)
+ for metric_name in NOISE_METRICS:
+ if metric_name not in combined_metrics.columns or metric_name not in thresholds:
+ continue
+ values = combined_metrics[metric_name].values
+ if metric_name in absolute_value_metrics:
+ values = np.abs(values)
+ thresh = thresholds[metric_name]
+ noise_mask |= np.isnan(values)
+ if not _is_threshold_disabled(thresh["min"]):
+ noise_mask |= values < thresh["min"]
+ if not _is_threshold_disabled(thresh["max"]):
+ noise_mask |= values > thresh["max"]
+ unit_type[noise_mask] = 0
+
+ # MUA: spike quality failures
+ mua_mask = np.zeros(n_units, dtype=bool)
+ for metric_name in SPIKE_QUALITY_METRICS:
+ if metric_name not in combined_metrics.columns or metric_name not in thresholds:
+ continue
+ values = combined_metrics[metric_name].values
+ if metric_name in absolute_value_metrics:
+ values = np.abs(values)
+ thresh = thresholds[metric_name]
+ valid_mask = np.isnan(unit_type)
+ if not _is_threshold_disabled(thresh["min"]):
+ mua_mask |= valid_mask & ~np.isnan(values) & (values < thresh["min"])
+ if not _is_threshold_disabled(thresh["max"]):
+ mua_mask |= valid_mask & ~np.isnan(values) & (values > thresh["max"])
+ unit_type[mua_mask & np.isnan(unit_type)] = 2
+
+ # GOOD: passed all checks
+ unit_type[np.isnan(unit_type)] = 1
+
+ # NON-SOMATIC
+ if label_non_somatic:
+
+ def get_metric(name):
+ if name in combined_metrics.columns:
+ return combined_metrics[name].values
+ return np.full(n_units, np.nan)
+
+ peak_before_width = get_metric("peak_before_width")
+ trough_width = get_metric("trough_width")
+ width_thresh_peak = thresholds.get("peak_before_width", {}).get("min", None)
+ width_thresh_trough = thresholds.get("trough_width", {}).get("min", None)
+
+ narrow_peak = (
+ ~np.isnan(peak_before_width) & (peak_before_width < width_thresh_peak)
+ if not _is_threshold_disabled(width_thresh_peak)
+ else np.zeros(n_units, dtype=bool)
+ )
+ narrow_trough = (
+ ~np.isnan(trough_width) & (trough_width < width_thresh_trough)
+ if not _is_threshold_disabled(width_thresh_trough)
+ else np.zeros(n_units, dtype=bool)
+ )
+ width_conditions = narrow_peak & narrow_trough
+
+ peak_before_to_trough = get_metric("peak_before_to_trough_ratio")
+ peak_before_to_peak_after = get_metric("peak_before_to_peak_after_ratio")
+ main_peak_to_trough = get_metric("main_peak_to_trough_ratio")
+
+ ratio_thresh_pbt = thresholds.get("peak_before_to_trough_ratio", {}).get("max", None)
+ ratio_thresh_pbpa = thresholds.get("peak_before_to_peak_after_ratio", {}).get("max", None)
+ ratio_thresh_mpt = thresholds.get("main_peak_to_trough_ratio", {}).get("max", None)
+
+ large_initial_peak = (
+ ~np.isnan(peak_before_to_trough) & (peak_before_to_trough > ratio_thresh_pbt)
+ if not _is_threshold_disabled(ratio_thresh_pbt)
+ else np.zeros(n_units, dtype=bool)
+ )
+ large_peak_ratio = (
+ ~np.isnan(peak_before_to_peak_after) & (peak_before_to_peak_after > ratio_thresh_pbpa)
+ if not _is_threshold_disabled(ratio_thresh_pbpa)
+ else np.zeros(n_units, dtype=bool)
+ )
+ large_main_peak = (
+ ~np.isnan(main_peak_to_trough) & (main_peak_to_trough > ratio_thresh_mpt)
+ if not _is_threshold_disabled(ratio_thresh_mpt)
+ else np.zeros(n_units, dtype=bool)
+ )
+
+ # (ratio AND width) OR standalone main_peak_to_trough
+ ratio_conditions = large_initial_peak | large_peak_ratio
+ is_non_somatic = (ratio_conditions & width_conditions) | large_main_peak
+
+ if split_non_somatic_good_mua:
+ unit_type[(unit_type == 1) & is_non_somatic] = 3
+ unit_type[(unit_type == 2) & is_non_somatic] = 4
+ else:
+ unit_type[(unit_type != 0) & is_non_somatic] = 3
+
+ # String labels
+ if split_non_somatic_good_mua:
+ labels = {0: "NOISE", 1: "good", 2: "mua", 3: "non_soma_good", 4: "non_soma_mua"}
+ else:
+ labels = {0: "noise", 1: "good", 2: "mua", 3: "non_soma"}
+
+ unit_type_string = np.array([labels.get(int(t), "unknown") for t in unit_type], dtype=object)
+ return unit_type.astype(int), unit_type_string
+
+
+def get_bombcell_labeling_summary(unit_type: np.ndarray, unit_type_string: np.ndarray) -> dict:
+ """Get counts and percentages for each unit type."""
+ n_total = len(unit_type)
+ unique_types, counts = np.unique(unit_type, return_counts=True)
+
+ summary = {"total_units": n_total, "counts": {}, "percentages": {}}
+ for utype, count in zip(unique_types, counts):
+ label = unit_type_string[unit_type == utype][0]
+ summary["counts"][label] = int(count)
+ summary["percentages"][label] = round(100 * count / n_total, 1)
+
+ return summary
+
+
+def save_bombcell_results(
+ quality_metrics: "pd.DataFrame",
+ unit_type: np.ndarray,
+ unit_type_string: np.ndarray,
+ thresholds: dict,
+ folder,
+ save_narrow: bool = True,
+ save_wide: bool = True,
+) -> None:
+ """
+ Save labeling results to CSV files.
+
+ Parameters
+ ----------
+ quality_metrics : pd.DataFrame
+ DataFrame with quality metrics (index = unit_ids).
+ unit_type : np.ndarray
+ Numeric unit type codes.
+ unit_type_string : np.ndarray
+ String labels for each unit.
+ thresholds : dict
+ Threshold dictionary used for labeling.
+ folder : str or Path
+ Folder to save the CSV files.
+ save_narrow : bool, default: True
+ Save narrow/tidy format (one row per unit-metric).
+ save_wide : bool, default: True
+ Save wide format (one row per unit, metrics as columns).
+ """
+ from pathlib import Path
+ import pandas as pd
+
+ folder = Path(folder)
+ folder.mkdir(parents=True, exist_ok=True)
+
+ unit_ids = quality_metrics.index.values
+
+ # Wide format: one row per unit
+ if save_wide:
+ wide_df = quality_metrics.copy()
+ wide_df.insert(0, "label", unit_type_string)
+ wide_df.insert(1, "label_code", unit_type)
+ wide_df.to_csv(folder / "labeling_results_wide.csv")
+
+ # Narrow format: one row per unit-metric combination
+ if save_narrow:
+ rows = []
+ for i, unit_id in enumerate(unit_ids):
+ label = unit_type_string[i]
+ label_code = unit_type[i]
+ for metric_name in quality_metrics.columns:
+ if metric_name not in thresholds:
+ continue
+ value = quality_metrics.loc[unit_id, metric_name]
+ thresh = thresholds[metric_name]
+ thresh_min = thresh.get("min", None)
+ thresh_max = thresh.get("max", None)
+
+ # Determine pass/fail
+ passed = True
+ if np.isnan(value):
+ passed = False
+ elif not _is_threshold_disabled(thresh_min) and value < thresh_min:
+ passed = False
+ elif not _is_threshold_disabled(thresh_max) and value > thresh_max:
+ passed = False
+
+ rows.append(
+ {
+ "unit_id": unit_id,
+ "label": label,
+ "label_code": label_code,
+ "metric_name": metric_name,
+ "value": value,
+ "threshold_min": thresh_min,
+ "threshold_max": thresh_max,
+ "passed": passed,
+ }
+ )
+
+ narrow_df = pd.DataFrame(rows)
+ narrow_df.to_csv(folder / "labeling_results_narrow.csv", index=False)
diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py
index e779e13182..99822d64f6 100644
--- a/src/spikeinterface/curation/model_based_curation.py
+++ b/src/spikeinterface/curation/model_based_curation.py
@@ -3,11 +3,10 @@
import json
import warnings
import re
+from packaging.version import parse
from spikeinterface.core import SortingAnalyzer
from spikeinterface.curation.train_manual_curation import (
- try_to_get_metrics_from_analyzer,
- _get_computed_metrics,
_format_metric_dataframe,
)
from copy import deepcopy
@@ -81,11 +80,12 @@ def predict_labels(
# Get metrics DataFrame for classification
if input_data is None:
- input_data = _get_computed_metrics(self.sorting_analyzer)
+ input_data = self.sorting_analyzer.get_metrics_extension_data()
else:
if not isinstance(input_data, pd.DataFrame):
raise ValueError("Input data must be a pandas DataFrame")
+ input_data = self.handle_backwards_compatibility_in_metrics(input_data, model_info=model_info)
input_data = self._check_required_metrics_are_present(input_data)
if model_info is not None:
@@ -127,8 +127,23 @@ def predict_labels(
return classified_units
- def _check_required_metrics_are_present(self, calculated_metrics):
+ def handle_backwards_compatibility_in_metrics(self, calculated_metrics, model_info):
+ si_version = model_info["requirements"].get("spikeinterface", None)
+ if si_version is not None and parse(si_version) < parse("0.103.2"):
+ # if the model was trained with SI version < 0.103.2, we need to rename some metrics
+ calculated_metrics = calculated_metrics.copy()
+ # peak_to_trough_duration was named peak_to_valley
+ if "peak_to_trough_duration" in calculated_metrics.columns:
+ calculated_metrics = calculated_metrics.rename(columns={"peak_to_trough_duration": "peak_to_valley"})
+ # main_peak_to_trough_ratio was named peak_trough_ratio and had inverted sign
+ if "main_peak_to_trough_ratio" in calculated_metrics.columns:
+ calculated_metrics = calculated_metrics.rename(
+ columns={"main_peak_to_trough_ratio": "peak_trough_ratio"}
+ )
+ calculated_metrics["peak_trough_ratio"] = -1 * calculated_metrics["peak_trough_ratio"]
+ return calculated_metrics
+ def _check_required_metrics_are_present(self, calculated_metrics):
# Check all the required metrics have been calculated
required_metrics = set(self.required_metrics)
if required_metrics.issubset(set(calculated_metrics)):
diff --git a/src/spikeinterface/curation/tests/test_bombcell_curation.py b/src/spikeinterface/curation/tests/test_bombcell_curation.py
new file mode 100644
index 0000000000..a867453064
--- /dev/null
+++ b/src/spikeinterface/curation/tests/test_bombcell_curation.py
@@ -0,0 +1,17 @@
+import pytest
+from pathlib import Path
+from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path
+from spikeinterface.curation.bombcell_curation import bombcell_label_units
+
+
+def test_bombcell_label_units(sorting_analyzer_for_curation):
+ """Test bombcell_label_units function on a sorting_analyzer with computed quality metrics."""
+
+ sorting_analyzer = sorting_analyzer_for_curation
+ sorting_analyzer.compute("quality_metrics")
+ sorting_analyzer.compute("template_metrics")
+
+ unit_type, unit_type_string = bombcell_label_units(sorting_analyzer=sorting_analyzer)
+
+ assert len(unit_type) == sorting_analyzer.unit_ids.size
+ assert set(unit_type_string).issubset({"somatic", "non-somatic", "good", "mua", "noise"})
diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py
index 198e98037c..c5f29ac329 100644
--- a/src/spikeinterface/metrics/quality/misc_metrics.py
+++ b/src/spikeinterface/metrics/quality/misc_metrics.py
@@ -146,7 +146,7 @@ class PresenceRatio(BaseMetric):
def compute_snrs(
sorting_analyzer,
unit_ids=None,
- peak_sign: str = "neg",
+ peak_sign: str = "both",
peak_mode: str = "extremum",
):
"""
@@ -207,6 +207,119 @@ class SNR(BaseMetric):
depend_on = ["noise_levels", "templates"]
+def compute_snrs_versus_baseline(
+ sorting_analyzer,
+ unit_ids=None,
+ peak_sign: str = "neg",
+ baseline_window_ms: float = 0.5,
+):
+ """
+ Compute signal to noise ratio versus baseline.
+
+ This differs from the standard SNR by using:
+ - Signal: Max absolute value of the median waveform on peak channel
+ - Noise: MAD (Median Absolute Deviation) of baseline samples from waveforms
+
+ Parameters
+ ----------
+ sorting_analyzer : SortingAnalyzer
+ A SortingAnalyzer object.
+ unit_ids : list or None
+ The list of unit ids to compute the SNR. If None, all units are used.
+ peak_sign : "neg" | "pos" | "both", default: "neg"
+ The sign of the template to compute best channels.
+ baseline_window_ms : float, default: 0.5
+ Duration in ms at the start of the waveform to use as baseline for noise calculation.
+
+ Returns
+ -------
+ snrs : dict
+ Computed signal to noise ratio for each unit.
+
+ Notes
+ -----
+ This implementation follows the bombcell methodology [1]:
+ - Signal is the maximum absolute amplitude of the median waveform on the peak channel
+ - Noise is computed as MAD of baseline samples (first N samples of each waveform)
+
+ Requires the "waveforms" extension to be computed.
+
+ References
+ ----------
+ [1] https://github.com/Julie-Fabre/bombcell
+ """
+ if not sorting_analyzer.has_extension("waveforms"):
+ raise ValueError(
+ "The 'waveforms' extension is required for compute_snrs_versus_baseline. "
+ "Please compute it first with: analyzer.compute('waveforms')"
+ )
+
+ if unit_ids is None:
+ unit_ids = sorting_analyzer.unit_ids
+
+ waveforms_ext = sorting_analyzer.get_extension("waveforms")
+ nbefore = waveforms_ext.nbefore
+ sampling_frequency = sorting_analyzer.sampling_frequency
+
+ # Calculate baseline samples from ms
+ baseline_samples = int(baseline_window_ms / 1000 * sampling_frequency)
+ baseline_samples = min(baseline_samples, nbefore) # Can't exceed nbefore
+
+ # Get peak channel for each unit from templates
+ extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign)
+
+ snrs = {}
+ for unit_id in unit_ids:
+ # Get waveforms for this unit (num_spikes, num_samples, num_channels)
+ waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False)
+
+ if waveforms is None or len(waveforms) == 0:
+ snrs[unit_id] = np.nan
+ continue
+
+ # Get peak channel index
+ peak_chan_id = extremum_channels_ids[unit_id]
+ if sorting_analyzer.is_sparse():
+ chan_ids = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id]
+ if peak_chan_id not in chan_ids:
+ snrs[unit_id] = np.nan
+ continue
+ peak_chan_idx = np.where(chan_ids == peak_chan_id)[0][0]
+ else:
+ peak_chan_idx = sorting_analyzer.channel_ids_to_indices([peak_chan_id])[0]
+
+ # Extract waveforms on peak channel
+ waveforms_peak = waveforms[:, :, peak_chan_idx] # (num_spikes, num_samples)
+
+ # Signal: max absolute value of the median waveform
+ median_waveform = np.median(waveforms_peak, axis=0) # median across spikes
+ signal = np.max(np.abs(median_waveform))
+
+ # Noise: MAD of baseline samples (first N samples of each waveform)
+ baseline_samples_all = waveforms_peak[:, :baseline_samples].flatten()
+ median_baseline = np.median(baseline_samples_all)
+ noise = np.median(np.abs(baseline_samples_all - median_baseline))
+
+ # Calculate SNR (avoid division by zero)
+ if noise > 0:
+ snrs[unit_id] = signal / noise
+ else:
+ snrs[unit_id] = np.nan
+
+ return snrs
+
+
+class SNRBaseline(BaseMetric):
+ metric_name = "snr_baseline"
+ metric_function = compute_snrs_versus_baseline
+ metric_params = {"peak_sign": "neg", "baseline_window_ms": 0.5}
+ metric_columns = {"snr_baseline": float}
+ metric_descriptions = {
+ "snr_baseline": "Signal to noise ratio versus baseline (median waveform max / baseline MAD). Based on bombcell."
+ }
+ depend_on = ["waveforms", "templates"]
+
+
def compute_isi_violations(sorting_analyzer, unit_ids=None, periods=None, isi_threshold_ms=1.5, min_isi_ms=0):
"""
Calculate Inter-Spike Interval (ISI) violations.
@@ -890,7 +1003,10 @@ def compute_amplitude_cutoffs(
if invert_amplitudes:
amplitudes = -amplitudes
all_fraction_missing[unit_id] = amplitude_cutoff(
- amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio
+ amplitudes,
+ num_histogram_bins,
+ histogram_smoothing_value,
+ amplitudes_bins_min_ratio,
)
if np.any(np.isnan(list(all_fraction_missing.values()))):
@@ -1392,6 +1508,7 @@ class SDRatio(BaseMetric):
FiringRate,
PresenceRatio,
SNR,
+ # SNRBaseline,
ISIViolation,
RPViolation,
SlidingRPViolation,
@@ -1516,7 +1633,12 @@ def isi_violations(spike_trains, total_duration_s, isi_threshold_s=0.0015, min_i
return isi_violations_ratio, isi_violations_rate, isi_violations_count
-def amplitude_cutoff(amplitudes, num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5):
+def amplitude_cutoff(
+ amplitudes,
+ num_histogram_bins=500,
+ histogram_smoothing_value=3,
+ amplitudes_bins_min_ratio=5,
+):
"""
Calculate approximate fraction of spikes missing from a distribution of amplitudes.
diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py
index a1af1de348..fc22f09006 100644
--- a/src/spikeinterface/metrics/template/metrics.py
+++ b/src/spikeinterface/metrics/template/metrics.py
@@ -1,66 +1,445 @@
from __future__ import annotations
+import warnings
import numpy as np
-from collections import namedtuple
-
from spikeinterface.core.analyzer_extension_core import BaseMetric
-def get_trough_and_peak_idx(template):
+def get_trough_and_peak_idx(
+ template, min_thresh_detect_peaks_troughs=0.4, smooth=True, smooth_window_frac=0.1, smooth_polyorder=3
+):
"""
- Return the indices into the input template of the detected trough
- (minimum of template) and peak (maximum of template, after trough).
- Assumes negative trough and positive peak.
+ Detect troughs and peaks in a template waveform and return detailed information
+ about each detected feature.
Parameters
----------
- template: numpy.ndarray
+ template : numpy.ndarray
The 1D template waveform
+ min_thresh_detect_peaks_troughs : float, default: 0.4
+ Minimum prominence threshold as a fraction of the template's absolute max value
+ smooth : bool, default: True
+ Whether to apply smoothing before peak detection
+ smooth_window_frac : float, default: 0.1
+ Smoothing window length as a fraction of template length (0.05-0.2 recommended)
+ smooth_polyorder : int, default: 3
+ Polynomial order for Savitzky-Golay filter (must be < window_length)
Returns
-------
- trough_idx: int
- The index of the trough
- peak_idx: int
- The index of the peak
+ troughs : dict
+ Dictionary containing:
+ - "indices": array of all trough indices
+ - "values": array of all trough values
+ - "prominences": array of all trough prominences
+ - "widths": array of all trough widths
+ - "main_idx": index of the main trough (most prominent)
+ - "main_loc": location (sample index) of the main trough in template
+ peaks_before : dict
+ Dictionary containing peaks detected before the main trough (initial peaks):
+ - "indices": array of all peak indices (in original template coordinates)
+ - "values": array of all peak values
+ - "prominences": array of all peak prominences
+ - "widths": array of all peak widths
+ - "main_idx": index of the main peak (most prominent)
+ - "main_loc": location (sample index) of the main peak in template
+ peaks_after : dict
+ Dictionary containing peaks detected after the main trough (repolarization peaks):
+ - "indices": array of all peak indices (in original template coordinates)
+ - "values": array of all peak values
+ - "prominences": array of all peak prominences
+ - "widths": array of all peak widths
+ - "main_idx": index of the main peak (most prominent)
+ - "main_loc": location (sample index) of the main peak in template
"""
+ from scipy.signal import find_peaks, savgol_filter
+
assert template.ndim == 1
- trough_idx = np.argmin(template)
- peak_idx = trough_idx + np.argmax(template[trough_idx:])
- return trough_idx, peak_idx
+ # Smooth template to reduce noise while preserving peaks using Savitzky-Golay filter
+ if smooth:
+ window_length = int(len(template) * smooth_window_frac) // 2 * 2 + 1
+ window_length = max(smooth_polyorder + 2, window_length) # Must be > polyorder
+ template = savgol_filter(template, window_length=window_length, polyorder=smooth_polyorder)
+
+ # Initialize empty result dictionaries
+ empty_dict = {
+ "indices": np.array([], dtype=int),
+ "values": np.array([]),
+ "prominences": np.array([]),
+ "widths": np.array([]),
+ "main_idx": None,
+ "main_loc": None,
+ }
-#########################################################################################
-# Single-channel metrics
-def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float:
+ # Get min prominence to detect peaks and troughs relative to template abs max value
+ min_prominence = min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template))
+
+ # --- Find troughs (by inverting waveform and using find_peaks) ---
+ trough_locs, trough_props = find_peaks(-template, prominence=min_prominence, width=0)
+
+ if len(trough_locs) == 0:
+ # Fallback: use global minimum
+ trough_locs = np.array([np.nanargmin(template)])
+ trough_props = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])}
+
+ # Determine main trough (most prominent, or first if no valid prominences)
+ trough_prominences = trough_props.get("prominences", np.array([]))
+ if len(trough_prominences) > 0 and not np.all(np.isnan(trough_prominences)):
+ main_trough_idx = np.nanargmax(trough_prominences)
+ else:
+ main_trough_idx = 0
+
+ main_trough_loc = trough_locs[main_trough_idx]
+
+ troughs = {
+ "indices": trough_locs,
+ "values": template[trough_locs],
+ "prominences": trough_props.get("prominences", np.full(len(trough_locs), np.nan)),
+ "widths": trough_props.get("widths", np.full(len(trough_locs), np.nan)),
+ "main_idx": main_trough_idx,
+ "main_loc": main_trough_loc,
+ }
+
+ # --- Find peaks before the main trough ---
+ if main_trough_loc > 3:
+ template_before = template[:main_trough_loc]
+
+ # Try with original prominence
+ peak_locs_before, peak_props_before = find_peaks(template_before, prominence=min_prominence, width=0)
+
+ # If no peaks found, try with lower prominence (keep only max peak)
+ if len(peak_locs_before) == 0:
+ lower_prominence = 0.075 * min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template))
+ peak_locs_before, peak_props_before = find_peaks(template_before, prominence=lower_prominence, width=0)
+ # Keep only the most prominent peak when using lower threshold
+ if len(peak_locs_before) > 1:
+ prominences = peak_props_before.get("prominences", np.array([]))
+ if len(prominences) > 0 and not np.all(np.isnan(prominences)):
+ max_idx = np.nanargmax(prominences)
+ peak_locs_before = np.array([peak_locs_before[max_idx]])
+ peak_props_before = {
+ "prominences": np.array([prominences[max_idx]]),
+ "widths": np.array([peak_props_before.get("widths", np.array([np.nan]))[max_idx]]),
+ }
+
+ # If still no peaks found, fall back to argmax
+ if len(peak_locs_before) == 0:
+ peak_locs_before = np.array([np.nanargmax(template_before)])
+ peak_props_before = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])}
+
+ peak_prominences_before = peak_props_before.get("prominences", np.array([]))
+ if len(peak_prominences_before) > 0 and not np.all(np.isnan(peak_prominences_before)):
+ main_peak_before_idx = np.nanargmax(peak_prominences_before)
+ else:
+ main_peak_before_idx = 0
+
+ peaks_before = {
+ "indices": peak_locs_before,
+ "values": template[peak_locs_before],
+ "prominences": peak_props_before.get("prominences", np.full(len(peak_locs_before), np.nan)),
+ "widths": peak_props_before.get("widths", np.full(len(peak_locs_before), np.nan)),
+ "main_idx": main_peak_before_idx,
+ "main_loc": peak_locs_before[main_peak_before_idx],
+ }
+ else:
+ peaks_before = empty_dict.copy()
+
+ # --- Find peaks after the main trough (repolarization peaks) ---
+ if main_trough_loc < len(template) - 3:
+ template_after = template[main_trough_loc:]
+
+ # Try with original prominence
+ peak_locs_after, peak_props_after = find_peaks(template_after, prominence=min_prominence, width=0)
+
+ # If no peaks found, try with lower prominence (keep only max peak)
+ if len(peak_locs_after) == 0:
+ lower_prominence = 0.075 * min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template))
+ peak_locs_after, peak_props_after = find_peaks(template_after, prominence=lower_prominence, width=0)
+ # Keep only the most prominent peak when using lower threshold
+ if len(peak_locs_after) > 1:
+ prominences = peak_props_after.get("prominences", np.array([]))
+ if len(prominences) > 0 and not np.all(np.isnan(prominences)):
+ max_idx = np.nanargmax(prominences)
+ peak_locs_after = np.array([peak_locs_after[max_idx]])
+ peak_props_after = {
+ "prominences": np.array([prominences[max_idx]]),
+ "widths": np.array([peak_props_after.get("widths", np.array([np.nan]))[max_idx]]),
+ }
+
+ # If still no peaks found, fall back to argmax
+ if len(peak_locs_after) == 0:
+ peak_locs_after = np.array([np.nanargmax(template_after)])
+ peak_props_after = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])}
+
+ # Convert to original template coordinates
+ peak_locs_after_abs = peak_locs_after + main_trough_loc
+
+ peak_prominences_after = peak_props_after.get("prominences", np.array([]))
+ if len(peak_prominences_after) > 0 and not np.all(np.isnan(peak_prominences_after)):
+ main_peak_after_idx = np.nanargmax(peak_prominences_after)
+ else:
+ main_peak_after_idx = 0
+
+ peaks_after = {
+ "indices": peak_locs_after_abs,
+ "values": template[peak_locs_after_abs],
+ "prominences": peak_props_after.get("prominences", np.full(len(peak_locs_after), np.nan)),
+ "widths": peak_props_after.get("widths", np.full(len(peak_locs_after), np.nan)),
+ "main_idx": main_peak_after_idx,
+ "main_loc": peak_locs_after_abs[main_peak_after_idx],
+ }
+ else:
+ peaks_after = empty_dict.copy()
+
+ return troughs, peaks_before, peaks_after
+
+
+def get_main_to_next_peak_duration(template, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs):
"""
- Return the peak to valley duration in seconds of input waveforms.
+ Calculate duration from the main extremum to the next extremum.
+
+ The duration is measured from the largest absolute feature (main trough or main peak)
+ to the next extremum. For typical negative-first waveforms, this is trough-to-peak.
+ For positive-first waveforms, this is peak-to-trough.
Parameters
----------
- template_single: numpy.ndarray
+ template : numpy.ndarray
The 1D template waveform
sampling_frequency : float
- The sampling frequency of the template
- trough_idx: int, default: None
- The index of the trough
- peak_idx: int, default: None
- The index of the peak
+ The sampling frequency in Hz
+ troughs : dict
+ Trough detection results from get_trough_and_peak_idx
+ peaks_before : dict
+ Peak before trough results from get_trough_and_peak_idx
+ peaks_after : dict
+ Peak after trough results from get_trough_and_peak_idx
Returns
-------
- ptv: float
- The peak to valley duration in seconds
+ main_to_next_peak_duration : float
+ Duration in seconds from main extremum to next extremum
"""
- if trough_idx is None or peak_idx is None:
- trough_idx, peak_idx = get_trough_and_peak_idx(template_single)
- ptv = (peak_idx - trough_idx) / sampling_frequency
- return ptv
+
+ # Get main locations and values
+ trough_loc = troughs["main_loc"]
+ trough_val = template[trough_loc] if trough_loc is not None else None
+
+ peak_before_loc = peaks_before["main_loc"]
+ peak_before_val = template[peak_before_loc] if peak_before_loc is not None else None
+
+ peak_after_loc = peaks_after["main_loc"]
+ peak_after_val = template[peak_after_loc] if peak_after_loc is not None else None
+
+ # Find the main extremum (largest absolute value)
+ candidates = []
+ if trough_loc is not None and trough_val is not None:
+ candidates.append(("trough", trough_loc, abs(trough_val)))
+ if peak_before_loc is not None and peak_before_val is not None:
+ candidates.append(("peak_before", peak_before_loc, abs(peak_before_val)))
+ if peak_after_loc is not None and peak_after_val is not None:
+ candidates.append(("peak_after", peak_after_loc, abs(peak_after_val)))
+
+ if len(candidates) == 0:
+ return np.nan
+
+ # Sort by absolute value to find main extremum
+ candidates.sort(key=lambda x: x[2], reverse=True)
+ main_type, main_loc, _ = candidates[0]
+
+ # Find the next extremum after the main one
+ if main_type == "trough":
+ # Main is trough, next is peak_after
+ if peak_after_loc is not None:
+ duration_samples = abs(peak_after_loc - main_loc)
+ elif peak_before_loc is not None:
+ duration_samples = abs(main_loc - peak_before_loc)
+ else:
+ return np.nan
+ elif main_type == "peak_before":
+ # Main is peak before, next is trough
+ if trough_loc is not None:
+ duration_samples = abs(trough_loc - main_loc)
+ else:
+ return np.nan
+ else: # peak_after
+ # Main is peak after, previous is trough
+ if trough_loc is not None:
+ duration_samples = abs(main_loc - trough_loc)
+ else:
+ return np.nan
+
+ # Convert to seconds
+ main_to_next_peak_duration = duration_samples / sampling_frequency
+
+ return main_to_next_peak_duration
+
+
+def get_waveform_ratios(template, troughs, peaks_before, peaks_after, **kwargs):
+ """
+ Calculate various waveform amplitude ratios.
+
+ Parameters
+ ----------
+ template : numpy.ndarray
+ The 1D template waveform
+ troughs : dict
+ Trough detection results from get_trough_and_peak_idx
+ peaks_before : dict
+ Peak before trough results from get_trough_and_peak_idx
+ peaks_after : dict
+ Peak after trough results from get_trough_and_peak_idx
+
+ Returns
+ -------
+ ratios : dict
+ Dictionary containing:
+ - "peak_before_to_trough_ratio": ratio of peak before to trough amplitude
+ - "peak_after_to_trough_ratio": ratio of peak after to trough amplitude
+ - "peak_before_to_peak_after_ratio": ratio of peak before to peak after amplitude
+ - "main_peak_to_trough_ratio": ratio of larger peak to trough amplitude
+ """
+ # Get absolute amplitudes
+ trough_amp = abs(template[troughs["main_loc"]]) if troughs["main_loc"] is not None else np.nan
+ peak_before_amp = abs(template[peaks_before["main_loc"]]) if peaks_before["main_loc"] is not None else np.nan
+ peak_after_amp = abs(template[peaks_after["main_loc"]]) if peaks_after["main_loc"] is not None else np.nan
+
+ def safe_ratio(a, b):
+ if np.isnan(a) or np.isnan(b) or b == 0:
+ return np.nan
+ return a / b
+
+ ratios = {
+ "peak_before_to_trough_ratio": safe_ratio(peak_before_amp, trough_amp),
+ "peak_after_to_trough_ratio": safe_ratio(peak_after_amp, trough_amp),
+ "peak_before_to_peak_after_ratio": safe_ratio(peak_before_amp, peak_after_amp),
+ "main_peak_to_trough_ratio": safe_ratio(
+ (
+ max(peak_before_amp, peak_after_amp)
+ if not (np.isnan(peak_before_amp) and np.isnan(peak_after_amp))
+ else np.nan
+ ),
+ trough_amp,
+ ),
+ }
+
+ return ratios
-def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs) -> float:
+def get_waveform_baseline_flatness(template, sampling_frequency, **kwargs):
"""
- Return the peak to trough ratio of input waveforms.
+ Compute the baseline flatness of the waveform.
+
+ This metric measures the ratio of the max absolute amplitude in the baseline
+ window to the max absolute amplitude of the whole waveform. A lower value
+ indicates a flat baseline (expected for good units).
+
+ Parameters
+ ----------
+ template : numpy.ndarray
+ The 1D template waveform
+ sampling_frequency : float
+ The sampling frequency in Hz
+ **kwargs : Required kwargs:
+ - baseline_window_ms : tuple of (start_ms, end_ms) defining the baseline window
+ relative to waveform start. Default is (0, 0.5) for first 0.5ms.
+
+ Returns
+ -------
+ baseline_flatness : float
+ Ratio of max(abs(baseline)) / max(abs(waveform)). Lower = flatter baseline.
+ """
+ baseline_window_ms = kwargs.get("baseline_window_ms", (0.0, 0.5))
+
+ if baseline_window_ms is None:
+ return np.nan
+
+ start_ms, end_ms = baseline_window_ms
+ start_idx = int(start_ms / 1000 * sampling_frequency)
+ end_idx = int(end_ms / 1000 * sampling_frequency)
+
+ # Clamp to valid range
+ start_idx = max(0, start_idx)
+ end_idx = min(len(template), end_idx)
+
+ if end_idx <= start_idx:
+ return np.nan
+
+ baseline_segment = template[start_idx:end_idx]
+
+ if len(baseline_segment) == 0:
+ return np.nan
+
+ max_baseline = np.nanmax(np.abs(baseline_segment))
+ max_waveform = np.nanmax(np.abs(template))
+
+ if max_waveform == 0 or np.isnan(max_waveform):
+ return np.nan
+
+ baseline_flatness = max_baseline / max_waveform
+
+ return baseline_flatness
+
+
+def get_waveform_widths(template, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs):
+ """
+ Get the widths of the main trough and peaks in seconds.
+
+ Parameters
+ ----------
+ template : numpy.ndarray
+ The 1D template waveform
+ sampling_frequency : float
+ The sampling frequency in Hz
+ troughs : dict
+ Trough detection results from get_trough_and_peak_idx
+ peaks_before : dict
+ Peak before trough results from get_trough_and_peak_idx
+ peaks_after : dict
+ Peak after trough results from get_trough_and_peak_idx
+
+ Returns
+ -------
+ widths : dict
+ Dictionary containing:
+ - "trough_width": width of main trough in seconds
+ - "peak_before_width": width of main peak before trough in seconds
+ - "peak_after_width": width of main peak after trough in seconds
+ """
+
+ def get_main_width(feature_dict):
+ if feature_dict["main_idx"] is None:
+ return np.nan
+ widths = feature_dict.get("widths", np.array([]))
+ if len(widths) == 0:
+ return np.nan
+ main_idx = feature_dict["main_idx"]
+ if main_idx < len(widths):
+ return widths[main_idx]
+ return np.nan
+
+ # Convert from samples to seconds
+ samples_to_seconds = 1.0 / sampling_frequency
+
+ trough_width = get_main_width(troughs)
+ peak_before_width = get_main_width(peaks_before)
+ peak_after_width = get_main_width(peaks_after)
+
+ widths = {
+ "trough_width": trough_width * samples_to_seconds if not np.isnan(trough_width) else np.nan,
+ "peak_before_width": peak_before_width * samples_to_seconds if not np.isnan(peak_before_width) else np.nan,
+ "peak_after_width": peak_after_width * samples_to_seconds if not np.isnan(peak_after_width) else np.nan,
+ }
+
+ return widths
+
+
+#########################################################################################
+# Single-channel metrics
+def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float:
+ """
+ Return the peak to valley duration in seconds of input waveforms.
Parameters
----------
@@ -75,13 +454,17 @@ def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=N
Returns
-------
- ptratio: float
- The peak to trough ratio
+ ptv: float
+ The peak to valley duration in seconds
"""
if trough_idx is None or peak_idx is None:
- trough_idx, peak_idx = get_trough_and_peak_idx(template_single)
- ptratio = template_single[peak_idx] / template_single[trough_idx]
- return ptratio
+ troughs, _, peaks_after = get_trough_and_peak_idx(template_single)
+ trough_idx = troughs["main_loc"]
+ peak_idx = peaks_after["main_loc"]
+ if trough_idx is None or peak_idx is None:
+ return np.nan
+ ptv = (peak_idx - trough_idx) / sampling_frequency
+ return ptv
def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float:
@@ -105,9 +488,11 @@ def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_id
The half width in seconds
"""
if trough_idx is None or peak_idx is None:
- trough_idx, peak_idx = get_trough_and_peak_idx(template_single)
+ troughs, _, peaks_after = get_trough_and_peak_idx(template_single)
+ trough_idx = troughs["main_loc"]
+ peak_idx = peaks_after["main_loc"]
- if peak_idx == 0:
+ if peak_idx is None or peak_idx == 0:
return np.nan
trough_val = template_single[trough_idx]
@@ -156,11 +541,12 @@ def get_repolarization_slope(template_single, sampling_frequency, trough_idx=Non
The repolarization slope
"""
if trough_idx is None:
- trough_idx, _ = get_trough_and_peak_idx(template_single)
+ troughs, _, _ = get_trough_and_peak_idx(template_single)
+ trough_idx = troughs["main_loc"]
times = np.arange(template_single.shape[0]) / sampling_frequency
- if trough_idx == 0:
+ if trough_idx is None or trough_idx == 0:
return np.nan
(rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0)
@@ -209,11 +595,12 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa
assert "recovery_window_ms" in kwargs, "recovery_window_ms must be given as kwarg"
recovery_window_ms = kwargs["recovery_window_ms"]
if peak_idx is None:
- _, peak_idx = get_trough_and_peak_idx(template_single)
+ _, _, peaks_after = get_trough_and_peak_idx(template_single)
+ peak_idx = peaks_after["main_loc"]
times = np.arange(template_single.shape[0]) / sampling_frequency
- if peak_idx == 0:
+ if peak_idx is None or peak_idx == 0:
return np.nan
max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency))
max_idx = np.min([max_idx, template_single.shape[0]])
@@ -222,9 +609,12 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa
return res.slope
-def get_number_of_peaks(template_single, sampling_frequency, **kwargs):
+def get_number_of_peaks(template_single, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs):
"""
- Count the total number of peaks (positive + negative) in the template.
+ Count the total number of peaks (positive) and troughs (negative) in the template.
+
+ Uses the pre-computed peak/trough detection from get_trough_and_peak_idx which
+ applies smoothing for more robust detection.
Parameters
----------
@@ -232,28 +622,28 @@ def get_number_of_peaks(template_single, sampling_frequency, **kwargs):
The 1D template waveform
sampling_frequency : float
The sampling frequency of the template
- **kwargs: Required kwargs:
- - peak_relative_threshold: the relative threshold to detect positive and negative peaks
- - peak_width_ms: the width in samples to detect peaks
+ troughs : dict
+ Trough detection results from get_trough_and_peak_idx
+ peaks_before : dict
+ Peak before trough results from get_trough_and_peak_idx
+ peaks_after : dict
+ Peak after trough results from get_trough_and_peak_idx
Returns
-------
- number_of_peaks: int
- the total number of peaks (positive + negative)
- """
- from scipy.signal import find_peaks
-
- assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg"
- assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg"
- peak_relative_threshold = kwargs["peak_relative_threshold"]
- peak_width_ms = kwargs["peak_width_ms"]
- max_value = np.max(np.abs(template_single))
- peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency)
-
- pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples)
- neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples)
- num_positive = len(pos_peaks[0])
- num_negative = len(neg_peaks[0])
+ num_positive_peaks : int
+ The number of positive peaks (peaks_before + peaks_after)
+ num_negative_peaks : int
+ The number of negative peaks (troughs)
+ """
+ # Count peaks (positive) from peaks_before and peaks_after
+ num_peaks_before = len(peaks_before["indices"])
+ num_peaks_after = len(peaks_after["indices"])
+ num_positive = num_peaks_before + num_peaks_after
+
+ # Count troughs (negative)
+ num_negative = len(troughs["indices"])
+
return num_positive, num_negative
@@ -293,8 +683,10 @@ def fit_velocity(peak_times, channel_dist):
from sklearn.linear_model import TheilSenRegressor
- theil = TheilSenRegressor()
- theil.fit(peak_times.reshape(-1, 1), channel_dist)
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", message=".*Maximum number of iterations.*")
+ theil = TheilSenRegressor(max_iter=1000)
+ theil.fit(peak_times.reshape(-1, 1), channel_dist)
slope = theil.coef_[0]
intercept = theil.intercept_
score = theil.score(peak_times.reshape(-1, 1), channel_dist)
@@ -376,7 +768,11 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs):
"""
- Compute the exponential decay of the template amplitude over distance in units um/s.
+ Compute the spatial decay of the template amplitude over distance.
+
+ Can fit either an exponential decay (with offset) or a linear decay model. Channels are first
+ filtered by x-distance tolerance from the max channel, then the closest channels
+ in y-distance are used for fitting.
Parameters
----------
@@ -387,13 +783,18 @@ def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs
sampling_frequency : float
The sampling frequency of the template
**kwargs: Required kwargs:
- - peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min")
- - min_r2: the minimum r2 to accept the exp decay fit
+ - peak_function: the function to use to compute the peak amplitude ("ptp" or "min")
+ - min_r2: the minimum r2 to accept the fit
+ - linear_fit: bool, if True use linear fit, otherwise exponential fit
+ - channel_tolerance: max x-distance (um) from max channel to include channels
+ - min_channels_for_fit: minimum number of valid channels required for fitting
+ - num_channels_for_fit: number of closest channels to use for fitting
+ - normalize_decay: bool, if True normalize amplitudes to max before fitting
Returns
-------
exp_decay_value : float
- The exponential decay of the template amplitude
+ The spatial decay slope (decay constant for exp fit, negative slope for linear fit)
"""
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
@@ -401,41 +802,117 @@ def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs
def exp_decay(x, decay, amp0, offset):
return amp0 * np.exp(-decay * x) + offset
+ def linear_fit_func(x, a, b):
+ return a * x + b
+
+ # Extract parameters
assert "peak_function" in kwargs, "peak_function must be given as kwarg"
peak_function = kwargs["peak_function"]
assert "min_r2" in kwargs, "min_r2 must be given as kwarg"
min_r2 = kwargs["min_r2"]
- # exp decay fit
+
+ use_linear_fit = kwargs.get("linear_fit", False)
+ channel_tolerance = kwargs.get("channel_tolerance", None)
+ normalize_decay = kwargs.get("normalize_decay", False)
+
+ # Set defaults based on fit type if not specified
+ min_channels_for_fit = kwargs.get("min_channels_for_fit")
+ if min_channels_for_fit is None:
+ min_channels_for_fit = 5 if use_linear_fit else 8
+
+ num_channels_for_fit = kwargs.get("num_channels_for_fit")
+ if num_channels_for_fit is None:
+ num_channels_for_fit = 6 if use_linear_fit else 10
+
+ # Compute peak amplitudes per channel
if peak_function == "ptp":
fun = np.ptp
elif peak_function == "min":
fun = np.min
+ else:
+ fun = np.ptp
+
peak_amplitudes = np.abs(fun(template, axis=0))
- max_channel_location = channel_locations[np.argmax(peak_amplitudes)]
- channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations])
- distances_sort_indices = np.argsort(channel_distances)
+ max_channel_idx = np.argmax(peak_amplitudes)
+ max_channel_location = channel_locations[max_channel_idx]
- # longdouble is float128 when the platform supports it, otherwise it is float64
- channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.longdouble)
- peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.longdouble)
+ # Channel selection based on tolerance (new bombcell-style) or use all channels (old style)
+ if channel_tolerance is not None:
+ # Calculate x-distances from max channel
+ x_dist = np.abs(channel_locations[:, 0] - max_channel_location[0])
- try:
- amp0 = peak_amplitudes_sorted[0]
- offset0 = np.min(peak_amplitudes_sorted)
-
- popt, _ = curve_fit(
- exp_decay,
- channel_distances_sorted,
- peak_amplitudes_sorted,
- bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]),
- p0=[1e-3, peak_amplitudes_sorted[0], offset0],
+ # Find channels within x-distance tolerance
+ valid_x_channels = np.argwhere(x_dist <= channel_tolerance).flatten()
+
+ if len(valid_x_channels) < min_channels_for_fit:
+ return np.nan
+
+ # Calculate y-distances for channel selection
+ y_dist = np.abs(channel_locations[:, 1] - max_channel_location[1])
+
+ # Set y distances to max for channels outside x tolerance (so they won't be selected)
+ y_dist_masked = y_dist.copy()
+ y_dist_masked[~np.isin(np.arange(len(y_dist)), valid_x_channels)] = y_dist.max() + 1
+
+ # Select the closest channels in y-distance
+ use_these_channels = np.argsort(y_dist_masked)[:num_channels_for_fit]
+
+ # Calculate distances from max channel for selected channels
+ channel_distances = np.sqrt(
+ np.sum(np.square(channel_locations[use_these_channels] - max_channel_location), axis=1)
)
- r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt))
- exp_decay_value = popt[0]
+
+ # Get amplitudes for selected channels
+ spatial_decay_points = np.max(np.abs(template[:, use_these_channels]), axis=0)
+
+ # Sort by distance
+ sort_idx = np.argsort(channel_distances)
+ channel_distances_sorted = channel_distances[sort_idx]
+ peak_amplitudes_sorted = spatial_decay_points[sort_idx]
+
+ # Normalize if requested
+ if normalize_decay:
+ peak_amplitudes_sorted = peak_amplitudes_sorted / np.max(peak_amplitudes_sorted)
+
+ # Ensure float64 for numerical stability
+ channel_distances_sorted = np.float64(channel_distances_sorted)
+ peak_amplitudes_sorted = np.float64(peak_amplitudes_sorted)
+
+ else:
+ # Old style: use all channels sorted by distance
+ channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations])
+ distances_sort_indices = np.argsort(channel_distances)
+
+ # longdouble is float128 when the platform supports it, otherwise it is float64
+ channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.longdouble)
+ peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.longdouble)
+
+ try:
+ if use_linear_fit:
+ # Linear fit: y = a*x + b
+ popt, _ = curve_fit(linear_fit_func, channel_distances_sorted, peak_amplitudes_sorted)
+ predicted = linear_fit_func(channel_distances_sorted, *popt)
+ r2 = r2_score(peak_amplitudes_sorted, predicted)
+ exp_decay_value = -popt[0] # Negative of slope
+ else:
+ # Exponential fit with offset: y = amp0 * exp(-decay * x) + offset
+ amp0 = peak_amplitudes_sorted[0]
+ offset0 = np.min(peak_amplitudes_sorted)
+
+ popt, _ = curve_fit(
+ exp_decay,
+ channel_distances_sorted,
+ peak_amplitudes_sorted,
+ bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]),
+ p0=[1e-3, peak_amplitudes_sorted[0], offset0],
+ )
+ r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt))
+ exp_decay_value = popt[0]
if r2 < min_r2:
exp_decay_value = np.nan
- except:
+
+ except Exception:
exp_decay_value = np.nan
return exp_decay_value
@@ -512,17 +989,17 @@ def single_channel_metric(unit_function, sorting_analyzer, unit_ids, tmp_data, *
return result
-class PeakToValley(BaseMetric):
- metric_name = "peak_to_valley"
+class PeakToTroughDuration(BaseMetric):
+ metric_name = "peak_to_trough_duration"
metric_params = {}
- metric_columns = {"peak_to_valley": float}
+ metric_columns = {"peak_to_trough_duration": float}
metric_descriptions = {
- "peak_to_valley": "Duration in s between the trough (minimum) and the peak (maximum) of the spike waveform."
+ "peak_to_trough_duration": "Duration in seconds between the trough (minimum) and the peak (maximum) of the spike waveform."
}
needs_tmp_data = True
@staticmethod
- def _peak_to_valley_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params):
+ def _peak_to_trough_duration_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params):
return single_channel_metric(
unit_function=get_peak_to_valley,
sorting_analyzer=sorting_analyzer,
@@ -531,29 +1008,7 @@ def _peak_to_valley_metric_function(sorting_analyzer, unit_ids, tmp_data, **metr
**metric_params,
)
- metric_function = _peak_to_valley_metric_function
-
-
-class PeakToTroughRatio(BaseMetric):
- metric_name = "peak_trough_ratio"
- metric_params = {}
- metric_columns = {"peak_trough_ratio": float}
- metric_descriptions = {
- "peak_trough_ratio": "Ratio of the amplitude of the peak (maximum) to the trough (minimum) of the spike waveform."
- }
- needs_tmp_data = True
-
- @staticmethod
- def _peak_to_trough_ratio_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params):
- return single_channel_metric(
- unit_function=get_peak_trough_ratio,
- sorting_analyzer=sorting_analyzer,
- unit_ids=unit_ids,
- tmp_data=tmp_data,
- **metric_params,
- )
-
- metric_function = _peak_to_trough_ratio_metric_function
+ metric_function = _peak_to_trough_duration_metric_function
class HalfWidth(BaseMetric):
@@ -623,14 +1078,26 @@ def _recovery_slope_metric_function(sorting_analyzer, unit_ids, tmp_data, **metr
def _number_of_peaks_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params):
+ from collections import namedtuple
+
num_peaks_result = namedtuple("NumberOfPeaksResult", ["num_positive_peaks", "num_negative_peaks"])
num_positive_peaks_dict = {}
num_negative_peaks_dict = {}
- sampling_frequency = sorting_analyzer.sampling_frequency
+ sampling_frequency = tmp_data["sampling_frequency"]
templates_single = tmp_data["templates_single"]
+ troughs_info = tmp_data["troughs_info"]
+ peaks_before_info = tmp_data["peaks_before_info"]
+ peaks_after_info = tmp_data["peaks_after_info"]
for unit_index, unit_id in enumerate(unit_ids):
template_single = templates_single[unit_index]
- num_positive, num_negative = get_number_of_peaks(template_single, sampling_frequency, **metric_params)
+ num_positive, num_negative = get_number_of_peaks(
+ template_single,
+ sampling_frequency,
+ troughs_info[unit_id],
+ peaks_before_info[unit_id],
+ peaks_after_info[unit_id],
+ **metric_params,
+ )
num_positive_peaks_dict[unit_id] = num_positive
num_negative_peaks_dict[unit_id] = num_negative
return num_peaks_result(num_positive_peaks=num_positive_peaks_dict, num_negative_peaks=num_negative_peaks_dict)
@@ -639,26 +1106,192 @@ def _number_of_peaks_metric_function(sorting_analyzer, unit_ids, tmp_data, **met
class NumberOfPeaks(BaseMetric):
metric_name = "number_of_peaks"
metric_function = _number_of_peaks_metric_function
- metric_params = {"peak_relative_threshold": 0.2, "peak_width_ms": 0.1}
+ metric_params = {}
metric_columns = {"num_positive_peaks": int, "num_negative_peaks": int}
metric_descriptions = {
"num_positive_peaks": "Number of positive peaks in the template",
- "num_negative_peaks": "Number of negative peaks in the template",
+ "num_negative_peaks": "Number of negative peaks (troughs) in the template",
+ }
+ needs_tmp_data = True
+
+
+def _main_to_next_peak_duration_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params):
+ result = {}
+ templates_single = tmp_data["templates_single"]
+ troughs_info = tmp_data["troughs_info"]
+ peaks_before_info = tmp_data["peaks_before_info"]
+ peaks_after_info = tmp_data["peaks_after_info"]
+ sampling_frequency = tmp_data["sampling_frequency"]
+ for unit_index, unit_id in enumerate(unit_ids):
+ template_single = templates_single[unit_index]
+ value = get_main_to_next_peak_duration(
+ template_single,
+ sampling_frequency,
+ troughs_info[unit_id],
+ peaks_before_info[unit_id],
+ peaks_after_info[unit_id],
+ **metric_params,
+ )
+ result[unit_id] = value
+ return result
+
+
+class MainToNextPeakDuration(BaseMetric):
+ metric_name = "main_to_next_peak_duration"
+ metric_function = _main_to_next_peak_duration_metric_function
+ metric_params = {}
+ metric_columns = {"main_to_next_peak_duration": float}
+ metric_descriptions = {"main_to_next_peak_duration": "Duration in seconds from main extremum to next extremum."}
+ needs_tmp_data = True
+
+
+def _waveform_ratios_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params):
+ from collections import namedtuple
+
+ waveform_ratios_result = namedtuple(
+ "WaveformRatiosResult",
+ [
+ "peak_before_to_trough_ratio",
+ "peak_after_to_trough_ratio",
+ "peak_before_to_peak_after_ratio",
+ "main_peak_to_trough_ratio",
+ ],
+ )
+ peak_before_to_trough = {}
+ peak_after_to_trough = {}
+ peak_before_to_peak_after = {}
+ main_peak_to_trough = {}
+ templates_single = tmp_data["templates_single"]
+ troughs_info = tmp_data["troughs_info"]
+ peaks_before_info = tmp_data["peaks_before_info"]
+ peaks_after_info = tmp_data["peaks_after_info"]
+ for unit_index, unit_id in enumerate(unit_ids):
+ template_single = templates_single[unit_index]
+ ratios = get_waveform_ratios(
+ template_single,
+ troughs_info[unit_id],
+ peaks_before_info[unit_id],
+ peaks_after_info[unit_id],
+ **metric_params,
+ )
+ peak_before_to_trough[unit_id] = ratios["peak_before_to_trough_ratio"]
+ peak_after_to_trough[unit_id] = ratios["peak_after_to_trough_ratio"]
+ peak_before_to_peak_after[unit_id] = ratios["peak_before_to_peak_after_ratio"]
+ main_peak_to_trough[unit_id] = ratios["main_peak_to_trough_ratio"]
+ return waveform_ratios_result(
+ peak_before_to_trough_ratio=peak_before_to_trough,
+ peak_after_to_trough_ratio=peak_after_to_trough,
+ peak_before_to_peak_after_ratio=peak_before_to_peak_after,
+ main_peak_to_trough_ratio=main_peak_to_trough,
+ )
+
+
+class WaveformRatios(BaseMetric):
+ metric_name = "waveform_ratios"
+ metric_function = _waveform_ratios_metric_function
+ metric_params = {}
+ metric_columns = {
+ "peak_before_to_trough_ratio": float,
+ "peak_after_to_trough_ratio": float,
+ "peak_before_to_peak_after_ratio": float,
+ "main_peak_to_trough_ratio": float,
+ }
+ metric_descriptions = {
+ "peak_before_to_trough_ratio": "Ratio of peak before amplitude to trough amplitude",
+ "peak_after_to_trough_ratio": "Ratio of peak after amplitude to trough amplitude",
+ "peak_before_to_peak_after_ratio": "Ratio of peak before amplitude to peak after amplitude",
+ "main_peak_to_trough_ratio": "Ratio of main peak amplitude to trough amplitude",
+ }
+ needs_tmp_data = True
+
+
+def _waveform_widths_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params):
+ from collections import namedtuple
+
+ waveform_widths_result = namedtuple(
+ "WaveformWidthsResult", ["trough_width", "peak_before_width", "peak_after_width"]
+ )
+ trough_width_dict = {}
+ peak_before_width_dict = {}
+ peak_after_width_dict = {}
+ templates_single = tmp_data["templates_single"]
+ troughs_info = tmp_data["troughs_info"]
+ peaks_before_info = tmp_data["peaks_before_info"]
+ peaks_after_info = tmp_data["peaks_after_info"]
+ sampling_frequency = tmp_data["sampling_frequency"]
+ for unit_index, unit_id in enumerate(unit_ids):
+ template_single = templates_single[unit_index]
+ widths = get_waveform_widths(
+ template_single,
+ sampling_frequency,
+ troughs_info[unit_id],
+ peaks_before_info[unit_id],
+ peaks_after_info[unit_id],
+ **metric_params,
+ )
+ trough_width_dict[unit_id] = widths["trough_width"]
+ peak_before_width_dict[unit_id] = widths["peak_before_width"]
+ peak_after_width_dict[unit_id] = widths["peak_after_width"]
+ return waveform_widths_result(
+ trough_width=trough_width_dict, peak_before_width=peak_before_width_dict, peak_after_width=peak_after_width_dict
+ )
+
+
+class WaveformWidths(BaseMetric):
+ metric_name = "waveform_widths"
+ metric_function = _waveform_widths_metric_function
+ metric_params = {}
+ metric_columns = {
+ "trough_width": float,
+ "peak_before_width": float,
+ "peak_after_width": float,
+ }
+ metric_descriptions = {
+ "trough_width": "Width of the main trough in seconds",
+ "peak_before_width": "Width of the main peak before trough in seconds",
+ "peak_after_width": "Width of the main peak after trough in seconds",
+ }
+ needs_tmp_data = True
+
+
+def _waveform_baseline_flatness_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params):
+ result = {}
+ templates_single = tmp_data["templates_single"]
+ sampling_frequency = tmp_data["sampling_frequency"]
+ for unit_index, unit_id in enumerate(unit_ids):
+ template_single = templates_single[unit_index]
+ value = get_waveform_baseline_flatness(template_single, sampling_frequency, **metric_params)
+ result[unit_id] = value
+ return result
+
+
+class WaveformBaselineFlatness(BaseMetric):
+ metric_name = "waveform_baseline_flatness"
+ metric_function = _waveform_baseline_flatness_metric_function
+ metric_params = {"baseline_window_ms": (0.0, 0.5)}
+ metric_columns = {"waveform_baseline_flatness": float}
+ metric_descriptions = {
+ "waveform_baseline_flatness": "Ratio of max baseline amplitude to max waveform amplitude. Lower = flatter baseline."
}
needs_tmp_data = True
single_channel_metrics = [
- PeakToValley,
- PeakToTroughRatio,
+ PeakToTroughDuration,
HalfWidth,
RepolarizationSlope,
RecoverySlope,
NumberOfPeaks,
+ MainToNextPeakDuration,
+ WaveformRatios,
+ WaveformWidths,
+ WaveformBaselineFlatness,
]
def _get_velocity_fits_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params):
+ from collections import namedtuple
+
velocity_above_result = namedtuple("Velocities", ["velocity_above", "velocity_below"])
velocity_above_dict = {}
velocity_below_dict = {}
@@ -707,10 +1340,21 @@ def multi_channel_metric(unit_function, sorting_analyzer, unit_ids, tmp_data, **
class ExpDecay(BaseMetric):
metric_name = "exp_decay"
- metric_params = {"peak_function": "ptp", "min_r2": 0.2}
+ metric_params = {
+ "peak_function": "ptp",
+ "min_r2": 0.2,
+ "linear_fit": False,
+ "channel_tolerance": None, # None uses old style (all channels), set to e.g. 33 for bombcell-style
+ "min_channels_for_fit": None, # None means use default based on linear_fit (5 for linear, 8 for exp)
+ "num_channels_for_fit": None, # None means use default based on linear_fit (6 for linear, 10 for exp)
+ "normalize_decay": False,
+ }
metric_columns = {"exp_decay": float}
metric_descriptions = {
- "exp_decay": ("Exponential decay of the template amplitude over distance from the extremum channel (1/um).")
+ "exp_decay": (
+ "Spatial decay of the template amplitude over distance from the extremum channel (1/um). "
+ "Uses exponential or linear fit based on linear_fit parameter."
+ )
}
needs_tmp_data = True
@@ -729,7 +1373,7 @@ def _exp_decay_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_pa
class Spread(BaseMetric):
metric_name = "spread"
- metric_params = {"spread_threshold": 0.5, "spread_smooth_um": 20, "column_range": None}
+ metric_params = {"spread_threshold": 0.2, "spread_smooth_um": 20, "column_range": None}
metric_columns = {"spread": float}
metric_descriptions = {
"spread": (
diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py
index 85ef9e22cb..212369ca12 100644
--- a/src/spikeinterface/metrics/template/template_metrics.py
+++ b/src/spikeinterface/metrics/template/template_metrics.py
@@ -6,9 +6,8 @@
from __future__ import annotations
-import numpy as np
import warnings
-from copy import deepcopy
+import numpy as np
from spikeinterface.core.sortinganalyzer import register_result_extension
from spikeinterface.core.analyzer_extension_core import BaseMetricExtension
@@ -33,6 +32,8 @@ def get_template_metric_list():
def get_template_metric_names():
+ import warnings
+
warnings.warn(
"get_template_metric_names is deprecated and will be removed in a version 0.105.0. "
"Please use get_template_metric_list instead.",
@@ -45,8 +46,8 @@ def get_template_metric_names():
class ComputeTemplateMetrics(BaseMetricExtension):
"""
Compute template metrics including:
- * peak_to_valley
- * peak_trough_ratio
+ * peak_to_trough_duration
+ * peak_to_trough_ratio
* halfwidth
* repolarization_slope
* recovery_slope
@@ -95,6 +96,8 @@ class ComputeTemplateMetrics(BaseMetricExtension):
metric_list = single_channel_metrics + multi_channel_metrics
def _handle_backward_compatibility_on_load(self):
+ from copy import deepcopy
+
# For backwards compatibility - this reformats metrics_kwargs as metric_params
if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None:
@@ -106,24 +109,52 @@ def _handle_backward_compatibility_on_load(self):
del self.params["metrics_kwargs"]
# handle metric names change:
- # num_positive_peaks/num_negative_peaks merged into number_of_peaks
if "num_positive_peaks" in self.params["metric_names"]:
self.params["metric_names"].remove("num_positive_peaks")
if "number_of_peaks" not in self.params["metric_names"]:
self.params["metric_names"].append("number_of_peaks")
+ if "num_positive_peaks" in self.params["metric_params"]:
+ del self.params["metric_params"]["num_positive_peaks"]
if "num_negative_peaks" in self.params["metric_names"]:
self.params["metric_names"].remove("num_negative_peaks")
if "number_of_peaks" not in self.params["metric_names"]:
self.params["metric_names"].append("number_of_peaks")
+ if "num_negative_peaks" in self.params["metric_params"]:
+ del self.params["metric_params"]["num_negative_peaks"]
# velocity_above/velocity_below merged into velocity_fits
if "velocity_above" in self.params["metric_names"]:
self.params["metric_names"].remove("velocity_above")
if "velocity_fits" not in self.params["metric_names"]:
self.params["metric_names"].append("velocity_fits")
+ self.params["metric_params"]["velocity_fits"] = self.params["metric_params"]["velocity_above"]
+ self.params["metric_params"]["velocity_fits"]["min_channels"] = self.params["metric_params"][
+ "velocity_above"
+ ]["min_channels_for_velocity"]
+ self.params["metric_params"]["velocity_fits"]["min_r2"] = self.params["metric_params"]["velocity_above"][
+ "min_r2_velocity"
+ ]
+ del self.params["metric_params"]["velocity_above"]
if "velocity_below" in self.params["metric_names"]:
self.params["metric_names"].remove("velocity_below")
if "velocity_fits" not in self.params["metric_names"]:
self.params["metric_names"].append("velocity_fits")
+ # parameters are already updated from velocity_above
+ if "velocity_below" in self.params["metric_params"]:
+ del self.params["metric_params"]["velocity_below"]
+ # peak to valley -> peak_to_trough_duration
+ if "peak_to_valley" in self.params["metric_names"]:
+ self.params["metric_names"].remove("peak_to_valley")
+ if "peak_to_trough_duration" not in self.params["metric_names"]:
+ self.params["metric_names"].append("peak_to_trough_duration")
+ # peak_trough ratio -> main peak to trough ratio
+ # note that the new implementation correctly uses the absolute peak values,
+ # which is different from the old implementation.
+ # we make a flag to invert the polarity of old values if needed
+ if "peak_trough_ratio" in self.params["metric_names"]:
+ self.params["metric_names"].remove("peak_trough_ratio")
+ if "waveform_ratios" not in self.params["metric_names"]:
+ self.params["metric_names"].append("waveform_ratios")
+ self.params["metric_params"]["invert_peak_to_trough"] = True
def _set_params(
self,
@@ -137,6 +168,10 @@ def _set_params(
upsampling_factor=10,
include_multi_channel_metrics=False,
depth_direction="y",
+ min_thresh_detect_peaks_troughs=0.4,
+ smooth=True,
+ smooth_window_frac=0.1,
+ smooth_polyorder=3,
):
# Auto-detect if multi-channel metrics should be included based on number of channels
num_channels = self.sorting_analyzer.get_num_channels()
@@ -166,9 +201,15 @@ def _set_params(
upsampling_factor=upsampling_factor,
include_multi_channel_metrics=include_multi_channel_metrics,
depth_direction=depth_direction,
+ min_thresh_detect_peaks_troughs=min_thresh_detect_peaks_troughs,
+ smooth=smooth,
+ smooth_window_frac=smooth_window_frac,
+ smooth_polyorder=smooth_polyorder,
)
def _prepare_data(self, sorting_analyzer, unit_ids):
+ import warnings
+
from scipy.signal import resample_poly
# compute templates_single and templates_multi (if include_multi_channel_metrics is True)
@@ -197,6 +238,9 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
templates_single = []
troughs = {}
peaks = {}
+ troughs_info = {}
+ peaks_before_info = {}
+ peaks_after_info = {}
templates_multi = []
channel_locations_multi = []
for unit_id in unit_ids:
@@ -210,11 +254,22 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
else:
template_upsampled = template_single
sampling_frequency_up = sampling_frequency
- trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled)
+ troughs_dict, peaks_before_dict, peaks_after_dict = get_trough_and_peak_idx(
+ template_upsampled,
+ min_thresh_detect_peaks_troughs=self.params["min_thresh_detect_peaks_troughs"],
+ smooth=self.params["smooth"],
+ smooth_window_frac=self.params["smooth_window_frac"],
+ smooth_polyorder=self.params["smooth_polyorder"],
+ )
templates_single.append(template_upsampled)
- troughs[unit_id] = trough_idx
- peaks[unit_id] = peak_idx
+ # Store main locations for backward compatibility
+ troughs[unit_id] = troughs_dict["main_loc"]
+ peaks[unit_id] = peaks_after_dict["main_loc"]
+ # Store full dicts for new metrics
+ troughs_info[unit_id] = troughs_dict
+ peaks_before_info[unit_id] = peaks_before_dict
+ peaks_after_info[unit_id] = peaks_after_dict
if include_multi_channel_metrics:
if sorting_analyzer.is_sparse():
@@ -239,6 +294,9 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
tmp_data["troughs"] = troughs
tmp_data["peaks"] = peaks
+ tmp_data["troughs_info"] = troughs_info
+ tmp_data["peaks_before_info"] = peaks_before_info
+ tmp_data["peaks_after_info"] = peaks_after_info
tmp_data["templates_single"] = np.array(templates_single)
if include_multi_channel_metrics:
@@ -249,6 +307,18 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
return tmp_data
+ def get_data(self, *args, **kwargs):
+ """Override to handle deprecated polarity of 'peak_trough_ratio' metric."""
+ metrics = super().get_data(*args, **kwargs)
+ if self.params["metric_params"].get("invert_peak_to_trough", False):
+ if "peak_trough_ratio" in metrics.columns:
+ warnings.warn(
+ "The 'peak_trough_ratio' metric has been deprecated and replaced by 'main_peak_to_trough_ratio'. "
+ "The values have been inverted to maintain consistency with previous versions."
+ )
+ metrics["peak_trough_ratio"] = -metrics["peak_trough_ratio"]
+ return metrics
+
register_result_extension(ComputeTemplateMetrics)
compute_template_metrics = ComputeTemplateMetrics.function_factory()
@@ -273,6 +343,8 @@ def get_default_tm_params(metric_names=None):
metric_params : dict
Dictionary with default parameters for template metrics.
"""
+ import warnings
+
warnings.warn(
"get_default_tm_params is deprecated and will be removed in a version 0.105.0. "
"Please use get_default_template_metrics_params instead.",
diff --git a/src/spikeinterface/metrics/template/tests/test_template_metrics.py b/src/spikeinterface/metrics/template/tests/test_template_metrics.py
index 8f1bf05b85..66437b156e 100644
--- a/src/spikeinterface/metrics/template/tests/test_template_metrics.py
+++ b/src/spikeinterface/metrics/template/tests/test_template_metrics.py
@@ -83,7 +83,7 @@ def test_metric_names_in_same_order(small_sorting_analyzer):
"""
Computes sepecified template metrics and checks order is propagated.
"""
- specified_metric_names = ["peak_trough_ratio", "half_width", "peak_to_valley"]
+ specified_metric_names = ["main_peak_to_trough_ratio", "half_width", "peak_to_valley"]
small_sorting_analyzer.compute(
"template_metrics", metric_names=specified_metric_names, delete_existing_metrics=True
)
diff --git a/src/spikeinterface/widgets/bombcell_curation.py b/src/spikeinterface/widgets/bombcell_curation.py
new file mode 100644
index 0000000000..030cb67327
--- /dev/null
+++ b/src/spikeinterface/widgets/bombcell_curation.py
@@ -0,0 +1,389 @@
+"""Widgets for visualizing unit labeling results."""
+
+from __future__ import annotations
+
+import numpy as np
+from typing import Optional
+
+from .base import BaseWidget, to_attr
+
+from .unit_labels import WaveformOverlayByLabelWidget
+
+
+def _is_threshold_disabled(value):
+ """Check if a threshold value is disabled (None or np.nan)."""
+ if value is None:
+ return True
+ if isinstance(value, float) and np.isnan(value):
+ return True
+ return False
+
+
+class LabelingHistogramsWidget(BaseWidget):
+ """Plot histograms of quality metrics with threshold lines."""
+
+ def __init__(
+ self,
+ sorting_analyzer,
+ thresholds: Optional[dict] = None,
+ metrics_to_plot: Optional[list] = None,
+ backend=None,
+ **backend_kwargs,
+ ):
+ from spikeinterface.curation import bombcell_get_default_thresholds
+
+ sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)
+ combined_metrics = sorting_analyzer.get_metrics_extension_data()
+ if combined_metrics.empty:
+ raise ValueError(
+ "SortingAnalyzer has no metrics extensions computed. "
+ "Compute quality_metrics and/or template_metrics first."
+ )
+
+ if thresholds is None:
+ thresholds = bombcell_get_default_thresholds()
+ if metrics_to_plot is None:
+ metrics_to_plot = [m for m in thresholds.keys() if m in combined_metrics.columns]
+
+ plot_data = dict(
+ quality_metrics=combined_metrics,
+ thresholds=thresholds,
+ metrics_to_plot=metrics_to_plot,
+ )
+ BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
+
+ def plot_matplotlib(self, data_plot, **backend_kwargs):
+ from .utils_matplotlib import make_mpl_figure
+ import matplotlib.pyplot as plt
+
+ dp = to_attr(data_plot)
+ quality_metrics = dp.quality_metrics
+ thresholds = dp.thresholds
+ metrics_to_plot = dp.metrics_to_plot
+
+ n_metrics = len(metrics_to_plot)
+ if n_metrics == 0:
+ print("No metrics to plot")
+ return
+
+ n_cols = min(4, n_metrics)
+ n_rows = int(np.ceil(n_metrics / n_cols))
+ backend_kwargs["ncols"] = n_cols
+ if "figsize" not in backend_kwargs:
+ backend_kwargs["figsize"] = (4 * n_cols, 3 * n_rows)
+ self.figure, self.ax, self.axes = make_mpl_figure(n_rows, n_cols, **backend_kwargs)
+
+ colors = plt.cm.tab10(np.linspace(0, 1, 10))
+ absolute_value_metrics = ["amplitude_median"]
+
+ axes = self.axes
+ for idx, metric_name in enumerate(metrics_to_plot):
+ row, col = idx // n_cols, idx % n_cols
+ ax = axes[row, col]
+
+ values = quality_metrics[metric_name].values
+ if metric_name in absolute_value_metrics:
+ values = np.abs(values)
+ values = values[~np.isnan(values) & ~np.isinf(values)]
+
+ if len(values) == 0:
+ ax.set_title(f"{metric_name}\n(no valid data)")
+ continue
+
+ ax.hist(values, bins=30, color=colors[idx % 10], alpha=0.7, edgecolor="black", density=True)
+
+ thresh = thresholds.get(metric_name, {})
+ has_thresh = False
+ if not _is_threshold_disabled(thresh.get("min", None)):
+ ax.axvline(thresh["min"], color="red", ls="--", lw=2, label=f"min={thresh['min']:.2g}")
+ has_thresh = True
+ if not _is_threshold_disabled(thresh.get("max", None)):
+ ax.axvline(thresh["max"], color="blue", ls="--", lw=2, label=f"max={thresh['max']:.2g}")
+ has_thresh = True
+
+ ax.set_xlabel(metric_name)
+ ax.set_ylabel("Density")
+ if has_thresh:
+ ax.legend(fontsize=8, loc="upper right")
+ ax.spines["top"].set_visible(False)
+ ax.spines["right"].set_visible(False)
+
+ for idx in range(len(metrics_to_plot), n_rows * n_cols):
+ axes[idx // n_cols, idx % n_cols].set_visible(False)
+
+
+class UpsetPlotWidget(BaseWidget):
+ """
+ Plot UpSet plots showing which metrics fail together for each unit type.
+
+ Requires `upsetplot` package. Each unit type shows relevant metrics:
+ NOISE -> waveform metrics, MUA -> spike quality metrics, NON_SOMA -> non-somatic metrics.
+ """
+
+ def __init__(
+ self,
+ sorting_analyzer,
+ unit_type: np.ndarray,
+ unit_type_string: np.ndarray,
+ thresholds: Optional[dict] = None,
+ unit_types_to_plot: Optional[list] = None,
+ split_non_somatic: bool = False,
+ min_subset_size: int = 1,
+ backend=None,
+ **backend_kwargs,
+ ):
+ from spikeinterface.curation import bombcell_get_default_thresholds
+
+ sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)
+ combined_metrics = sorting_analyzer.get_metrics_extension_data()
+ if combined_metrics.empty:
+ raise ValueError(
+ "SortingAnalyzer has no metrics extensions computed. "
+ "Compute quality_metrics and/or template_metrics first."
+ )
+
+ if thresholds is None:
+ thresholds = bombcell_get_default_thresholds()
+ if unit_types_to_plot is None:
+ if split_non_somatic:
+ unit_types_to_plot = ["noise", "mua", "non_soma_good", "non_soma_mua"]
+ else:
+ unit_types_to_plot = ["noise", "mua", "non_soma"]
+
+ plot_data = dict(
+ quality_metrics=combined_metrics,
+ unit_type=unit_type,
+ unit_type_string=unit_type_string,
+ thresholds=thresholds,
+ unit_types_to_plot=unit_types_to_plot,
+ min_subset_size=min_subset_size,
+ )
+ BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
+
+ def _get_metrics_for_unit_type(self, unit_type_label):
+ from spikeinterface.curation.bombcell_curation import (
+ NOISE_METRICS,
+ SPIKE_QUALITY_METRICS,
+ NON_SOMATIC_METRICS,
+ )
+
+ if unit_type_label == "noise":
+ return NOISE_METRICS
+ elif unit_type_label == "mua":
+ return SPIKE_QUALITY_METRICS
+ elif unit_type_label in ("non_soma", "non_soma_good", "non_soma_mua"):
+ return NON_SOMATIC_METRICS
+ return None
+
+ def plot_matplotlib(self, data_plot, **backend_kwargs):
+ from .utils_matplotlib import make_mpl_figure
+ import warnings
+ import matplotlib.pyplot as plt
+ import pandas as pd
+
+ dp = to_attr(data_plot)
+ quality_metrics = dp.quality_metrics
+ unit_type_string = dp.unit_type_string
+ thresholds = dp.thresholds
+ unit_types_to_plot = dp.unit_types_to_plot
+ min_subset_size = dp.min_subset_size
+
+ try:
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=FutureWarning, module="upsetplot")
+ from upsetplot import UpSet, from_memberships
+ except ImportError:
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.text(
+ 0.5,
+ 0.5,
+ "UpSet plots require 'upsetplot' package.\n\npip install upsetplot",
+ ha="center",
+ va="center",
+ fontsize=14,
+ family="monospace",
+ bbox=dict(boxstyle="round", facecolor="lightyellow", edgecolor="orange"),
+ )
+ ax.axis("off")
+ ax.set_title("UpSet Plot - Package Not Installed", fontsize=16)
+ self.figure = fig
+ self.axes = ax
+ self.figures = [fig]
+ return
+
+ failure_table = self._build_failure_table(quality_metrics, thresholds)
+ figures = []
+ axes_list = []
+
+ for unit_type_label in unit_types_to_plot:
+ mask = unit_type_string == unit_type_label
+ n_units = np.sum(mask)
+ if n_units == 0:
+ continue
+
+ relevant_metrics = self._get_metrics_for_unit_type(unit_type_label)
+ if relevant_metrics is not None:
+ available_metrics = [m for m in relevant_metrics if m in failure_table.columns]
+ if len(available_metrics) == 0:
+ continue
+ unit_failure_table = failure_table[available_metrics]
+ else:
+ unit_failure_table = failure_table
+
+ unit_failures = unit_failure_table.loc[mask]
+ memberships = []
+ for idx in unit_failures.index:
+ failed = unit_failures.columns[unit_failures.loc[idx]].tolist()
+ if failed:
+ memberships.append(failed)
+
+ if not memberships:
+ continue
+
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=FutureWarning, module="upsetplot")
+ upset_data = from_memberships(memberships)
+ upset_data = upset_data[upset_data >= min_subset_size]
+ if len(upset_data) == 0:
+ continue
+
+ fig = plt.figure(figsize=(12, 6))
+ UpSet(
+ upset_data,
+ subset_size="count",
+ show_counts=True,
+ sort_by="cardinality",
+ sort_categories_by="cardinality",
+ ).plot(fig=fig)
+ fig.suptitle(f"{unit_type_label} (n={n_units})", fontsize=14, y=1.02)
+ figures.append(fig)
+ axes_list.append(fig.axes)
+
+ if not figures:
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
+ ax.text(0.5, 0.5, "No units found or no metric failures detected.", ha="center", va="center", fontsize=12)
+ ax.axis("off")
+ figures = [fig]
+ axes_list = [ax]
+
+ self.figures = figures
+ self.figure = figures[0] if figures else None
+ self.axes = axes_list
+
+ def _build_failure_table(self, quality_metrics, thresholds):
+ import pandas as pd
+
+ absolute_value_metrics = ["amplitude_median"]
+ failure_data = {}
+
+ for metric_name, thresh in thresholds.items():
+ if metric_name not in quality_metrics.columns:
+ continue
+ values = quality_metrics[metric_name].values.copy()
+ if metric_name in absolute_value_metrics:
+ values = np.abs(values)
+
+ failed = np.isnan(values)
+ if not _is_threshold_disabled(thresh.get("min", None)):
+ failed |= values < thresh["min"]
+ if not _is_threshold_disabled(thresh.get("max", None)):
+ failed |= values > thresh["max"]
+ failure_data[metric_name] = failed
+
+ return pd.DataFrame(failure_data, index=quality_metrics.index)
+
+
+def plot_unit_labeling_all(
+ sorting_analyzer,
+ unit_type: np.ndarray,
+ unit_type_string: np.ndarray,
+ thresholds: Optional[dict] = None,
+ split_non_somatic: bool = False,
+ include_upset: bool = True,
+ save_folder=None,
+ backend=None,
+ **kwargs,
+):
+ """
+ Generate all unit labeling plots and optionally save to folder.
+
+ Parameters
+ ----------
+ sorting_analyzer : SortingAnalyzer
+ The sorting analyzer object with computed metrics extensions.
+ unit_type : np.ndarray
+ Array of unit type codes (0=noise, 1=good, 2=mua, 3=non_soma, etc.).
+ unit_type_string : np.ndarray
+ Array of unit type labels as strings.
+ thresholds : dict, optional
+ Threshold dictionary. If None, uses default thresholds.
+ split_non_somatic : bool, default: False
+ Whether to split "non_soma" into "non_soma_good" and "non_soma_mua".
+ include_upset : bool, default: True
+ Whether to include UpSet plots (requires upsetplot package).
+ save_folder : str or Path, optional
+ If provided, saves all plots and CSV results to this folder.
+ backend : str, optional
+ Plotting backend.
+ **kwargs
+ Additional arguments passed to plot functions.
+
+ Returns
+ -------
+ dict
+ Dictionary with keys 'histograms', 'waveforms', 'upset' containing widget objects.
+ """
+ from pathlib import Path
+ from spikeinterface.curation import bombcell_get_default_thresholds, save_bombcell_results
+
+ if thresholds is None:
+ thresholds = bombcell_get_default_thresholds()
+
+ combined_metrics = sorting_analyzer.get_metrics_extension_data()
+ has_metrics = not combined_metrics.empty
+
+ results = {}
+
+ # Histograms
+ if has_metrics:
+ results["histograms"] = LabelingHistogramsWidget(
+ sorting_analyzer,
+ thresholds=thresholds,
+ backend=backend,
+ **kwargs,
+ )
+
+ # Waveform overlay
+ results["waveforms"] = WaveformOverlayByLabelWidget(sorting_analyzer, unit_type_string, backend=backend, **kwargs)
+
+ # UpSet plots
+ if include_upset and has_metrics:
+ results["upset"] = UpsetPlotWidget(
+ sorting_analyzer,
+ unit_type,
+ unit_type_string,
+ thresholds=thresholds,
+ split_non_somatic=split_non_somatic,
+ backend=backend,
+ **kwargs,
+ )
+
+ # Save to folder if requested
+ if save_folder is not None:
+ save_folder = Path(save_folder)
+ save_folder.mkdir(parents=True, exist_ok=True)
+
+ # Save plots
+ if "histograms" in results and results["histograms"].figure is not None:
+ results["histograms"].figure.savefig(save_folder / "labeling_histograms.png", dpi=150, bbox_inches="tight")
+ if "waveforms" in results and results["waveforms"].figure is not None:
+ results["waveforms"].figure.savefig(save_folder / "waveform_overlay.png", dpi=150, bbox_inches="tight")
+ if "upset" in results and hasattr(results["upset"], "figures"):
+ for i, fig in enumerate(results["upset"].figures):
+ fig.savefig(save_folder / f"upset_plot_{i}.png", dpi=150, bbox_inches="tight")
+
+ # Save CSV results
+ if has_metrics:
+ save_bombcell_results(combined_metrics, unit_type, unit_type_string, thresholds, save_folder)
+
+ return results
diff --git a/src/spikeinterface/widgets/unit_labels.py b/src/spikeinterface/widgets/unit_labels.py
new file mode 100644
index 0000000000..03ee0b2391
--- /dev/null
+++ b/src/spikeinterface/widgets/unit_labels.py
@@ -0,0 +1,114 @@
+"""Widgets for visualizing unit labeling results."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from .base import BaseWidget, to_attr
+
+
+class WaveformOverlayByLabelWidget(BaseWidget):
+ """Plot overlaid waveforms grouped by unit label type.
+
+ Parameters
+ ----------
+ sorting_analyzer : SortingAnalyzer
+ A SortingAnalyzer object with 'templates' extension computed.
+ unit_labels : np.ndarray
+ Array of unit type labels corresponding to each unit in the sorting.
+ labels_order : list, optional
+ List specifying the order of labels to display. If None, unique labels in unit_labels are
+ used in the order they appear.
+ max_columns : int, default: 3
+ Maximum number of columns in the plot grid.
+ """
+
+ def __init__(
+ self,
+ sorting_analyzer,
+ unit_labels: np.ndarray,
+ labels_order=None,
+ max_columns: int = 3,
+ backend=None,
+ **backend_kwargs,
+ ):
+ sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)
+ self.check_extensions(sorting_analyzer, "templates")
+ if labels_order is not None:
+ assert len(labels_order) == len(np.unique(unit_labels)), "labels_order length must match unique unit types"
+ assert all(
+ [label in np.unique(unit_labels) for label in labels_order]
+ ), "All labels in labels_order must be present in unit_labels"
+ else:
+ labels_order = np.unique(unit_labels)
+ plot_data = dict(
+ sorting_analyzer=sorting_analyzer,
+ labels_order=labels_order,
+ unit_labels=unit_labels,
+ max_columns=max_columns,
+ )
+ BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
+
+ def plot_matplotlib(self, data_plot, **backend_kwargs):
+ import matplotlib.pyplot as plt
+ from .utils_matplotlib import make_mpl_figure
+
+ dp = to_attr(data_plot)
+ sorting_analyzer = dp.sorting_analyzer
+ unit_labels = dp.unit_labels
+ labels_order = dp.labels_order
+
+ if not sorting_analyzer.has_extension("templates"):
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
+ ax.text(
+ 0.5,
+ 0.5,
+ "Templates extension not computed.\nRun: analyzer.compute('templates')",
+ ha="center",
+ va="center",
+ fontsize=12,
+ )
+ ax.axis("off")
+ self.figure = fig
+ self.axes = ax
+ return
+
+ templates_ext = sorting_analyzer.get_extension("templates")
+ templates = templates_ext.get_templates(operator="average")
+
+ backend_kwargs["num_axes"] = len(labels_order)
+ if len(labels_order) <= dp.max_columns:
+ ncols = len(labels_order)
+ else:
+ ncols = int(np.ceil(len(labels_order) / 2))
+ nrows = int(np.ceil(len(labels_order) / ncols))
+ backend_kwargs["ncols"] = ncols
+ if "figsize" not in backend_kwargs:
+ backend_kwargs["figsize"] = (5 * ncols, 4 * nrows)
+ self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)
+
+ axes_flat = self.axes.flatten()
+ for index, label in enumerate(labels_order):
+ ax = axes_flat[index]
+ mask = unit_labels == label
+ n_units = np.sum(mask)
+
+ if n_units > 0:
+ unit_indices = np.where(mask)[0]
+ alpha = max(0.05, min(0.3, 10 / n_units))
+ for unit_idx in unit_indices:
+ template = templates[unit_idx]
+ best_chan = np.argmax(np.max(np.abs(template), axis=0))
+ ax.plot(template[:, best_chan], color="black", alpha=alpha, linewidth=0.5)
+ ax.set_title(f"{label} (n={n_units})")
+ else:
+ ax.set_title(f"{label} (n=0)")
+ ax.text(0.5, 0.5, "No units", ha="center", va="center", transform=ax.transAxes)
+
+ for spine in ax.spines.values():
+ spine.set_visible(False)
+ ax.set_xticks([])
+ ax.set_yticks([])
+
+ for idx in range(len(labels_order), len(axes_flat)):
+ axes_flat[idx].set_visible(False)
diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py
index 6edba67c96..fe191d4450 100644
--- a/src/spikeinterface/widgets/widget_list.py
+++ b/src/spikeinterface/widgets/widget_list.py
@@ -37,12 +37,19 @@
from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget
from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyAgreementMatrix, StudySummary
from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget
+from .unit_labels import WaveformOverlayByLabelWidget
+from .bombcell_curation import (
+ LabelingHistogramsWidget,
+ UpsetPlotWidget,
+ plot_unit_labeling_all,
+)
widget_list = [
AgreementMatrixWidget,
AllAmplitudesDistributionsWidget,
AmplitudesWidget,
AutoCorrelogramsWidget,
+ LabelingHistogramsWidget,
ConfusionMatrixWidget,
ComparisonCollisionBySimilarityWidget,
CrossCorrelogramsWidget,
@@ -75,6 +82,8 @@
UnitTemplatesWidget,
UnitWaveformDensityMapWidget,
UnitWaveformsWidget,
+ UpsetPlotWidget,
+ WaveformOverlayByLabelWidget,
StudyRunTimesWidget,
StudyUnitCountsWidget,
StudyPerformances,
@@ -148,6 +157,9 @@
plot_template_similarity = TemplateSimilarityWidget
plot_traces = TracesWidget
plot_unit_depths = UnitDepthsWidget
+plot_unit_labels = WaveformOverlayByLabelWidget
+plot_unit_labeling_upset = UpsetPlotWidget
+plot_unit_labeling_histograms = LabelingHistogramsWidget
plot_unit_locations = UnitLocationsWidget
plot_unit_presence = UnitPresenceWidget
plot_unit_probe_map = UnitProbeMapWidget