From 42aa4ff4aca7f1aa49a74064de9386c3bc33e946 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 16:14:16 +0100 Subject: [PATCH 1/5] auto_label_units -> model_base_label_units + unitrefine function --- doc/api.rst | 3 +- doc/how_to/auto_curation_prediction.rst | 10 +-- doc/references.rst | 2 +- src/spikeinterface/curation/__init__.py | 3 +- .../curation/model_based_curation.py | 8 +- .../tests/test_model_based_curation.py | 23 +++++- .../curation/unitrefine_curation.py | 74 +++++++++++++++++++ 7 files changed, 107 insertions(+), 16 deletions(-) create mode 100644 src/spikeinterface/curation/unitrefine_curation.py diff --git a/doc/api.rst b/doc/api.rst index 38990a430d..adfdb85470 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -373,9 +373,10 @@ spikeinterface.curation .. autofunction:: remove_redundant_units .. autofunction:: remove_duplicated_spikes .. autofunction:: remove_excess_spikes - .. autofunction:: auto_label_units + .. autofunction:: model_based_label_units .. autofunction:: load_model .. autofunction:: train_model + .. autofunction:: unitrefine_label_units Curation Model ~~~~~~~~~~~~~~ diff --git a/doc/how_to/auto_curation_prediction.rst b/doc/how_to/auto_curation_prediction.rst index ad672f22a8..3e5f8ddd4f 100644 --- a/doc/how_to/auto_curation_prediction.rst +++ b/doc/how_to/auto_curation_prediction.rst @@ -16,9 +16,9 @@ repo's URL after huggingface.co/) and that we trust the model. .. code:: - from spikeinterface.curation import auto_label_units + from spikeinterface.curation import model_based_label_units - labels_and_probabilities = auto_label_units( + labels_and_probabilities = model_based_label_units( sorting_analyzer = sorting_analyzer, repo_id = "SpikeInterface/toy_tetrode_model", trust_model = True @@ -29,7 +29,7 @@ create the labels: .. code:: - labels_and_probabilities = si.auto_label_units( + labels_and_probabilities = si.model_based_label_units( sorting_analyzer = sorting_analyzer, model_folder = "my_folder_with_a_model_in_it", ) @@ -39,5 +39,5 @@ are also saved as a property of your ``sorting_analyzer`` and can be accessed li .. code:: - labels = sorting_analyzer.sorting.get_property("classifier_label") - probabilities = sorting_analyzer.sorting.get_property("classifier_probability") + labels = sorting_analyzer.get_sorting_property("classifier_label") + probabilities = sorting_analyzer.get_sorting_property("classifier_probability") diff --git a/doc/references.rst b/doc/references.rst index be05b69c0c..24eba16902 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -92,7 +92,7 @@ If you use the default "similarity_correlograms" preset in the :code:`compute_me If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_ -If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_ +If you use :code:`unitrefine_label_units`, :code:`model_based_label_units` or :code:`train_model`, please cite [Jain]_ Benchmark --------- diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index b64070e662..68737a0f12 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -20,5 +20,6 @@ from .sortingview_curation import apply_sortingview_curation # automated curation -from .model_based_curation import auto_label_units, load_model +from .model_based_curation import model_based_label_units, load_model from .train_manual_curation import train_model, get_default_classifier_search_spaces +from .unitrefine_curation import unitrefine_label_units diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index e779e13182..0598a6afa1 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -119,8 +119,8 @@ def predict_labels( ) # Set predictions and probability as sorting properties - self.sorting_analyzer.sorting.set_property("classifier_label", predictions) - self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities) + self.sorting_analyzer.set_sorting_property("classifier_label", predictions) + self.sorting_analyzer.set_sorting_property("classifier_probability", probabilities) if export_to_phy: self._export_to_phy(classified_units) @@ -204,7 +204,7 @@ def _export_to_phy(self, classified_df): classified_df.to_csv(f"{sorting_path}/cluster_prediction.tsv", sep="\t", index_label="cluster_id") -def auto_label_units( +def model_based_label_units( sorting_analyzer: SortingAnalyzer, model_folder=None, model_name=None, @@ -227,7 +227,7 @@ def auto_label_units( ---------- sorting_analyzer : SortingAnalyzer The sorting analyzer object containing the spike sorting results. - model_folder : str or Path, defualt: None + model_folder : str or Path, default: None The path to the folder containing the model repo_id : str | Path, default: None Hugging face repo id which contains the model e.g. 'username/model' diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 9f845bb1c3..23e4dd3587 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -2,7 +2,7 @@ from pathlib import Path from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path from spikeinterface.curation.model_based_curation import ModelBasedClassification -from spikeinterface.curation import auto_label_units, load_model +from spikeinterface.curation import model_based_label_units, load_model from spikeinterface.curation.train_manual_curation import _get_computed_metrics import numpy as np @@ -39,13 +39,13 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model): def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path): - """The function `auto_label_units` needs the correct metrics to have been computed. However, + """The function `model_based_label_units` needs the correct metrics to have been computed. However, it should be independent of the order of computation. We test this here.""" sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) - prediction_prob_dataframe_1 = auto_label_units( + prediction_prob_dataframe_1 = model_based_label_units( sorting_analyzer=sorting_analyzer_for_curation, model_folder=trained_pipeline_path, trusted=["numpy.dtype"], @@ -53,7 +53,7 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pip sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"]) - prediction_prob_dataframe_2 = auto_label_units( + prediction_prob_dataframe_2 = model_based_label_units( sorting_analyzer=sorting_analyzer_for_curation, model_folder=trained_pipeline_path, trusted=["numpy.dtype"], @@ -168,3 +168,18 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) + + +def test_unitrefine_label_units(sorting_analyzer_for_curation): + """Test the `unitrefine_label_units` function.""" + + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + + from spikeinterface.curation import unitrefine_label_units + + labels = unitrefine_label_units(sorting_analyzer_for_curation) + + assert "label" in labels.columns + assert "probability" in labels.columns + assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids()) diff --git a/src/spikeinterface/curation/unitrefine_curation.py b/src/spikeinterface/curation/unitrefine_curation.py new file mode 100644 index 0000000000..7b4692ad8c --- /dev/null +++ b/src/spikeinterface/curation/unitrefine_curation.py @@ -0,0 +1,74 @@ +from pathlib import Path + +from spikeinterface.core import SortingAnalyzer +from spikeinterface.curation.model_based_curation import model_based_label_units + + +def unitrefine_label_units( + sorting_analyzer: SortingAnalyzer, + noise_neural_classifier: str | Path | None = "SpikeInterface/UnitRefine_noise_neural_classifier", + sua_mua_classifier: str | Path | None = "SpikeInterface/UnitRefine_sua_mua_classifier", +): + """Label units using UnitRefine, which is a cascade of pre-trained classifiers for + noise/neural unit classification and SUA/MUA classification. + The noise/neural classifier is applied first to remove noise units, + then the SUA/MUA classifier is applied to the remaining units. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting results. + noise_neural_classifier : str or Path or None, default: "SpikeInterface/unitrefine-noise-neural-classifier" + The path to the folder containing the model or a string to a repo on HuggingFace. + If None, the noise/neural classification step is skipped. + By default, it uses a pre-trained model hosted on HuggingFace. + sua_mua_classifier : str or Path or None, default: "SpikeInterface/unitrefine-sua-mua-classifier" + The path to the folder containing the model or a string to a repo on HuggingFace. + If None, the SUA/MUA classification step is skipped. + By default, it uses a pre-trained model hosted on HuggingFace. + + Returns + ------- + labels : pd.DataFrame + A DataFrame with unit ids as index and "label"/"probability" as column. + + References + ---------- + The approach is described in [Jain]_. + """ + import pandas as pd + import pandas as pd + + if noise_neural_classifier is None and sua_mua_classifier is None: + raise ValueError("At least one of noise_neural_classifier or sua_mua_classifier must be provided.") + + if noise_neural_classifier is not None: + # 1. apply the noise/neural classification and remove noise + noise_neuron_labels = model_based_label_units( + sorting_analyzer=sorting_analyzer, + repo_id=noise_neural_classifier, + trust_model=True, + ) + noise_units = noise_neuron_labels[noise_neuron_labels["prediction"] == "noise"] + sorting_analyzer_neural = sorting_analyzer.remove_units(noise_units.index) + else: + sorting_analyzer_neural = sorting_analyzer + noise_units = pd.DataFrame(columns=["prediction", "probability"]) + + if sua_mua_classifier is not None: + # 2. apply the sua/mua classification and aggregate results + if len(sorting_analyzer.unit_ids) > len(noise_units): + sua_mua_labels = model_based_label_units( + sorting_analyzer=sorting_analyzer_neural, + repo_id=sua_mua_classifier, + trust_model=True, + ) + all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index() + else: + all_labels = noise_units + else: + all_labels = noise_neuron_labels + + # rename prediction column to label + all_labels = all_labels.rename(columns={"prediction": "label"}) + return all_labels From bbd0201f19806e26f5f3807f73779416559d83d1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 16:19:19 +0100 Subject: [PATCH 2/5] Use lightweight model by default --- .../curation/tests/test_model_based_curation.py | 4 ++-- src/spikeinterface/curation/unitrefine_curation.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 23e4dd3587..96e1de1316 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -173,8 +173,8 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura def test_unitrefine_label_units(sorting_analyzer_for_curation): """Test the `unitrefine_label_units` function.""" - sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) - sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True) + sorting_analyzer_for_curation.compute("quality_metrics") from spikeinterface.curation import unitrefine_label_units diff --git a/src/spikeinterface/curation/unitrefine_curation.py b/src/spikeinterface/curation/unitrefine_curation.py index 7b4692ad8c..c850637767 100644 --- a/src/spikeinterface/curation/unitrefine_curation.py +++ b/src/spikeinterface/curation/unitrefine_curation.py @@ -6,8 +6,8 @@ def unitrefine_label_units( sorting_analyzer: SortingAnalyzer, - noise_neural_classifier: str | Path | None = "SpikeInterface/UnitRefine_noise_neural_classifier", - sua_mua_classifier: str | Path | None = "SpikeInterface/UnitRefine_sua_mua_classifier", + noise_neural_classifier: str | Path | None = "SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", + sua_mua_classifier: str | Path | None = "SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", ): """Label units using UnitRefine, which is a cascade of pre-trained classifiers for noise/neural unit classification and SUA/MUA classification. @@ -18,14 +18,16 @@ def unitrefine_label_units( ---------- sorting_analyzer : SortingAnalyzer The sorting analyzer object containing the spike sorting results. - noise_neural_classifier : str or Path or None, default: "SpikeInterface/unitrefine-noise-neural-classifier" + noise_neural_classifier : str or Path or None, default: "SpikeInterface/UnitRefine_noise_neural_classifier_lightweight" The path to the folder containing the model or a string to a repo on HuggingFace. If None, the noise/neural classification step is skipped. - By default, it uses a pre-trained model hosted on HuggingFace. - sua_mua_classifier : str or Path or None, default: "SpikeInterface/unitrefine-sua-mua-classifier" + By default, it uses a pre-trained lightweight model hosted on HuggingFace that does not require principal + component analysis (PCA) features. + sua_mua_classifier : str or Path or None, default: "SpikeInterface/UnitRefine_sua_mua_classifier_lightweight" The path to the folder containing the model or a string to a repo on HuggingFace. If None, the SUA/MUA classification step is skipped. - By default, it uses a pre-trained model hosted on HuggingFace. + By default, it uses a pre-trained lightweight model hosted on HuggingFace that does not require principal + component analysis (PCA) features. Returns ------- From 03d70778fd33f6ae114d8b7d434eb52eec3053ab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 16:22:47 +0100 Subject: [PATCH 3/5] examples! --- .../tutorials/curation/plot_1_automated_curation.py | 10 +++++----- examples/tutorials/curation/plot_3_upload_a_model.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/tutorials/curation/plot_1_automated_curation.py b/examples/tutorials/curation/plot_1_automated_curation.py index 00bd606c44..939acf0003 100644 --- a/examples/tutorials/curation/plot_1_automated_curation.py +++ b/examples/tutorials/curation/plot_1_automated_curation.py @@ -83,10 +83,10 @@ ############################################################################## # Great! We can now use the model to predict labels. Here, we pass the HF repo id directly -# to the ``auto_label_units`` function. This returns a dictionary containing a label and +# to the ``model_based_label_units`` function. This returns a dictionary containing a label and # a confidence for each unit contained in the ``sorting_analyzer``. -labels = sc.auto_label_units( +labels = sc.model_based_label_units( sorting_analyzer = sorting_analyzer, repo_id = "SpikeInterface/toy_tetrode_model", trusted = ['numpy.dtype'] @@ -220,7 +220,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size): # # Apply the noise/not-noise model -noise_neuron_labels = sc.auto_label_units( +noise_neuron_labels = sc.model_based_label_units( sorting_analyzer=sorting_analyzer, repo_id="SpikeInterface/UnitRefine_noise_neural_classifier", trust_model=True, @@ -230,7 +230,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size): analyzer_neural = sorting_analyzer.remove_units(noise_units.index) # Apply the sua/mua model -sua_mua_labels = sc.auto_label_units( +sua_mua_labels = sc.model_based_label_units( sorting_analyzer=analyzer_neural, repo_id="SpikeInterface/UnitRefine_sua_mua_classifier", trust_model=True, @@ -276,7 +276,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size): # # .. code-block:: # -# labels = sc.auto_label_units( +# labels = sc.model_based_label_units( # sorting_analyzer = sorting_analyzer, # model_folder = "path/to/model/folder", # ) diff --git a/examples/tutorials/curation/plot_3_upload_a_model.py b/examples/tutorials/curation/plot_3_upload_a_model.py index ad9d16cab5..36c7a20f1d 100644 --- a/examples/tutorials/curation/plot_3_upload_a_model.py +++ b/examples/tutorials/curation/plot_3_upload_a_model.py @@ -112,8 +112,8 @@ # # ` ` ` python (NOTE: you should remove the spaces between each backtick. This is just formatting for the notebook you are reading) # -# from spikeinterface.curation import auto_label_units -# labels = auto_label_units( +# from spikeinterface.curation import model_based_label_units +# labels = model_based_label_units( # sorting_analyzer = sorting_analyzer, # repo_id = "SpikeInterface/toy_tetrode_model", # trust_model=True @@ -123,9 +123,9 @@ # or you can download the entire repositry to `a_folder_for_a_model`, and use # # ` ` ` python -# from spikeinterface.curation import auto_label_units +# from spikeinterface.curation import model_based_label_units # -# labels = auto_label_units( +# labels = model_based_label_units( # sorting_analyzer = sorting_analyzer, # model_folder = "path/to/a_folder_for_a_model", # trusted = ['numpy.dtype'] From 6bfdb8be49a2ba03d3b4ceeb442703aaa3cceb00 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 18:38:08 +0100 Subject: [PATCH 4/5] Update src/spikeinterface/curation/unitrefine_curation.py Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/curation/unitrefine_curation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/curation/unitrefine_curation.py b/src/spikeinterface/curation/unitrefine_curation.py index c850637767..46e8a2e0c3 100644 --- a/src/spikeinterface/curation/unitrefine_curation.py +++ b/src/spikeinterface/curation/unitrefine_curation.py @@ -9,8 +9,9 @@ def unitrefine_label_units( noise_neural_classifier: str | Path | None = "SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", sua_mua_classifier: str | Path | None = "SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", ): - """Label units using UnitRefine, which is a cascade of pre-trained classifiers for - noise/neural unit classification and SUA/MUA classification. + """Label units using a cascade of pre-trained classifiers for + noise/neural unit classification and SUA/MUA classification, + as shown in the UnitRefine paper (see References). The noise/neural classifier is applied first to remove noise units, then the SUA/MUA classifier is applied to the remaining units. From 48cd1c51a15913d3933d1906a4846e350fcc62cb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 23 Jan 2026 09:50:34 +0100 Subject: [PATCH 5/5] Deprecate auto_label_units and set default models to None --- .../curation/plot_1_automated_curation.py | 18 ++++++++++++++--- .../curation/model_based_curation.py | 12 +++++++++++ .../tests/test_model_based_curation.py | 6 +++++- .../curation/unitrefine_curation.py | 20 ++++++++++--------- 4 files changed, 43 insertions(+), 13 deletions(-) diff --git a/examples/tutorials/curation/plot_1_automated_curation.py b/examples/tutorials/curation/plot_1_automated_curation.py index 939acf0003..6ef986a810 100644 --- a/examples/tutorials/curation/plot_1_automated_curation.py +++ b/examples/tutorials/curation/plot_1_automated_curation.py @@ -211,15 +211,15 @@ def calculate_moving_avg(label_df, confidence_label, window_size): # For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in # V1,SC and ALM: https://huggingface.co/SpikeInterface/UnitRefine_noise_neural_classifier/ and # https://huggingface.co/SpikeInterface/UnitRefine_sua_mua_classifier/. One will classify units into -# `noise` or `not-noise` and the other will classify the `not-noise` units into single +# `noise` or `neural` and the other will classify the `neural` units into single # unit activity (sua) units and multi-unit activity (mua) units. # # There is more information about the model on the model's HuggingFace page. Take a look! -# The idea here is to first apply the noise/not-noise classifier, then the sua/mua one. +# The idea here is to first apply the noise/neural classifier, then the sua/mua one. # We can do so as follows: # -# Apply the noise/not-noise model +# Apply the noise/neural model noise_neuron_labels = sc.model_based_label_units( sorting_analyzer=sorting_analyzer, repo_id="SpikeInterface/UnitRefine_noise_neural_classifier", @@ -239,6 +239,18 @@ def calculate_moving_avg(label_df, confidence_label, window_size): all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index() print(all_labels) +############################################################################## +# Both steps can be done in one go using the ``unitrefine_label_units`` function: +# + +all_labels = sc.unitrefine_label_units( + sorting_analyzer, + noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier", + sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier", +) +print(all_labels) + + ############################################################################## # If you run this without the ``trust_model=True`` parameter, you will receive an error: # diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index 0598a6afa1..ef8571b62f 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -281,6 +281,18 @@ def model_based_label_units( return classified_units +def auto_label_units(*args, **kwargs): + """ + Deprecated function. Please use `model_based_label_units` instead. + """ + warnings.warn( + "`auto_label_units` is deprecated and will be removed in v0.105.0. " + "Please use `model_based_label_units` instead.", + DeprecationWarning, + ) + return model_based_label_units(*args, **kwargs) + + def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=False, trusted=None): """ Loads a model and model_info from a HuggingFaceHub repo or a local folder. diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 96e1de1316..9b6ccd0a21 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -178,7 +178,11 @@ def test_unitrefine_label_units(sorting_analyzer_for_curation): from spikeinterface.curation import unitrefine_label_units - labels = unitrefine_label_units(sorting_analyzer_for_curation) + labels = unitrefine_label_units( + sorting_analyzer_for_curation, + noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", + sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", + ) assert "label" in labels.columns assert "probability" in labels.columns diff --git a/src/spikeinterface/curation/unitrefine_curation.py b/src/spikeinterface/curation/unitrefine_curation.py index 46e8a2e0c3..f7cd21e7e9 100644 --- a/src/spikeinterface/curation/unitrefine_curation.py +++ b/src/spikeinterface/curation/unitrefine_curation.py @@ -6,8 +6,8 @@ def unitrefine_label_units( sorting_analyzer: SortingAnalyzer, - noise_neural_classifier: str | Path | None = "SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", - sua_mua_classifier: str | Path | None = "SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", + noise_neural_classifier: str | Path | None = None, + sua_mua_classifier: str | Path | None = None, ): """Label units using a cascade of pre-trained classifiers for noise/neural unit classification and SUA/MUA classification, @@ -19,16 +19,13 @@ def unitrefine_label_units( ---------- sorting_analyzer : SortingAnalyzer The sorting analyzer object containing the spike sorting results. - noise_neural_classifier : str or Path or None, default: "SpikeInterface/UnitRefine_noise_neural_classifier_lightweight" + noise_neural_classifier : str or Path or None, default: None The path to the folder containing the model or a string to a repo on HuggingFace. If None, the noise/neural classification step is skipped. - By default, it uses a pre-trained lightweight model hosted on HuggingFace that does not require principal - component analysis (PCA) features. - sua_mua_classifier : str or Path or None, default: "SpikeInterface/UnitRefine_sua_mua_classifier_lightweight" + Make sure to provide at least one of the two classifiers. + sua_mua_classifier : str or Path or None, default: None The path to the folder containing the model or a string to a repo on HuggingFace. If None, the SUA/MUA classification step is skipped. - By default, it uses a pre-trained lightweight model hosted on HuggingFace that does not require principal - component analysis (PCA) features. Returns ------- @@ -43,7 +40,12 @@ def unitrefine_label_units( import pandas as pd if noise_neural_classifier is None and sua_mua_classifier is None: - raise ValueError("At least one of noise_neural_classifier or sua_mua_classifier must be provided.") + raise ValueError( + "At least one of noise_neural_classifier or sua_mua_classifier must be provided. " + "Pre-trained models can be found at https://huggingface.co/collections/SpikeInterface/curation-models or " + "https://huggingface.co/AnoushkaJain3/models. You can also train models on your own data: " + "see https://github.com/anoushkajain/UnitRefine for more details." + ) if noise_neural_classifier is not None: # 1. apply the noise/neural classification and remove noise