diff --git a/doc/api.rst b/doc/api.rst
index 38990a430d..adfdb85470 100755
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -373,9 +373,10 @@ spikeinterface.curation
.. autofunction:: remove_redundant_units
.. autofunction:: remove_duplicated_spikes
.. autofunction:: remove_excess_spikes
- .. autofunction:: auto_label_units
+ .. autofunction:: model_based_label_units
.. autofunction:: load_model
.. autofunction:: train_model
+ .. autofunction:: unitrefine_label_units
Curation Model
~~~~~~~~~~~~~~
diff --git a/doc/how_to/auto_curation_prediction.rst b/doc/how_to/auto_curation_prediction.rst
index ad672f22a8..3e5f8ddd4f 100644
--- a/doc/how_to/auto_curation_prediction.rst
+++ b/doc/how_to/auto_curation_prediction.rst
@@ -16,9 +16,9 @@ repo's URL after huggingface.co/) and that we trust the model.
.. code::
- from spikeinterface.curation import auto_label_units
+ from spikeinterface.curation import model_based_label_units
- labels_and_probabilities = auto_label_units(
+ labels_and_probabilities = model_based_label_units(
sorting_analyzer = sorting_analyzer,
repo_id = "SpikeInterface/toy_tetrode_model",
trust_model = True
@@ -29,7 +29,7 @@ create the labels:
.. code::
- labels_and_probabilities = si.auto_label_units(
+ labels_and_probabilities = si.model_based_label_units(
sorting_analyzer = sorting_analyzer,
model_folder = "my_folder_with_a_model_in_it",
)
@@ -39,5 +39,5 @@ are also saved as a property of your ``sorting_analyzer`` and can be accessed li
.. code::
- labels = sorting_analyzer.sorting.get_property("classifier_label")
- probabilities = sorting_analyzer.sorting.get_property("classifier_probability")
+ labels = sorting_analyzer.get_sorting_property("classifier_label")
+ probabilities = sorting_analyzer.get_sorting_property("classifier_probability")
diff --git a/doc/references.rst b/doc/references.rst
index 49f8b33add..53fcab39e1 100644
--- a/doc/references.rst
+++ b/doc/references.rst
@@ -92,7 +92,7 @@ If you use the default "similarity_correlograms" preset in the :code:`compute_me
If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_
-If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_
+If you use :code:`unitrefine_label_units`, :code:`model_based_label_units` or :code:`train_model`, please cite [Jain]_
Benchmark
---------
diff --git a/examples/tutorials/curation/plot_1_automated_curation.py b/examples/tutorials/curation/plot_1_automated_curation.py
index 00bd606c44..bb5e8b9a57 100644
--- a/examples/tutorials/curation/plot_1_automated_curation.py
+++ b/examples/tutorials/curation/plot_1_automated_curation.py
@@ -83,10 +83,10 @@
##############################################################################
# Great! We can now use the model to predict labels. Here, we pass the HF repo id directly
-# to the ``auto_label_units`` function. This returns a dictionary containing a label and
+# to the ``model_based_label_units`` function. This returns a dictionary containing a label and
# a confidence for each unit contained in the ``sorting_analyzer``.
-labels = sc.auto_label_units(
+labels = sc.model_based_label_units(
sorting_analyzer = sorting_analyzer,
repo_id = "SpikeInterface/toy_tetrode_model",
trusted = ['numpy.dtype']
@@ -211,16 +211,16 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
# For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in
# V1,SC and ALM: https://huggingface.co/SpikeInterface/UnitRefine_noise_neural_classifier/ and
# https://huggingface.co/SpikeInterface/UnitRefine_sua_mua_classifier/. One will classify units into
-# `noise` or `not-noise` and the other will classify the `not-noise` units into single
+# `noise` or `neural` and the other will classify the `neural` units into single
# unit activity (sua) units and multi-unit activity (mua) units.
#
# There is more information about the model on the model's HuggingFace page. Take a look!
-# The idea here is to first apply the noise/not-noise classifier, then the sua/mua one.
+# The idea here is to first apply the noise/neural classifier, then the sua/mua one.
# We can do so as follows:
#
-# Apply the noise/not-noise model
-noise_neuron_labels = sc.auto_label_units(
+# Apply the noise/neural model
+noise_neuron_labels = sc.model_based_label_units(
sorting_analyzer=sorting_analyzer,
repo_id="SpikeInterface/UnitRefine_noise_neural_classifier",
trust_model=True,
@@ -230,7 +230,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
analyzer_neural = sorting_analyzer.remove_units(noise_units.index)
# Apply the sua/mua model
-sua_mua_labels = sc.auto_label_units(
+sua_mua_labels = sc.model_based_label_units(
sorting_analyzer=analyzer_neural,
repo_id="SpikeInterface/UnitRefine_sua_mua_classifier",
trust_model=True,
@@ -239,6 +239,18 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index()
print(all_labels)
+##############################################################################
+# Both steps can be done in one go using the ``unitrefine_label_units`` function:
+#
+
+all_labels = sc.unitrefine_label_units(
+ sorting_analyzer,
+ noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier",
+ sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier",
+)
+print(all_labels)
+
+
##############################################################################
# If you run this without the ``trust_model=True`` parameter, you will receive an error:
#
@@ -252,7 +264,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
#
# .. dropdown:: More about security
#
-# Sharing models, with are Python objects, is complicated.
+# Sharing models, which are Python objects, is complicated.
# We have chosen to use the `skops format `_, instead
# of the common but insecure ``.pkl`` format (read about ``pickle`` security issues
# `here `_). While unpacking the ``.skops`` file, each function
@@ -276,7 +288,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
#
# .. code-block::
#
-# labels = sc.auto_label_units(
+# labels = sc.model_based_label_units(
# sorting_analyzer = sorting_analyzer,
# model_folder = "path/to/model/folder",
# )
diff --git a/examples/tutorials/curation/plot_3_upload_a_model.py b/examples/tutorials/curation/plot_3_upload_a_model.py
index ad9d16cab5..36c7a20f1d 100644
--- a/examples/tutorials/curation/plot_3_upload_a_model.py
+++ b/examples/tutorials/curation/plot_3_upload_a_model.py
@@ -112,8 +112,8 @@
#
# ` ` ` python (NOTE: you should remove the spaces between each backtick. This is just formatting for the notebook you are reading)
#
-# from spikeinterface.curation import auto_label_units
-# labels = auto_label_units(
+# from spikeinterface.curation import model_based_label_units
+# labels = model_based_label_units(
# sorting_analyzer = sorting_analyzer,
# repo_id = "SpikeInterface/toy_tetrode_model",
# trust_model=True
@@ -123,9 +123,9 @@
# or you can download the entire repositry to `a_folder_for_a_model`, and use
#
# ` ` ` python
-# from spikeinterface.curation import auto_label_units
+# from spikeinterface.curation import model_based_label_units
#
-# labels = auto_label_units(
+# labels = model_based_label_units(
# sorting_analyzer = sorting_analyzer,
# model_folder = "path/to/a_folder_for_a_model",
# trusted = ['numpy.dtype']
diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py
index b64070e662..730481937c 100644
--- a/src/spikeinterface/curation/__init__.py
+++ b/src/spikeinterface/curation/__init__.py
@@ -20,5 +20,6 @@
from .sortingview_curation import apply_sortingview_curation
# automated curation
-from .model_based_curation import auto_label_units, load_model
+from .model_based_curation import model_based_label_units, load_model, auto_label_units
from .train_manual_curation import train_model, get_default_classifier_search_spaces
+from .unitrefine_curation import unitrefine_label_units
diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py
index e779e13182..e9a222f793 100644
--- a/src/spikeinterface/curation/model_based_curation.py
+++ b/src/spikeinterface/curation/model_based_curation.py
@@ -119,8 +119,8 @@ def predict_labels(
)
# Set predictions and probability as sorting properties
- self.sorting_analyzer.sorting.set_property("classifier_label", predictions)
- self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities)
+ self.sorting_analyzer.set_sorting_property("classifier_label", predictions)
+ self.sorting_analyzer.set_sorting_property("classifier_probability", probabilities)
if export_to_phy:
self._export_to_phy(classified_units)
@@ -204,11 +204,11 @@ def _export_to_phy(self, classified_df):
classified_df.to_csv(f"{sorting_path}/cluster_prediction.tsv", sep="\t", index_label="cluster_id")
-def auto_label_units(
+def model_based_label_units(
sorting_analyzer: SortingAnalyzer,
model_folder=None,
- model_name=None,
repo_id=None,
+ model_name=None,
label_conversion=None,
trust_model=False,
trusted=None,
@@ -227,11 +227,11 @@ def auto_label_units(
----------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting results.
- model_folder : str or Path, defualt: None
+ model_folder : str or Path, default: None
The path to the folder containing the model
- repo_id : str | Path, default: None
+ repo_id : str, default: None
Hugging face repo id which contains the model e.g. 'username/model'
- model_name: str | Path, default: None
+ model_name: str, default: None
Filename of model e.g. 'my_model.skops'. If None, uses first model found.
label_conversion : dic | None, default: None
A dictionary for converting the predicted labels (which are integers) to custom labels. If None,
@@ -281,6 +281,19 @@ def auto_label_units(
return classified_units
+def auto_label_units(*args, **kwargs):
+ """
+ Deprecated function. Please use `model_based_label_units` instead.
+ """
+ warnings.warn(
+ "`auto_label_units` is deprecated and will be removed in v0.105.0. "
+ "Please use `model_based_label_units` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return model_based_label_units(*args, **kwargs)
+
+
def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=False, trusted=None):
"""
Loads a model and model_info from a HuggingFaceHub repo or a local folder.
@@ -289,9 +302,9 @@ def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=Fal
----------
model_folder : str or Path, defualt: None
The path to the folder containing the model
- repo_id : str | Path, default: None
+ repo_id : str, default: None
Hugging face repo id which contains the model e.g. 'username/model'
- model_name: str | Path, default: None
+ model_name: str, default: None
Filename of model e.g. 'my_model.skops'. If None, uses first model found.
trust_model : bool, default: False
Whether to trust the model. If True, the `trusted` parameter that is passed to `skops.load` to load the model will be
diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py
index 8863f874d5..8553685459 100644
--- a/src/spikeinterface/curation/tests/common.py
+++ b/src/spikeinterface/curation/tests/common.py
@@ -91,22 +91,25 @@ def trained_pipeline_path():
If the model already exists, this function does nothing.
"""
trained_model_folder = Path(__file__).parent / Path("trained_pipeline")
- analyzer = make_sorting_analyzer(sparse=True)
- analyzer.compute(
- {
- "quality_metrics": {"metric_names": ["snr", "num_spikes"]},
- "template_metrics": {"metric_names": ["half_width"]},
- }
- )
- train_model(
- analyzers=[analyzer] * 5,
- labels=[[1, 0, 1, 0, 1]] * 5,
- folder=trained_model_folder,
- classifiers=["RandomForestClassifier"],
- imputation_strategies=["median"],
- scaling_techniques=["standard_scaler"],
- )
- yield trained_model_folder
+ if trained_model_folder.is_dir():
+ yield trained_model_folder
+ else:
+ analyzer = make_sorting_analyzer(sparse=True)
+ analyzer.compute(
+ {
+ "quality_metrics": {"metric_names": ["snr", "num_spikes"]},
+ "template_metrics": {"metric_names": ["half_width"]},
+ }
+ )
+ train_model(
+ analyzers=[analyzer] * 5,
+ labels=[[1, 0, 1, 0, 1]] * 5,
+ folder=trained_model_folder,
+ classifiers=["RandomForestClassifier"],
+ imputation_strategies=["median"],
+ scaling_techniques=["standard_scaler"],
+ )
+ yield trained_model_folder
if __name__ == "__main__":
diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py
index 9f845bb1c3..86b7d836b7 100644
--- a/src/spikeinterface/curation/tests/test_model_based_curation.py
+++ b/src/spikeinterface/curation/tests/test_model_based_curation.py
@@ -1,9 +1,12 @@
import pytest
from pathlib import Path
+
from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path
from spikeinterface.curation.model_based_curation import ModelBasedClassification
-from spikeinterface.curation import auto_label_units, load_model
+from spikeinterface.curation import model_based_label_units, load_model
from spikeinterface.curation.train_manual_curation import _get_computed_metrics
+from spikeinterface.curation import unitrefine_label_units
+
import numpy as np
@@ -39,13 +42,13 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model):
def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path):
- """The function `auto_label_units` needs the correct metrics to have been computed. However,
+ """The function `model_based_label_units` needs the correct metrics to have been computed. However,
it should be independent of the order of computation. We test this here."""
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"])
- prediction_prob_dataframe_1 = auto_label_units(
+ prediction_prob_dataframe_1 = model_based_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=trained_pipeline_path,
trusted=["numpy.dtype"],
@@ -53,7 +56,7 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pip
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"])
- prediction_prob_dataframe_2 = auto_label_units(
+ prediction_prob_dataframe_2 = model_based_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=trained_pipeline_path,
trusted=["numpy.dtype"],
@@ -168,3 +171,83 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura
model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"])
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model)
model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info)
+
+
+def test_unitrefine_label_units_hf(sorting_analyzer_for_curation):
+ """Test the `unitrefine_label_units` function."""
+ sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True)
+ sorting_analyzer_for_curation.compute("quality_metrics")
+
+ # test passing both classifiers
+ labels = unitrefine_label_units(
+ sorting_analyzer_for_curation,
+ noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
+ sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
+ )
+
+ assert "label" in labels.columns
+ assert "probability" in labels.columns
+ assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
+
+ # test only noise neural classifier
+ labels = unitrefine_label_units(
+ sorting_analyzer_for_curation,
+ noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
+ sua_mua_classifier=None,
+ )
+
+ assert "label" in labels.columns
+ assert "probability" in labels.columns
+ assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
+
+ # test only sua mua classifier
+ labels = unitrefine_label_units(
+ sorting_analyzer_for_curation,
+ noise_neural_classifier=None,
+ sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
+ )
+
+ assert "label" in labels.columns
+ assert "probability" in labels.columns
+ assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
+
+ # test passing none
+ with pytest.raises(ValueError):
+ labels = unitrefine_label_units(
+ sorting_analyzer_for_curation,
+ noise_neural_classifier=None,
+ sua_mua_classifier=None,
+ )
+
+ # test warnings when unexpected labels are returned
+ with pytest.warns(UserWarning):
+ labels = unitrefine_label_units(
+ sorting_analyzer_for_curation,
+ noise_neural_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
+ sua_mua_classifier=None,
+ )
+
+ with pytest.warns(UserWarning):
+ labels = unitrefine_label_units(
+ sorting_analyzer_for_curation,
+ noise_neural_classifier=None,
+ sua_mua_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
+ )
+
+
+def test_unitrefine_label_units_with_local_models(sorting_analyzer_for_curation, trained_pipeline_path):
+ # test with trained local models
+ sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True)
+ sorting_analyzer_for_curation.compute("quality_metrics")
+
+ # test passing model folder
+ labels = unitrefine_label_units(
+ sorting_analyzer_for_curation,
+ noise_neural_classifier=trained_pipeline_path,
+ )
+
+ # test passing model folder
+ labels = unitrefine_label_units(
+ sorting_analyzer_for_curation,
+ noise_neural_classifier=trained_pipeline_path / "best_model.skops",
+ )
diff --git a/src/spikeinterface/curation/unitrefine_curation.py b/src/spikeinterface/curation/unitrefine_curation.py
new file mode 100644
index 0000000000..72a14483fd
--- /dev/null
+++ b/src/spikeinterface/curation/unitrefine_curation.py
@@ -0,0 +1,115 @@
+import warnings
+from pathlib import Path
+
+from spikeinterface.core import SortingAnalyzer
+from spikeinterface.curation.model_based_curation import model_based_label_units
+
+
+def unitrefine_label_units(
+ sorting_analyzer: SortingAnalyzer,
+ noise_neural_classifier: str | Path | None = None,
+ sua_mua_classifier: str | Path | None = None,
+):
+ """Label units using a cascade of pre-trained classifiers for
+ noise/neural unit classification and SUA/MUA classification,
+ as shown in the UnitRefine paper (see References).
+ The noise/neural classifier is applied first to remove noise units,
+ then the SUA/MUA classifier is applied to the remaining units.
+
+ Parameters
+ ----------
+ sorting_analyzer : SortingAnalyzer
+ The sorting analyzer object containing the spike sorting results.
+ noise_neural_classifier : str or Path or None, default: None
+ The path to the folder containing the model, a full path to a model (".skops")
+ or a string to a repo on HuggingFace.
+ If None, the noise/neural classification step is skipped.
+ Make sure to provide at least one of the two classifiers.
+ sua_mua_classifier : str or Path or None, default: None
+ The path to the folder containing the model, a full path to a model (".skops")
+ or a string to a repo on HuggingFace.
+ If None, the SUA/MUA classification step is skipped.
+
+ Returns
+ -------
+ labels : pd.DataFrame
+ A DataFrame with unit ids as index and "label"/"probability" as column.
+
+ References
+ ----------
+ The approach is described in [Jain]_.
+ """
+ import pandas as pd
+
+ if noise_neural_classifier is None and sua_mua_classifier is None:
+ raise ValueError(
+ "At least one of noise_neural_classifier or sua_mua_classifier must be provided. "
+ "Pre-trained models can be found at https://huggingface.co/collections/SpikeInterface/curation-models or "
+ "https://huggingface.co/AnoushkaJain3/models. You can also train models on your own data: "
+ "see https://github.com/anoushkajain/UnitRefine for more details."
+ )
+
+ if noise_neural_classifier is not None:
+ # 1. apply the noise/neural classification and remove noise
+ noise_neuron_labels = model_based_label_units(
+ sorting_analyzer=sorting_analyzer,
+ trust_model=True,
+ **get_model_based_classification_kwargs(noise_neural_classifier),
+ )
+ if set(noise_neuron_labels["prediction"]) != {"noise", "neural"}:
+ warnings.warn(
+ "The noise/neural classifier did not return the expected labels 'noise' and 'neural'. "
+ "Please check the model used for classification."
+ )
+ noise_units = noise_neuron_labels[noise_neuron_labels["prediction"] == "noise"]
+ sorting_analyzer_neural = sorting_analyzer.remove_units(noise_units.index)
+ else:
+ sorting_analyzer_neural = sorting_analyzer
+ noise_units = pd.DataFrame(columns=["prediction", "probability"])
+
+ if sua_mua_classifier is not None:
+ # 2. apply the sua/mua classification and aggregate results
+ if len(sorting_analyzer.unit_ids) > len(noise_units):
+ sua_mua_labels = model_based_label_units(
+ sorting_analyzer=sorting_analyzer_neural,
+ trust_model=True,
+ **get_model_based_classification_kwargs(sua_mua_classifier),
+ )
+ if set(sua_mua_labels["prediction"]) != {"sua", "mua"}:
+ warnings.warn(
+ "The sua/mua classifier did not return the expected labels 'sua' and 'mua'. "
+ "Please check the model used for classification."
+ )
+ all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index()
+ else:
+ all_labels = noise_units
+ else:
+ all_labels = noise_neuron_labels
+
+ # rename prediction column to label
+ all_labels = all_labels.rename(columns={"prediction": "label"})
+ return all_labels
+
+
+def get_model_based_classification_kwargs(model: str | Path) -> dict:
+ """Get kwargs for model_based_label_units function based on model parameter.
+
+ Parameters
+ ----------
+ model : str or Path
+ The model argument.
+
+ Returns
+ -------
+ kwargs : dict
+ A dictionary with kwargs for model_based_label_units function based on model parameter.
+ This could be `model_folder`, `model_folder` + `model_name` or `repo_id`.
+ """
+ if Path(model).exists():
+ if Path(model).is_dir():
+ kwargs = {"model_folder": model}
+ else:
+ kwargs = {"model_folder": Path(model).parent, "model_name": Path(model).name}
+ else:
+ kwargs = {"repo_id": model}
+ return kwargs