diff --git a/doc/api.rst b/doc/api.rst index 38990a430d..adfdb85470 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -373,9 +373,10 @@ spikeinterface.curation .. autofunction:: remove_redundant_units .. autofunction:: remove_duplicated_spikes .. autofunction:: remove_excess_spikes - .. autofunction:: auto_label_units + .. autofunction:: model_based_label_units .. autofunction:: load_model .. autofunction:: train_model + .. autofunction:: unitrefine_label_units Curation Model ~~~~~~~~~~~~~~ diff --git a/doc/how_to/auto_curation_prediction.rst b/doc/how_to/auto_curation_prediction.rst index ad672f22a8..3e5f8ddd4f 100644 --- a/doc/how_to/auto_curation_prediction.rst +++ b/doc/how_to/auto_curation_prediction.rst @@ -16,9 +16,9 @@ repo's URL after huggingface.co/) and that we trust the model. .. code:: - from spikeinterface.curation import auto_label_units + from spikeinterface.curation import model_based_label_units - labels_and_probabilities = auto_label_units( + labels_and_probabilities = model_based_label_units( sorting_analyzer = sorting_analyzer, repo_id = "SpikeInterface/toy_tetrode_model", trust_model = True @@ -29,7 +29,7 @@ create the labels: .. code:: - labels_and_probabilities = si.auto_label_units( + labels_and_probabilities = si.model_based_label_units( sorting_analyzer = sorting_analyzer, model_folder = "my_folder_with_a_model_in_it", ) @@ -39,5 +39,5 @@ are also saved as a property of your ``sorting_analyzer`` and can be accessed li .. code:: - labels = sorting_analyzer.sorting.get_property("classifier_label") - probabilities = sorting_analyzer.sorting.get_property("classifier_probability") + labels = sorting_analyzer.get_sorting_property("classifier_label") + probabilities = sorting_analyzer.get_sorting_property("classifier_probability") diff --git a/doc/references.rst b/doc/references.rst index be05b69c0c..24eba16902 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -92,7 +92,7 @@ If you use the default "similarity_correlograms" preset in the :code:`compute_me If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_ -If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_ +If you use :code:`unitrefine_label_units`, :code:`model_based_label_units` or :code:`train_model`, please cite [Jain]_ Benchmark --------- diff --git a/examples/tutorials/curation/plot_1_automated_curation.py b/examples/tutorials/curation/plot_1_automated_curation.py index 00bd606c44..6ef986a810 100644 --- a/examples/tutorials/curation/plot_1_automated_curation.py +++ b/examples/tutorials/curation/plot_1_automated_curation.py @@ -83,10 +83,10 @@ ############################################################################## # Great! We can now use the model to predict labels. Here, we pass the HF repo id directly -# to the ``auto_label_units`` function. This returns a dictionary containing a label and +# to the ``model_based_label_units`` function. This returns a dictionary containing a label and # a confidence for each unit contained in the ``sorting_analyzer``. -labels = sc.auto_label_units( +labels = sc.model_based_label_units( sorting_analyzer = sorting_analyzer, repo_id = "SpikeInterface/toy_tetrode_model", trusted = ['numpy.dtype'] @@ -211,16 +211,16 @@ def calculate_moving_avg(label_df, confidence_label, window_size): # For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in # V1,SC and ALM: https://huggingface.co/SpikeInterface/UnitRefine_noise_neural_classifier/ and # https://huggingface.co/SpikeInterface/UnitRefine_sua_mua_classifier/. One will classify units into -# `noise` or `not-noise` and the other will classify the `not-noise` units into single +# `noise` or `neural` and the other will classify the `neural` units into single # unit activity (sua) units and multi-unit activity (mua) units. # # There is more information about the model on the model's HuggingFace page. Take a look! -# The idea here is to first apply the noise/not-noise classifier, then the sua/mua one. +# The idea here is to first apply the noise/neural classifier, then the sua/mua one. # We can do so as follows: # -# Apply the noise/not-noise model -noise_neuron_labels = sc.auto_label_units( +# Apply the noise/neural model +noise_neuron_labels = sc.model_based_label_units( sorting_analyzer=sorting_analyzer, repo_id="SpikeInterface/UnitRefine_noise_neural_classifier", trust_model=True, @@ -230,7 +230,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size): analyzer_neural = sorting_analyzer.remove_units(noise_units.index) # Apply the sua/mua model -sua_mua_labels = sc.auto_label_units( +sua_mua_labels = sc.model_based_label_units( sorting_analyzer=analyzer_neural, repo_id="SpikeInterface/UnitRefine_sua_mua_classifier", trust_model=True, @@ -239,6 +239,18 @@ def calculate_moving_avg(label_df, confidence_label, window_size): all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index() print(all_labels) +############################################################################## +# Both steps can be done in one go using the ``unitrefine_label_units`` function: +# + +all_labels = sc.unitrefine_label_units( + sorting_analyzer, + noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier", + sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier", +) +print(all_labels) + + ############################################################################## # If you run this without the ``trust_model=True`` parameter, you will receive an error: # @@ -276,7 +288,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size): # # .. code-block:: # -# labels = sc.auto_label_units( +# labels = sc.model_based_label_units( # sorting_analyzer = sorting_analyzer, # model_folder = "path/to/model/folder", # ) diff --git a/examples/tutorials/curation/plot_3_upload_a_model.py b/examples/tutorials/curation/plot_3_upload_a_model.py index ad9d16cab5..36c7a20f1d 100644 --- a/examples/tutorials/curation/plot_3_upload_a_model.py +++ b/examples/tutorials/curation/plot_3_upload_a_model.py @@ -112,8 +112,8 @@ # # ` ` ` python (NOTE: you should remove the spaces between each backtick. This is just formatting for the notebook you are reading) # -# from spikeinterface.curation import auto_label_units -# labels = auto_label_units( +# from spikeinterface.curation import model_based_label_units +# labels = model_based_label_units( # sorting_analyzer = sorting_analyzer, # repo_id = "SpikeInterface/toy_tetrode_model", # trust_model=True @@ -123,9 +123,9 @@ # or you can download the entire repositry to `a_folder_for_a_model`, and use # # ` ` ` python -# from spikeinterface.curation import auto_label_units +# from spikeinterface.curation import model_based_label_units # -# labels = auto_label_units( +# labels = model_based_label_units( # sorting_analyzer = sorting_analyzer, # model_folder = "path/to/a_folder_for_a_model", # trusted = ['numpy.dtype'] diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index b64070e662..68737a0f12 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -20,5 +20,6 @@ from .sortingview_curation import apply_sortingview_curation # automated curation -from .model_based_curation import auto_label_units, load_model +from .model_based_curation import model_based_label_units, load_model from .train_manual_curation import train_model, get_default_classifier_search_spaces +from .unitrefine_curation import unitrefine_label_units diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index e779e13182..ef8571b62f 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -119,8 +119,8 @@ def predict_labels( ) # Set predictions and probability as sorting properties - self.sorting_analyzer.sorting.set_property("classifier_label", predictions) - self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities) + self.sorting_analyzer.set_sorting_property("classifier_label", predictions) + self.sorting_analyzer.set_sorting_property("classifier_probability", probabilities) if export_to_phy: self._export_to_phy(classified_units) @@ -204,7 +204,7 @@ def _export_to_phy(self, classified_df): classified_df.to_csv(f"{sorting_path}/cluster_prediction.tsv", sep="\t", index_label="cluster_id") -def auto_label_units( +def model_based_label_units( sorting_analyzer: SortingAnalyzer, model_folder=None, model_name=None, @@ -227,7 +227,7 @@ def auto_label_units( ---------- sorting_analyzer : SortingAnalyzer The sorting analyzer object containing the spike sorting results. - model_folder : str or Path, defualt: None + model_folder : str or Path, default: None The path to the folder containing the model repo_id : str | Path, default: None Hugging face repo id which contains the model e.g. 'username/model' @@ -281,6 +281,18 @@ def auto_label_units( return classified_units +def auto_label_units(*args, **kwargs): + """ + Deprecated function. Please use `model_based_label_units` instead. + """ + warnings.warn( + "`auto_label_units` is deprecated and will be removed in v0.105.0. " + "Please use `model_based_label_units` instead.", + DeprecationWarning, + ) + return model_based_label_units(*args, **kwargs) + + def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=False, trusted=None): """ Loads a model and model_info from a HuggingFaceHub repo or a local folder. diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 9f845bb1c3..9b6ccd0a21 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -2,7 +2,7 @@ from pathlib import Path from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path from spikeinterface.curation.model_based_curation import ModelBasedClassification -from spikeinterface.curation import auto_label_units, load_model +from spikeinterface.curation import model_based_label_units, load_model from spikeinterface.curation.train_manual_curation import _get_computed_metrics import numpy as np @@ -39,13 +39,13 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model): def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path): - """The function `auto_label_units` needs the correct metrics to have been computed. However, + """The function `model_based_label_units` needs the correct metrics to have been computed. However, it should be independent of the order of computation. We test this here.""" sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) - prediction_prob_dataframe_1 = auto_label_units( + prediction_prob_dataframe_1 = model_based_label_units( sorting_analyzer=sorting_analyzer_for_curation, model_folder=trained_pipeline_path, trusted=["numpy.dtype"], @@ -53,7 +53,7 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pip sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"]) - prediction_prob_dataframe_2 = auto_label_units( + prediction_prob_dataframe_2 = model_based_label_units( sorting_analyzer=sorting_analyzer_for_curation, model_folder=trained_pipeline_path, trusted=["numpy.dtype"], @@ -168,3 +168,22 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) + + +def test_unitrefine_label_units(sorting_analyzer_for_curation): + """Test the `unitrefine_label_units` function.""" + + sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True) + sorting_analyzer_for_curation.compute("quality_metrics") + + from spikeinterface.curation import unitrefine_label_units + + labels = unitrefine_label_units( + sorting_analyzer_for_curation, + noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", + sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", + ) + + assert "label" in labels.columns + assert "probability" in labels.columns + assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids()) diff --git a/src/spikeinterface/curation/unitrefine_curation.py b/src/spikeinterface/curation/unitrefine_curation.py new file mode 100644 index 0000000000..f7cd21e7e9 --- /dev/null +++ b/src/spikeinterface/curation/unitrefine_curation.py @@ -0,0 +1,79 @@ +from pathlib import Path + +from spikeinterface.core import SortingAnalyzer +from spikeinterface.curation.model_based_curation import model_based_label_units + + +def unitrefine_label_units( + sorting_analyzer: SortingAnalyzer, + noise_neural_classifier: str | Path | None = None, + sua_mua_classifier: str | Path | None = None, +): + """Label units using a cascade of pre-trained classifiers for + noise/neural unit classification and SUA/MUA classification, + as shown in the UnitRefine paper (see References). + The noise/neural classifier is applied first to remove noise units, + then the SUA/MUA classifier is applied to the remaining units. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting results. + noise_neural_classifier : str or Path or None, default: None + The path to the folder containing the model or a string to a repo on HuggingFace. + If None, the noise/neural classification step is skipped. + Make sure to provide at least one of the two classifiers. + sua_mua_classifier : str or Path or None, default: None + The path to the folder containing the model or a string to a repo on HuggingFace. + If None, the SUA/MUA classification step is skipped. + + Returns + ------- + labels : pd.DataFrame + A DataFrame with unit ids as index and "label"/"probability" as column. + + References + ---------- + The approach is described in [Jain]_. + """ + import pandas as pd + import pandas as pd + + if noise_neural_classifier is None and sua_mua_classifier is None: + raise ValueError( + "At least one of noise_neural_classifier or sua_mua_classifier must be provided. " + "Pre-trained models can be found at https://huggingface.co/collections/SpikeInterface/curation-models or " + "https://huggingface.co/AnoushkaJain3/models. You can also train models on your own data: " + "see https://github.com/anoushkajain/UnitRefine for more details." + ) + + if noise_neural_classifier is not None: + # 1. apply the noise/neural classification and remove noise + noise_neuron_labels = model_based_label_units( + sorting_analyzer=sorting_analyzer, + repo_id=noise_neural_classifier, + trust_model=True, + ) + noise_units = noise_neuron_labels[noise_neuron_labels["prediction"] == "noise"] + sorting_analyzer_neural = sorting_analyzer.remove_units(noise_units.index) + else: + sorting_analyzer_neural = sorting_analyzer + noise_units = pd.DataFrame(columns=["prediction", "probability"]) + + if sua_mua_classifier is not None: + # 2. apply the sua/mua classification and aggregate results + if len(sorting_analyzer.unit_ids) > len(noise_units): + sua_mua_labels = model_based_label_units( + sorting_analyzer=sorting_analyzer_neural, + repo_id=sua_mua_classifier, + trust_model=True, + ) + all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index() + else: + all_labels = noise_units + else: + all_labels = noise_neuron_labels + + # rename prediction column to label + all_labels = all_labels.rename(columns={"prediction": "label"}) + return all_labels