Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,10 @@ spikeinterface.curation
.. autofunction:: remove_redundant_units
.. autofunction:: remove_duplicated_spikes
.. autofunction:: remove_excess_spikes
.. autofunction:: auto_label_units
.. autofunction:: model_based_label_units
.. autofunction:: load_model
.. autofunction:: train_model
.. autofunction:: unitrefine_label_units

Curation Model
~~~~~~~~~~~~~~
Expand Down
10 changes: 5 additions & 5 deletions doc/how_to/auto_curation_prediction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ repo's URL after huggingface.co/) and that we trust the model.

.. code::

from spikeinterface.curation import auto_label_units
from spikeinterface.curation import model_based_label_units

labels_and_probabilities = auto_label_units(
labels_and_probabilities = model_based_label_units(
sorting_analyzer = sorting_analyzer,
repo_id = "SpikeInterface/toy_tetrode_model",
trust_model = True
Expand All @@ -29,7 +29,7 @@ create the labels:

.. code::

labels_and_probabilities = si.auto_label_units(
labels_and_probabilities = si.model_based_label_units(
sorting_analyzer = sorting_analyzer,
model_folder = "my_folder_with_a_model_in_it",
)
Expand All @@ -39,5 +39,5 @@ are also saved as a property of your ``sorting_analyzer`` and can be accessed li

.. code::

labels = sorting_analyzer.sorting.get_property("classifier_label")
probabilities = sorting_analyzer.sorting.get_property("classifier_probability")
labels = sorting_analyzer.get_sorting_property("classifier_label")
probabilities = sorting_analyzer.get_sorting_property("classifier_probability")
2 changes: 1 addition & 1 deletion doc/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ If you use the default "similarity_correlograms" preset in the :code:`compute_me

If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_

If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_
If you use :code:`unitrefine_label_units`, :code:`model_based_label_units` or :code:`train_model`, please cite [Jain]_

Benchmark
---------
Expand Down
30 changes: 21 additions & 9 deletions examples/tutorials/curation/plot_1_automated_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@

##############################################################################
# Great! We can now use the model to predict labels. Here, we pass the HF repo id directly
# to the ``auto_label_units`` function. This returns a dictionary containing a label and
# to the ``model_based_label_units`` function. This returns a dictionary containing a label and
# a confidence for each unit contained in the ``sorting_analyzer``.

labels = sc.auto_label_units(
labels = sc.model_based_label_units(
sorting_analyzer = sorting_analyzer,
repo_id = "SpikeInterface/toy_tetrode_model",
trusted = ['numpy.dtype']
Expand Down Expand Up @@ -211,16 +211,16 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
# For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in
# V1,SC and ALM: https://huggingface.co/SpikeInterface/UnitRefine_noise_neural_classifier/ and
# https://huggingface.co/SpikeInterface/UnitRefine_sua_mua_classifier/. One will classify units into
# `noise` or `not-noise` and the other will classify the `not-noise` units into single
# `noise` or `neural` and the other will classify the `neural` units into single
# unit activity (sua) units and multi-unit activity (mua) units.
#
# There is more information about the model on the model's HuggingFace page. Take a look!
# The idea here is to first apply the noise/not-noise classifier, then the sua/mua one.
# The idea here is to first apply the noise/neural classifier, then the sua/mua one.
# We can do so as follows:
#

# Apply the noise/not-noise model
noise_neuron_labels = sc.auto_label_units(
# Apply the noise/neural model
noise_neuron_labels = sc.model_based_label_units(
sorting_analyzer=sorting_analyzer,
repo_id="SpikeInterface/UnitRefine_noise_neural_classifier",
trust_model=True,
Expand All @@ -230,7 +230,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
analyzer_neural = sorting_analyzer.remove_units(noise_units.index)

# Apply the sua/mua model
sua_mua_labels = sc.auto_label_units(
sua_mua_labels = sc.model_based_label_units(
sorting_analyzer=analyzer_neural,
repo_id="SpikeInterface/UnitRefine_sua_mua_classifier",
trust_model=True,
Expand All @@ -239,6 +239,18 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index()
print(all_labels)

##############################################################################
# Both steps can be done in one go using the ``unitrefine_label_units`` function:
#

all_labels = sc.unitrefine_label_units(
sorting_analyzer,
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier",
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier",
)
print(all_labels)


##############################################################################
# If you run this without the ``trust_model=True`` parameter, you will receive an error:
#
Expand All @@ -252,7 +264,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
#
# .. dropdown:: More about security
#
# Sharing models, with are Python objects, is complicated.
# Sharing models, which are Python objects, is complicated.
# We have chosen to use the `skops format <https://skops.readthedocs.io/en/stable/>`_, instead
# of the common but insecure ``.pkl`` format (read about ``pickle`` security issues
# `here <https://lwn.net/Articles/964392/>`_). While unpacking the ``.skops`` file, each function
Expand All @@ -276,7 +288,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
#
# .. code-block::
#
# labels = sc.auto_label_units(
# labels = sc.model_based_label_units(
# sorting_analyzer = sorting_analyzer,
# model_folder = "path/to/model/folder",
# )
Expand Down
8 changes: 4 additions & 4 deletions examples/tutorials/curation/plot_3_upload_a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@
#
# ` ` ` python (NOTE: you should remove the spaces between each backtick. This is just formatting for the notebook you are reading)
#
# from spikeinterface.curation import auto_label_units
# labels = auto_label_units(
# from spikeinterface.curation import model_based_label_units
# labels = model_based_label_units(
# sorting_analyzer = sorting_analyzer,
# repo_id = "SpikeInterface/toy_tetrode_model",
# trust_model=True
Expand All @@ -123,9 +123,9 @@
# or you can download the entire repositry to `a_folder_for_a_model`, and use
#
# ` ` ` python
# from spikeinterface.curation import auto_label_units
# from spikeinterface.curation import model_based_label_units
#
# labels = auto_label_units(
# labels = model_based_label_units(
# sorting_analyzer = sorting_analyzer,
# model_folder = "path/to/a_folder_for_a_model",
# trusted = ['numpy.dtype']
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/curation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from .sortingview_curation import apply_sortingview_curation

# automated curation
from .model_based_curation import auto_label_units, load_model
from .model_based_curation import model_based_label_units, load_model, auto_label_units
from .train_manual_curation import train_model, get_default_classifier_search_spaces
from .unitrefine_curation import unitrefine_label_units
31 changes: 22 additions & 9 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def predict_labels(
)

# Set predictions and probability as sorting properties
self.sorting_analyzer.sorting.set_property("classifier_label", predictions)
self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities)
self.sorting_analyzer.set_sorting_property("classifier_label", predictions)
self.sorting_analyzer.set_sorting_property("classifier_probability", probabilities)

if export_to_phy:
self._export_to_phy(classified_units)
Expand Down Expand Up @@ -204,11 +204,11 @@ def _export_to_phy(self, classified_df):
classified_df.to_csv(f"{sorting_path}/cluster_prediction.tsv", sep="\t", index_label="cluster_id")


def auto_label_units(
def model_based_label_units(
sorting_analyzer: SortingAnalyzer,
model_folder=None,
model_name=None,
repo_id=None,
model_name=None,
label_conversion=None,
trust_model=False,
trusted=None,
Expand All @@ -227,11 +227,11 @@ def auto_label_units(
----------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting results.
model_folder : str or Path, defualt: None
model_folder : str or Path, default: None
The path to the folder containing the model
repo_id : str | Path, default: None
repo_id : str, default: None
Hugging face repo id which contains the model e.g. 'username/model'
model_name: str | Path, default: None
model_name: str, default: None
Filename of model e.g. 'my_model.skops'. If None, uses first model found.
label_conversion : dic | None, default: None
A dictionary for converting the predicted labels (which are integers) to custom labels. If None,
Expand Down Expand Up @@ -281,6 +281,19 @@ def auto_label_units(
return classified_units


def auto_label_units(*args, **kwargs):
"""
Deprecated function. Please use `model_based_label_units` instead.
"""
warnings.warn(
"`auto_label_units` is deprecated and will be removed in v0.105.0. "
"Please use `model_based_label_units` instead.",
DeprecationWarning,
stacklevel=2,
)
return model_based_label_units(*args, **kwargs)


def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=False, trusted=None):
"""
Loads a model and model_info from a HuggingFaceHub repo or a local folder.
Expand All @@ -289,9 +302,9 @@ def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=Fal
----------
model_folder : str or Path, defualt: None
The path to the folder containing the model
repo_id : str | Path, default: None
repo_id : str, default: None
Hugging face repo id which contains the model e.g. 'username/model'
model_name: str | Path, default: None
model_name: str, default: None
Filename of model e.g. 'my_model.skops'. If None, uses first model found.
trust_model : bool, default: False
Whether to trust the model. If True, the `trusted` parameter that is passed to `skops.load` to load the model will be
Expand Down
35 changes: 19 additions & 16 deletions src/spikeinterface/curation/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,25 @@ def trained_pipeline_path():
If the model already exists, this function does nothing.
"""
trained_model_folder = Path(__file__).parent / Path("trained_pipeline")
analyzer = make_sorting_analyzer(sparse=True)
analyzer.compute(
{
"quality_metrics": {"metric_names": ["snr", "num_spikes"]},
"template_metrics": {"metric_names": ["half_width"]},
}
)
train_model(
analyzers=[analyzer] * 5,
labels=[[1, 0, 1, 0, 1]] * 5,
folder=trained_model_folder,
classifiers=["RandomForestClassifier"],
imputation_strategies=["median"],
scaling_techniques=["standard_scaler"],
)
yield trained_model_folder
if trained_model_folder.is_dir():
yield trained_model_folder
else:
analyzer = make_sorting_analyzer(sparse=True)
analyzer.compute(
{
"quality_metrics": {"metric_names": ["snr", "num_spikes"]},
"template_metrics": {"metric_names": ["half_width"]},
}
)
train_model(
analyzers=[analyzer] * 5,
labels=[[1, 0, 1, 0, 1]] * 5,
folder=trained_model_folder,
classifiers=["RandomForestClassifier"],
imputation_strategies=["median"],
scaling_techniques=["standard_scaler"],
)
yield trained_model_folder


if __name__ == "__main__":
Expand Down
91 changes: 87 additions & 4 deletions src/spikeinterface/curation/tests/test_model_based_curation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pytest
from pathlib import Path

from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path
from spikeinterface.curation.model_based_curation import ModelBasedClassification
from spikeinterface.curation import auto_label_units, load_model
from spikeinterface.curation import model_based_label_units, load_model
from spikeinterface.curation.train_manual_curation import _get_computed_metrics
from spikeinterface.curation import unitrefine_label_units


import numpy as np

Expand Down Expand Up @@ -39,21 +42,21 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model):


def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path):
"""The function `auto_label_units` needs the correct metrics to have been computed. However,
"""The function `model_based_label_units` needs the correct metrics to have been computed. However,
it should be independent of the order of computation. We test this here."""

sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"])

prediction_prob_dataframe_1 = auto_label_units(
prediction_prob_dataframe_1 = model_based_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=trained_pipeline_path,
trusted=["numpy.dtype"],
)

sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"])

prediction_prob_dataframe_2 = auto_label_units(
prediction_prob_dataframe_2 = model_based_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=trained_pipeline_path,
trusted=["numpy.dtype"],
Expand Down Expand Up @@ -168,3 +171,83 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura
model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"])
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model)
model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info)


def test_unitrefine_label_units_hf(sorting_analyzer_for_curation):
"""Test the `unitrefine_label_units` function."""
sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True)
sorting_analyzer_for_curation.compute("quality_metrics")

# test passing both classifiers
labels = unitrefine_label_units(
sorting_analyzer_for_curation,
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
)

assert "label" in labels.columns
assert "probability" in labels.columns
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)

# test only noise neural classifier
labels = unitrefine_label_units(
sorting_analyzer_for_curation,
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
sua_mua_classifier=None,
)

assert "label" in labels.columns
assert "probability" in labels.columns
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)

# test only sua mua classifier
labels = unitrefine_label_units(
sorting_analyzer_for_curation,
noise_neural_classifier=None,
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
)

assert "label" in labels.columns
assert "probability" in labels.columns
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)

# test passing none
with pytest.raises(ValueError):
labels = unitrefine_label_units(
sorting_analyzer_for_curation,
noise_neural_classifier=None,
sua_mua_classifier=None,
)

# test warnings when unexpected labels are returned
with pytest.warns(UserWarning):
labels = unitrefine_label_units(
sorting_analyzer_for_curation,
noise_neural_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
sua_mua_classifier=None,
)

with pytest.warns(UserWarning):
labels = unitrefine_label_units(
sorting_analyzer_for_curation,
noise_neural_classifier=None,
sua_mua_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
)


def test_unitrefine_label_units_with_local_models(sorting_analyzer_for_curation, trained_pipeline_path):
# test with trained local models
sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True)
sorting_analyzer_for_curation.compute("quality_metrics")

# test passing model folder
labels = unitrefine_label_units(
sorting_analyzer_for_curation,
noise_neural_classifier=trained_pipeline_path,
)

# test passing model folder
labels = unitrefine_label_units(
sorting_analyzer_for_curation,
noise_neural_classifier=trained_pipeline_path / "best_model.skops",
)
Loading