Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,10 @@ spikeinterface.curation
.. autofunction:: remove_redundant_units
.. autofunction:: remove_duplicated_spikes
.. autofunction:: remove_excess_spikes
.. autofunction:: auto_label_units
.. autofunction:: model_based_label_units
.. autofunction:: load_model
.. autofunction:: train_model
.. autofunction:: unitrefine_label_units

Curation Model
~~~~~~~~~~~~~~
Expand Down
10 changes: 5 additions & 5 deletions doc/how_to/auto_curation_prediction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ repo's URL after huggingface.co/) and that we trust the model.

.. code::

from spikeinterface.curation import auto_label_units
from spikeinterface.curation import model_based_label_units

labels_and_probabilities = auto_label_units(
labels_and_probabilities = model_based_label_units(
sorting_analyzer = sorting_analyzer,
repo_id = "SpikeInterface/toy_tetrode_model",
trust_model = True
Expand All @@ -29,7 +29,7 @@ create the labels:

.. code::

labels_and_probabilities = si.auto_label_units(
labels_and_probabilities = si.model_based_label_units(
sorting_analyzer = sorting_analyzer,
model_folder = "my_folder_with_a_model_in_it",
)
Expand All @@ -39,5 +39,5 @@ are also saved as a property of your ``sorting_analyzer`` and can be accessed li

.. code::

labels = sorting_analyzer.sorting.get_property("classifier_label")
probabilities = sorting_analyzer.sorting.get_property("classifier_probability")
labels = sorting_analyzer.get_sorting_property("classifier_label")
probabilities = sorting_analyzer.get_sorting_property("classifier_probability")
2 changes: 1 addition & 1 deletion doc/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ If you use the default "similarity_correlograms" preset in the :code:`compute_me

If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_

If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_
If you use :code:`unitrefine_label_units`, :code:`model_based_label_units` or :code:`train_model`, please cite [Jain]_

Benchmark
---------
Expand Down
28 changes: 20 additions & 8 deletions examples/tutorials/curation/plot_1_automated_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@

##############################################################################
# Great! We can now use the model to predict labels. Here, we pass the HF repo id directly
# to the ``auto_label_units`` function. This returns a dictionary containing a label and
# to the ``model_based_label_units`` function. This returns a dictionary containing a label and
# a confidence for each unit contained in the ``sorting_analyzer``.

labels = sc.auto_label_units(
labels = sc.model_based_label_units(
sorting_analyzer = sorting_analyzer,
repo_id = "SpikeInterface/toy_tetrode_model",
trusted = ['numpy.dtype']
Expand Down Expand Up @@ -211,16 +211,16 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
# For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in
# V1,SC and ALM: https://huggingface.co/SpikeInterface/UnitRefine_noise_neural_classifier/ and
# https://huggingface.co/SpikeInterface/UnitRefine_sua_mua_classifier/. One will classify units into
# `noise` or `not-noise` and the other will classify the `not-noise` units into single
# `noise` or `neural` and the other will classify the `neural` units into single
# unit activity (sua) units and multi-unit activity (mua) units.
#
# There is more information about the model on the model's HuggingFace page. Take a look!
# The idea here is to first apply the noise/not-noise classifier, then the sua/mua one.
# The idea here is to first apply the noise/neural classifier, then the sua/mua one.
# We can do so as follows:
#

# Apply the noise/not-noise model
noise_neuron_labels = sc.auto_label_units(
# Apply the noise/neural model
noise_neuron_labels = sc.model_based_label_units(
sorting_analyzer=sorting_analyzer,
repo_id="SpikeInterface/UnitRefine_noise_neural_classifier",
trust_model=True,
Expand All @@ -230,7 +230,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
analyzer_neural = sorting_analyzer.remove_units(noise_units.index)

# Apply the sua/mua model
sua_mua_labels = sc.auto_label_units(
sua_mua_labels = sc.model_based_label_units(
sorting_analyzer=analyzer_neural,
repo_id="SpikeInterface/UnitRefine_sua_mua_classifier",
trust_model=True,
Expand All @@ -239,6 +239,18 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index()
print(all_labels)

##############################################################################
# Both steps can be done in one go using the ``unitrefine_label_units`` function:
#

all_labels = sc.unitrefine_label_units(
sorting_analyzer,
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier",
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier",
)
print(all_labels)


##############################################################################
# If you run this without the ``trust_model=True`` parameter, you will receive an error:
#
Expand Down Expand Up @@ -276,7 +288,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
#
# .. code-block::
#
# labels = sc.auto_label_units(
# labels = sc.model_based_label_units(
# sorting_analyzer = sorting_analyzer,
# model_folder = "path/to/model/folder",
# )
Expand Down
8 changes: 4 additions & 4 deletions examples/tutorials/curation/plot_3_upload_a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@
#
# ` ` ` python (NOTE: you should remove the spaces between each backtick. This is just formatting for the notebook you are reading)
#
# from spikeinterface.curation import auto_label_units
# labels = auto_label_units(
# from spikeinterface.curation import model_based_label_units
# labels = model_based_label_units(
# sorting_analyzer = sorting_analyzer,
# repo_id = "SpikeInterface/toy_tetrode_model",
# trust_model=True
Expand All @@ -123,9 +123,9 @@
# or you can download the entire repositry to `a_folder_for_a_model`, and use
#
# ` ` ` python
# from spikeinterface.curation import auto_label_units
# from spikeinterface.curation import model_based_label_units
#
# labels = auto_label_units(
# labels = model_based_label_units(
# sorting_analyzer = sorting_analyzer,
# model_folder = "path/to/a_folder_for_a_model",
# trusted = ['numpy.dtype']
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/curation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from .sortingview_curation import apply_sortingview_curation

# automated curation
from .model_based_curation import auto_label_units, load_model
from .model_based_curation import model_based_label_units, load_model
from .train_manual_curation import train_model, get_default_classifier_search_spaces
from .unitrefine_curation import unitrefine_label_units
20 changes: 16 additions & 4 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def predict_labels(
)

# Set predictions and probability as sorting properties
self.sorting_analyzer.sorting.set_property("classifier_label", predictions)
self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities)
self.sorting_analyzer.set_sorting_property("classifier_label", predictions)
self.sorting_analyzer.set_sorting_property("classifier_probability", probabilities)

if export_to_phy:
self._export_to_phy(classified_units)
Expand Down Expand Up @@ -204,7 +204,7 @@ def _export_to_phy(self, classified_df):
classified_df.to_csv(f"{sorting_path}/cluster_prediction.tsv", sep="\t", index_label="cluster_id")


def auto_label_units(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we deprecate auto_label_units (slowly!)? I think a lot of people have copies of unitRefine notebooks using it!

def model_based_label_units(
sorting_analyzer: SortingAnalyzer,
model_folder=None,
model_name=None,
Expand All @@ -227,7 +227,7 @@ def auto_label_units(
----------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting results.
model_folder : str or Path, defualt: None
model_folder : str or Path, default: None
The path to the folder containing the model
repo_id : str | Path, default: None
Hugging face repo id which contains the model e.g. 'username/model'
Expand Down Expand Up @@ -281,6 +281,18 @@ def auto_label_units(
return classified_units


def auto_label_units(*args, **kwargs):
"""
Deprecated function. Please use `model_based_label_units` instead.
"""
warnings.warn(
"`auto_label_units` is deprecated and will be removed in v0.105.0. "
"Please use `model_based_label_units` instead.",
DeprecationWarning,
)
return model_based_label_units(*args, **kwargs)


def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=False, trusted=None):
"""
Loads a model and model_info from a HuggingFaceHub repo or a local folder.
Expand Down
27 changes: 23 additions & 4 deletions src/spikeinterface/curation/tests/test_model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path
from spikeinterface.curation.model_based_curation import ModelBasedClassification
from spikeinterface.curation import auto_label_units, load_model
from spikeinterface.curation import model_based_label_units, load_model
from spikeinterface.curation.train_manual_curation import _get_computed_metrics

import numpy as np
Expand Down Expand Up @@ -39,21 +39,21 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model):


def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path):
"""The function `auto_label_units` needs the correct metrics to have been computed. However,
"""The function `model_based_label_units` needs the correct metrics to have been computed. However,
it should be independent of the order of computation. We test this here."""

sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"])

prediction_prob_dataframe_1 = auto_label_units(
prediction_prob_dataframe_1 = model_based_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=trained_pipeline_path,
trusted=["numpy.dtype"],
)

sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"])

prediction_prob_dataframe_2 = auto_label_units(
prediction_prob_dataframe_2 = model_based_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=trained_pipeline_path,
trusted=["numpy.dtype"],
Expand Down Expand Up @@ -168,3 +168,22 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura
model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"])
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model)
model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info)


def test_unitrefine_label_units(sorting_analyzer_for_curation):
"""Test the `unitrefine_label_units` function."""

sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True)
sorting_analyzer_for_curation.compute("quality_metrics")

from spikeinterface.curation import unitrefine_label_units

labels = unitrefine_label_units(
sorting_analyzer_for_curation,
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
)

assert "label" in labels.columns
assert "probability" in labels.columns
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids())
79 changes: 79 additions & 0 deletions src/spikeinterface/curation/unitrefine_curation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pathlib import Path

from spikeinterface.core import SortingAnalyzer
from spikeinterface.curation.model_based_curation import model_based_label_units


def unitrefine_label_units(
sorting_analyzer: SortingAnalyzer,
noise_neural_classifier: str | Path | None = None,
sua_mua_classifier: str | Path | None = None,
):
"""Label units using a cascade of pre-trained classifiers for
noise/neural unit classification and SUA/MUA classification,
as shown in the UnitRefine paper (see References).
The noise/neural classifier is applied first to remove noise units,
then the SUA/MUA classifier is applied to the remaining units.

Parameters
----------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting results.
noise_neural_classifier : str or Path or None, default: None
The path to the folder containing the model or a string to a repo on HuggingFace.
If None, the noise/neural classification step is skipped.
Make sure to provide at least one of the two classifiers.
sua_mua_classifier : str or Path or None, default: None
The path to the folder containing the model or a string to a repo on HuggingFace.
If None, the SUA/MUA classification step is skipped.

Returns
-------
labels : pd.DataFrame
A DataFrame with unit ids as index and "label"/"probability" as column.

References
----------
The approach is described in [Jain]_.
"""
import pandas as pd
import pandas as pd

if noise_neural_classifier is None and sua_mua_classifier is None:
raise ValueError(
"At least one of noise_neural_classifier or sua_mua_classifier must be provided. "
"Pre-trained models can be found at https://huggingface.co/collections/SpikeInterface/curation-models or "
"https://huggingface.co/AnoushkaJain3/models. You can also train models on your own data: "
"see https://github.com/anoushkajain/UnitRefine for more details."
)

if noise_neural_classifier is not None:
# 1. apply the noise/neural classification and remove noise
noise_neuron_labels = model_based_label_units(
sorting_analyzer=sorting_analyzer,
repo_id=noise_neural_classifier,
trust_model=True,
)
noise_units = noise_neuron_labels[noise_neuron_labels["prediction"] == "noise"]
sorting_analyzer_neural = sorting_analyzer.remove_units(noise_units.index)
else:
sorting_analyzer_neural = sorting_analyzer
noise_units = pd.DataFrame(columns=["prediction", "probability"])

if sua_mua_classifier is not None:
# 2. apply the sua/mua classification and aggregate results
if len(sorting_analyzer.unit_ids) > len(noise_units):
sua_mua_labels = model_based_label_units(
sorting_analyzer=sorting_analyzer_neural,
repo_id=sua_mua_classifier,
trust_model=True,
)
all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index()
else:
all_labels = noise_units
else:
all_labels = noise_neuron_labels

# rename prediction column to label
all_labels = all_labels.rename(columns={"prediction": "label"})
return all_labels