From b50d1a559d8b5f15b5b7d66fc25ff4abe7bde550 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 10:52:20 +0100 Subject: [PATCH 01/13] Add qualitymetrics curation --- src/spikeinterface/curation/curation_tools.py | 9 +++ .../curation/qualitymetrics_curation.py | 80 +++++++++++++++++++ .../tests/test_qualitymetrics_curation.py | 68 ++++++++++++++++ 3 files changed, 157 insertions(+) create mode 100644 src/spikeinterface/curation/qualitymetrics_curation.py create mode 100644 src/spikeinterface/curation/tests/test_qualitymetrics_curation.py diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index f1d4eba3b5..e25cfba1e2 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -14,6 +14,15 @@ _methods_numpy = ("keep_first", "random", "keep_last") +def _is_threshold_disabled(value): + """Check if a threshold value is disabled (None or np.nan).""" + if value is None: + return True + if isinstance(value, float) and np.isnan(value): + return True + return False + + def _find_duplicated_spikes_numpy( spike_train: np.ndarray, censored_period: int, diff --git a/src/spikeinterface/curation/qualitymetrics_curation.py b/src/spikeinterface/curation/qualitymetrics_curation.py new file mode 100644 index 0000000000..f61a3c8078 --- /dev/null +++ b/src/spikeinterface/curation/qualitymetrics_curation.py @@ -0,0 +1,80 @@ +import json +from pathlib import Path + +import numpy as np + +from spikeinterface.core.analyzer_extension_core import SortingAnalyzer + +from .curation_tools import is_threshold_disabled + + +def qualitymetrics_label_units( + analyzer: SortingAnalyzer, + thresholds: dict | str | Path, +): + """Label units based on quality metrics and thresholds. + + Parameters + ---------- + analyzer : SortingAnalyzer + The SortingAnalyzer object containing the quality metrics. + thresholds : dict | str | Path + A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. + Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values + should contain at least "min" and/or "max" keys to specify threshold ranges. + Units that do not meet the threshold for a given metric will be labeled as 'noise', while those that do will + be labeled as 'good'. + + Returns + ------- + labels : pd.DataFrame + A DataFrame with unit IDs as index and a column 'label' containing the assigned labels ('noise' or 'good'). + """ + import pandas as pd + + # Get the quality metrics from the analyzer + assert analyzer.has_extension("quality_metrics"), ( + "The provided analyzer does not have quality metrics computed. " + "Please compute quality metrics before labeling units." + ) + qm = analyzer.get_extension("quality_metrics").get_data() + + # Load thresholds from file if a path is provided + if isinstance(thresholds, (str, Path)): + + with open(thresholds, "r") as f: + thresholds_dict = json.load(f) + elif isinstance(thresholds, dict): + thresholds_dict = thresholds + else: + raise ValueError("Thresholds must be a dictionary or a path to a JSON file containing the thresholds.") + + # Check that all specified metrics are present in the quality metrics DataFrame + missing_metrics = [] + for metric in thresholds_dict.keys(): + if metric not in qm.columns: + missing_metrics.append(metric) + if len(missing_metrics) > 0: + raise ValueError( + f"Metric(s) {missing_metrics} specified in thresholds are not present in the quality metrics DataFrame. " + f"Available metrics are: {qm.columns.tolist()}" + ) + + # Initialize an empty DataFrame to store labels + labels = pd.DataFrame(index=qm.index, dtype=str) + labels["label"] = "noise" # Default label is 'noise' + + # Apply thresholds to label units + good_mask = np.ones(len(qm), dtype=bool) + + for metric_name, threshold in thresholds_dict.items(): + min_value = threshold.get("min", None) + max_value = threshold.get("max", None) + if not is_threshold_disabled(min_value): + good_mask &= qm[metric_name] >= min_value + if not is_threshold_disabled(max_value): + good_mask &= qm[metric_name] <= max_value + + labels.loc[good_mask, "label"] = "good" + + return labels diff --git a/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py b/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py new file mode 100644 index 0000000000..96462818d1 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py @@ -0,0 +1,68 @@ +import pytest +import json + +from spikeinterface.curation.tests.common import sorting_analyzer_for_curation +from spikeinterface.curation import qualitymetrics_label_units + + +def test_qualitymetrics_label_units(sorting_analyzer_for_curation): + """Test the `qualitymetrics_label_units` function.""" + sorting_analyzer_for_curation.compute("quality_metrics") + + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1, "max": 20.0}, + } + + labels = qualitymetrics_label_units( + sorting_analyzer_for_curation, + thresholds, + ) + + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) + + # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' + qm = sorting_analyzer_for_curation.get_extension("quality_metrics").get_data() + for unit_id in sorting_analyzer_for_curation.sorting.unit_ids: + snr = qm.loc[unit_id, "snr"] + firing_rate = qm.loc[unit_id, "firing_rate"] + if ( + snr >= thresholds["snr"]["min"] + and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] + ): + assert labels.loc[unit_id, "label"] == "good" + else: + assert labels.loc[unit_id, "label"] == "noise" + + +def test_qualitymetrics_label_units_with_file(sorting_analyzer_for_curation, tmp_path): + """Test the `qualitymetrics_label_units` function with thresholds from a JSON file.""" + sorting_analyzer_for_curation.compute("quality_metrics") + + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1}, + } + + thresholds_file = tmp_path / "thresholds.json" + with open(thresholds_file, "w") as f: + json.dump(thresholds, f) + + labels = qualitymetrics_label_units( + sorting_analyzer_for_curation, + thresholds_file, + ) + + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) + + # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' + qm = sorting_analyzer_for_curation.get_extension("quality_metrics").get_data() + for unit_id in sorting_analyzer_for_curation.sorting.unit_ids: + snr = qm.loc[unit_id, "snr"] + firing_rate = qm.loc[unit_id, "firing_rate"] + if snr >= thresholds["snr"]["min"] and firing_rate >= thresholds["firing_rate"]["min"]: + assert labels.loc[unit_id, "label"] == "good" + else: + assert labels.loc[unit_id, "label"] == "noise" From d395c436393930e37d829156fb8e9fea26cecf34 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 10:53:58 +0100 Subject: [PATCH 02/13] Add to __init__ --- src/spikeinterface/curation/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 730481937c..8292116681 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -20,6 +20,7 @@ from .sortingview_curation import apply_sortingview_curation # automated curation +from .qualitymetrics_curation import qualitymetrics_label_units from .model_based_curation import model_based_label_units, load_model, auto_label_units from .train_manual_curation import train_model, get_default_classifier_search_spaces from .unitrefine_curation import unitrefine_label_units From 57616d0e944f744a18fef5b45fe0dd1bcdb8eff1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 10:57:44 +0100 Subject: [PATCH 03/13] rename function --- src/spikeinterface/curation/curation_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index e25cfba1e2..3b5cb046f6 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -14,7 +14,7 @@ _methods_numpy = ("keep_first", "random", "keep_last") -def _is_threshold_disabled(value): +def is_threshold_disabled(value): """Check if a threshold value is disabled (None or np.nan).""" if value is None: return True From 78f732f864d8027e847dbcf26d8f705922e4b5b5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 18:10:42 +0100 Subject: [PATCH 04/13] Rename threshold_metrics_label_units --- doc/api.rst | 1 + src/spikeinterface/curation/__init__.py | 2 +- .../curation/qualitymetrics_curation.py | 2 +- .../curation/tests/test_qualitymetrics_curation.py | 14 +++++++------- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index adfdb85470..f4a97caabe 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -373,6 +373,7 @@ spikeinterface.curation .. autofunction:: remove_redundant_units .. autofunction:: remove_duplicated_spikes .. autofunction:: remove_excess_spikes + .. autofunction:: threshold_metrics_label_units .. autofunction:: model_based_label_units .. autofunction:: load_model .. autofunction:: train_model diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 8292116681..c72ef82d3d 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -20,7 +20,7 @@ from .sortingview_curation import apply_sortingview_curation # automated curation -from .qualitymetrics_curation import qualitymetrics_label_units +from .qualitymetrics_curation import threshold_metrics_label_units from .model_based_curation import model_based_label_units, load_model, auto_label_units from .train_manual_curation import train_model, get_default_classifier_search_spaces from .unitrefine_curation import unitrefine_label_units diff --git a/src/spikeinterface/curation/qualitymetrics_curation.py b/src/spikeinterface/curation/qualitymetrics_curation.py index f61a3c8078..c8cc379dfd 100644 --- a/src/spikeinterface/curation/qualitymetrics_curation.py +++ b/src/spikeinterface/curation/qualitymetrics_curation.py @@ -8,7 +8,7 @@ from .curation_tools import is_threshold_disabled -def qualitymetrics_label_units( +def threshold_metrics_label_units( analyzer: SortingAnalyzer, thresholds: dict | str | Path, ): diff --git a/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py b/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py index 96462818d1..bd7c354688 100644 --- a/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py +++ b/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py @@ -2,11 +2,11 @@ import json from spikeinterface.curation.tests.common import sorting_analyzer_for_curation -from spikeinterface.curation import qualitymetrics_label_units +from spikeinterface.curation import threshold_metrics_label_units -def test_qualitymetrics_label_units(sorting_analyzer_for_curation): - """Test the `qualitymetrics_label_units` function.""" +def test_threshold_metrics_label_units(sorting_analyzer_for_curation): + """Test the `threshold_metrics_label_units` function.""" sorting_analyzer_for_curation.compute("quality_metrics") thresholds = { @@ -14,7 +14,7 @@ def test_qualitymetrics_label_units(sorting_analyzer_for_curation): "firing_rate": {"min": 0.1, "max": 20.0}, } - labels = qualitymetrics_label_units( + labels = threshold_metrics_label_units( sorting_analyzer_for_curation, thresholds, ) @@ -36,8 +36,8 @@ def test_qualitymetrics_label_units(sorting_analyzer_for_curation): assert labels.loc[unit_id, "label"] == "noise" -def test_qualitymetrics_label_units_with_file(sorting_analyzer_for_curation, tmp_path): - """Test the `qualitymetrics_label_units` function with thresholds from a JSON file.""" +def test_threshold_metrics_label_units_with_file(sorting_analyzer_for_curation, tmp_path): + """Test the `threshold_metrics_label_units` function with thresholds from a JSON file.""" sorting_analyzer_for_curation.compute("quality_metrics") thresholds = { @@ -49,7 +49,7 @@ def test_qualitymetrics_label_units_with_file(sorting_analyzer_for_curation, tmp with open(thresholds_file, "w") as f: json.dump(thresholds, f) - labels = qualitymetrics_label_units( + labels = threshold_metrics_label_units( sorting_analyzer_for_curation, thresholds_file, ) From 467420b75beae400edc405f37c4ba69cc2d17e1f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 18:14:41 +0100 Subject: [PATCH 05/13] Generalize over any metric --- src/spikeinterface/curation/__init__.py | 2 +- .../curation/qualitymetrics_curation.py | 80 ------------------- 2 files changed, 1 insertion(+), 81 deletions(-) delete mode 100644 src/spikeinterface/curation/qualitymetrics_curation.py diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index c72ef82d3d..e00629086b 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -20,7 +20,7 @@ from .sortingview_curation import apply_sortingview_curation # automated curation -from .qualitymetrics_curation import threshold_metrics_label_units +from .threshold_metrics_curation import threshold_metrics_label_units from .model_based_curation import model_based_label_units, load_model, auto_label_units from .train_manual_curation import train_model, get_default_classifier_search_spaces from .unitrefine_curation import unitrefine_label_units diff --git a/src/spikeinterface/curation/qualitymetrics_curation.py b/src/spikeinterface/curation/qualitymetrics_curation.py deleted file mode 100644 index c8cc379dfd..0000000000 --- a/src/spikeinterface/curation/qualitymetrics_curation.py +++ /dev/null @@ -1,80 +0,0 @@ -import json -from pathlib import Path - -import numpy as np - -from spikeinterface.core.analyzer_extension_core import SortingAnalyzer - -from .curation_tools import is_threshold_disabled - - -def threshold_metrics_label_units( - analyzer: SortingAnalyzer, - thresholds: dict | str | Path, -): - """Label units based on quality metrics and thresholds. - - Parameters - ---------- - analyzer : SortingAnalyzer - The SortingAnalyzer object containing the quality metrics. - thresholds : dict | str | Path - A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. - Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values - should contain at least "min" and/or "max" keys to specify threshold ranges. - Units that do not meet the threshold for a given metric will be labeled as 'noise', while those that do will - be labeled as 'good'. - - Returns - ------- - labels : pd.DataFrame - A DataFrame with unit IDs as index and a column 'label' containing the assigned labels ('noise' or 'good'). - """ - import pandas as pd - - # Get the quality metrics from the analyzer - assert analyzer.has_extension("quality_metrics"), ( - "The provided analyzer does not have quality metrics computed. " - "Please compute quality metrics before labeling units." - ) - qm = analyzer.get_extension("quality_metrics").get_data() - - # Load thresholds from file if a path is provided - if isinstance(thresholds, (str, Path)): - - with open(thresholds, "r") as f: - thresholds_dict = json.load(f) - elif isinstance(thresholds, dict): - thresholds_dict = thresholds - else: - raise ValueError("Thresholds must be a dictionary or a path to a JSON file containing the thresholds.") - - # Check that all specified metrics are present in the quality metrics DataFrame - missing_metrics = [] - for metric in thresholds_dict.keys(): - if metric not in qm.columns: - missing_metrics.append(metric) - if len(missing_metrics) > 0: - raise ValueError( - f"Metric(s) {missing_metrics} specified in thresholds are not present in the quality metrics DataFrame. " - f"Available metrics are: {qm.columns.tolist()}" - ) - - # Initialize an empty DataFrame to store labels - labels = pd.DataFrame(index=qm.index, dtype=str) - labels["label"] = "noise" # Default label is 'noise' - - # Apply thresholds to label units - good_mask = np.ones(len(qm), dtype=bool) - - for metric_name, threshold in thresholds_dict.items(): - min_value = threshold.get("min", None) - max_value = threshold.get("max", None) - if not is_threshold_disabled(min_value): - good_mask &= qm[metric_name] >= min_value - if not is_threshold_disabled(max_value): - good_mask &= qm[metric_name] <= max_value - - labels.loc[good_mask, "label"] = "good" - - return labels From b8cb1fcd16d41109331f675868a583190637aa59 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 18:15:24 +0100 Subject: [PATCH 06/13] add file... --- .../curation/threshold_metrics_curation.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 src/spikeinterface/curation/threshold_metrics_curation.py diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py new file mode 100644 index 0000000000..95c75dbd14 --- /dev/null +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -0,0 +1,73 @@ +import json +from pathlib import Path + +import numpy as np + +from spikeinterface.core.analyzer_extension_core import SortingAnalyzer + +from .curation_tools import is_threshold_disabled + + +def threshold_metrics_label_units( + sorting_analyzer: SortingAnalyzer, + thresholds: dict | str | Path, +): + """Label units based on metrics and thresholds. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object containing the some metrics extensions (e.g., quality metrics). + thresholds : dict | str | Path + A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. + Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values + should contain at least "min" and/or "max" keys to specify threshold ranges. + Units that do not meet the threshold for a given metric will be labeled as 'noise', while those that do will + be labeled as 'good'. + + Returns + ------- + labels : pd.DataFrame + A DataFrame with unit IDs as index and a column 'label' containing the assigned labels ('noise' or 'good'). + """ + import pandas as pd + + metrics = sorting_analyzer.get_metrics_extension_data() + + # Load thresholds from file if a path is provided + if isinstance(thresholds, (str, Path)): + with open(thresholds, "r") as f: + thresholds_dict = json.load(f) + elif isinstance(thresholds, dict): + thresholds_dict = thresholds + else: + raise ValueError("Thresholds must be a dictionary or a path to a JSON file containing the thresholds.") + + # Check that all specified metrics are present in the quality metrics DataFrame + missing_metrics = [] + for metric in thresholds_dict.keys(): + if metric not in metrics.columns: + missing_metrics.append(metric) + if len(missing_metrics) > 0: + raise ValueError( + f"Metric(s) {missing_metrics} specified in thresholds are not present in the quality metrics DataFrame. " + f"Available metrics are: {metrics.columns.tolist()}" + ) + + # Initialize an empty DataFrame to store labels + labels = pd.DataFrame(index=metrics.index, dtype=str) + labels["label"] = "noise" # Default label is 'noise' + + # Apply thresholds to label units + good_mask = np.ones(len(metrics), dtype=bool) + for metric_name, threshold in thresholds_dict.items(): + min_value = threshold.get("min", None) + max_value = threshold.get("max", None) + if not is_threshold_disabled(min_value): + good_mask &= metrics[metric_name] >= min_value + if not is_threshold_disabled(max_value): + good_mask &= metrics[metric_name] <= max_value + + labels.loc[good_mask, "label"] = "good" + + return labels From 2958d76e185ad1fa9e945fb382975234fd23b3f7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Feb 2026 12:20:13 +0100 Subject: [PATCH 07/13] Allow passing external labels and accept analyzer or dataframe --- .../tests/test_qualitymetrics_curation.py | 68 ---------- .../tests/test_threshold_metrics_curation.py | 127 ++++++++++++++++++ .../curation/threshold_metrics_curation.py | 34 +++-- 3 files changed, 149 insertions(+), 80 deletions(-) delete mode 100644 src/spikeinterface/curation/tests/test_qualitymetrics_curation.py create mode 100644 src/spikeinterface/curation/tests/test_threshold_metrics_curation.py diff --git a/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py b/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py deleted file mode 100644 index bd7c354688..0000000000 --- a/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py +++ /dev/null @@ -1,68 +0,0 @@ -import pytest -import json - -from spikeinterface.curation.tests.common import sorting_analyzer_for_curation -from spikeinterface.curation import threshold_metrics_label_units - - -def test_threshold_metrics_label_units(sorting_analyzer_for_curation): - """Test the `threshold_metrics_label_units` function.""" - sorting_analyzer_for_curation.compute("quality_metrics") - - thresholds = { - "snr": {"min": 5.0}, - "firing_rate": {"min": 0.1, "max": 20.0}, - } - - labels = threshold_metrics_label_units( - sorting_analyzer_for_curation, - thresholds, - ) - - assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) - - # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' - qm = sorting_analyzer_for_curation.get_extension("quality_metrics").get_data() - for unit_id in sorting_analyzer_for_curation.sorting.unit_ids: - snr = qm.loc[unit_id, "snr"] - firing_rate = qm.loc[unit_id, "firing_rate"] - if ( - snr >= thresholds["snr"]["min"] - and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] - ): - assert labels.loc[unit_id, "label"] == "good" - else: - assert labels.loc[unit_id, "label"] == "noise" - - -def test_threshold_metrics_label_units_with_file(sorting_analyzer_for_curation, tmp_path): - """Test the `threshold_metrics_label_units` function with thresholds from a JSON file.""" - sorting_analyzer_for_curation.compute("quality_metrics") - - thresholds = { - "snr": {"min": 5.0}, - "firing_rate": {"min": 0.1}, - } - - thresholds_file = tmp_path / "thresholds.json" - with open(thresholds_file, "w") as f: - json.dump(thresholds, f) - - labels = threshold_metrics_label_units( - sorting_analyzer_for_curation, - thresholds_file, - ) - - assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) - - # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' - qm = sorting_analyzer_for_curation.get_extension("quality_metrics").get_data() - for unit_id in sorting_analyzer_for_curation.sorting.unit_ids: - snr = qm.loc[unit_id, "snr"] - firing_rate = qm.loc[unit_id, "firing_rate"] - if snr >= thresholds["snr"]["min"] and firing_rate >= thresholds["firing_rate"]["min"]: - assert labels.loc[unit_id, "label"] == "good" - else: - assert labels.loc[unit_id, "label"] == "noise" diff --git a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py new file mode 100644 index 0000000000..90625401bb --- /dev/null +++ b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py @@ -0,0 +1,127 @@ +import pytest +import json + +from spikeinterface.curation.tests.common import sorting_analyzer_for_curation +from spikeinterface.curation import threshold_metrics_label_units + + +@pytest.fixture +def sorting_analyzer_with_metrics(sorting_analyzer_for_curation): + """A sorting analyzer with computed quality metrics.""" + + sorting_analyzer = sorting_analyzer_for_curation + sorting_analyzer.compute("quality_metrics") + return sorting_analyzer + + +def test_threshold_metrics_label_units(sorting_analyzer_with_metrics): + """Test the `threshold_metrics_label_units` function.""" + sorting_analyzer = sorting_analyzer_with_metrics + + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1, "max": 20.0}, + } + + labels = threshold_metrics_label_units( + sorting_analyzer, + thresholds, + ) + + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) + + # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' + qm = sorting_analyzer.get_extension("quality_metrics").get_data() + for unit_id in sorting_analyzer.sorting.unit_ids: + snr = qm.loc[unit_id, "snr"] + firing_rate = qm.loc[unit_id, "firing_rate"] + if ( + snr >= thresholds["snr"]["min"] + and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] + ): + assert labels.loc[unit_id, "label"] == "good" + else: + assert labels.loc[unit_id, "label"] == "noise" + + +def test_threshold_metrics_label_units_with_file(sorting_analyzer_with_metrics, tmp_path): + """Test the `threshold_metrics_label_units` function with thresholds from a JSON file.""" + sorting_analyzer = sorting_analyzer_with_metrics + + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1}, + } + + thresholds_file = tmp_path / "thresholds.json" + with open(thresholds_file, "w") as f: + json.dump(thresholds, f) + + labels = threshold_metrics_label_units( + sorting_analyzer, + thresholds_file, + ) + + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) + + # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' + qm = sorting_analyzer.get_extension("quality_metrics").get_data() + for unit_id in sorting_analyzer.sorting.unit_ids: + snr = qm.loc[unit_id, "snr"] + firing_rate = qm.loc[unit_id, "firing_rate"] + if snr >= thresholds["snr"]["min"] and firing_rate >= thresholds["firing_rate"]["min"]: + assert labels.loc[unit_id, "label"] == "good" + else: + assert labels.loc[unit_id, "label"] == "noise" + + +def test_threshold_metrics_label_units_with_external_metrics(sorting_analyzer_with_metrics): + """Test the `threshold_metrics_label_units` function with external metrics DataFrame.""" + sorting_analyzer = sorting_analyzer_with_metrics + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1, "max": 20.0}, + } + + qm = sorting_analyzer.get_extension("quality_metrics").get_data() + + labels = threshold_metrics_label_units( + sorting_analyzer_or_metrics=qm, + thresholds=thresholds, + ) + + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) + + # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' + for unit_id in sorting_analyzer.sorting.unit_ids: + snr = qm.loc[unit_id, "snr"] + firing_rate = qm.loc[unit_id, "firing_rate"] + if ( + snr >= thresholds["snr"]["min"] + and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] + ): + assert labels.loc[unit_id, "label"] == "good" + else: + assert labels.loc[unit_id, "label"] == "noise" + + +def test_threshold_metrics_label_external_labels(sorting_analyzer_with_metrics): + """Test the `threshold_metrics_label_units` function with custom pass/fail labels.""" + sorting_analyzer = sorting_analyzer_with_metrics + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1, "max": 20.0}, + } + + labels = threshold_metrics_label_units( + sorting_analyzer, + thresholds=thresholds, + pass_label="accepted", + fail_label="rejected", + ) + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) + assert set(labels["label"]).issubset({"accepted", "rejected"}) diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index 95c75dbd14..451edea87b 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -9,21 +9,26 @@ def threshold_metrics_label_units( - sorting_analyzer: SortingAnalyzer, + sorting_analyzer_or_metrics: "SortingAnalyzer | pd.DataFrame", thresholds: dict | str | Path, + pass_label: str = "good", + fail_label: str = "noise", ): """Label units based on metrics and thresholds. Parameters ---------- - sorting_analyzer : SortingAnalyzer - The SortingAnalyzer object containing the some metrics extensions (e.g., quality metrics). + sorting_analyzer_or_metrics : SortingAnalyzer | pd.DataFrame + The SortingAnalyzer object containing the some metrics extensions (e.g., quality metrics) or a DataFrame + containing unit metrics with unit IDs as index. thresholds : dict | str | Path A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values should contain at least "min" and/or "max" keys to specify threshold ranges. - Units that do not meet the threshold for a given metric will be labeled as 'noise', while those that do will - be labeled as 'good'. + pass_label : str, default: "good" + The label to assign to units that pass all thresholds. + fail_label : str, default: "noise" + The label to assign to units that fail any threshold. Returns ------- @@ -32,7 +37,13 @@ def threshold_metrics_label_units( """ import pandas as pd - metrics = sorting_analyzer.get_metrics_extension_data() + if not isinstance(sorting_analyzer_or_metrics, (SortingAnalyzer, pd.DataFrame)): + raise ValueError("Only SortingAnalyzer or pd.DataFrame are supported for sorting_analyzer_or_metrics.") + + if isinstance(sorting_analyzer_or_metrics, SortingAnalyzer): + metrics = sorting_analyzer_or_metrics.get_metrics_extension_data() + else: + metrics = sorting_analyzer_or_metrics # Load thresholds from file if a path is provided if isinstance(thresholds, (str, Path)): @@ -56,18 +67,17 @@ def threshold_metrics_label_units( # Initialize an empty DataFrame to store labels labels = pd.DataFrame(index=metrics.index, dtype=str) - labels["label"] = "noise" # Default label is 'noise' + labels["label"] = fail_label # Apply thresholds to label units - good_mask = np.ones(len(metrics), dtype=bool) + pass_mask = np.ones(len(metrics), dtype=bool) for metric_name, threshold in thresholds_dict.items(): min_value = threshold.get("min", None) max_value = threshold.get("max", None) if not is_threshold_disabled(min_value): - good_mask &= metrics[metric_name] >= min_value + pass_mask &= metrics[metric_name] >= min_value if not is_threshold_disabled(max_value): - good_mask &= metrics[metric_name] <= max_value - - labels.loc[good_mask, "label"] = "good" + pass_mask &= metrics[metric_name] <= max_value + labels.loc[pass_mask, "label"] = pass_label return labels From 7b795883a0e28bad702987123f60bcd3c9959118 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Feb 2026 14:48:00 +0100 Subject: [PATCH 08/13] Extend threshold_metrics curation with operator and nan policy --- .../tests/test_threshold_metrics_curation.py | 102 ++++++++++++++++++ .../curation/threshold_metrics_curation.py | 56 +++++++++- 2 files changed, 153 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py index 90625401bb..46d11f6057 100644 --- a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py +++ b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py @@ -1,6 +1,9 @@ import pytest import json +import numpy as np +import pandas as pd + from spikeinterface.curation.tests.common import sorting_analyzer_for_curation from spikeinterface.curation import threshold_metrics_label_units @@ -125,3 +128,102 @@ def test_threshold_metrics_label_external_labels(sorting_analyzer_with_metrics): assert "label" in labels.columns assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) assert set(labels["label"]).issubset({"accepted", "rejected"}) + + +def test_threshold_metrics_label_units_operator_or_with_dataframe(): + metrics = pd.DataFrame( + { + "m1": [1.0, 1.0, -1.0, -1.0], + "m2": [1.0, -1.0, 1.0, -1.0], + }, + index=[0, 1, 2, 3], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_and = threshold_metrics_label_units( + sorting_analyzer_or_metrics=metrics, + thresholds=thresholds, + operator="and", + ) + assert labels_and.index.equals(metrics.index) + assert labels_and["label"].to_dict() == {0: "good", 1: "noise", 2: "noise", 3: "noise"} + + labels_or = threshold_metrics_label_units( + sorting_analyzer_or_metrics=metrics, + thresholds=thresholds, + operator="or", + ) + assert labels_or.index.equals(metrics.index) + assert labels_or["label"].to_dict() == {0: "good", 1: "good", 2: "good", 3: "noise"} + + +def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): + metrics = pd.DataFrame( + { + "m1": [np.nan, 1.0, np.nan], + "m2": [1.0, -1.0, -1.0], + }, + index=[10, 11, 12], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_fail = threshold_metrics_label_units( + sorting_analyzer_or_metrics=metrics, + thresholds=thresholds, + operator="and", + nan_policy="fail", + ) + assert labels_fail["label"].to_dict() == {10: "noise", 11: "noise", 12: "noise"} + + labels_ignore = threshold_metrics_label_units( + sorting_analyzer_or_metrics=metrics, + thresholds=thresholds, + operator="and", + nan_policy="ignore", + ) + # unit 10: m1 ignored (NaN), m2 passes -> good + # unit 11: m2 fails -> noise + # unit 12: m1 ignored but m2 fails -> noise + assert labels_ignore["label"].to_dict() == {10: "good", 11: "noise", 12: "noise"} + + +def test_threshold_metrics_label_units_nan_policy_ignore_with_or(): + metrics = pd.DataFrame( + { + "m1": [np.nan, -1.0], + "m2": [-1.0, -1.0], + }, + index=[20, 21], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_ignore_or = threshold_metrics_label_units( + sorting_analyzer_or_metrics=metrics, + thresholds=thresholds, + operator="or", + nan_policy="ignore", + ) + # unit 20: m1 is NaN and ignored => passes that metric => good under "or" + # unit 21: both metrics fail => noise + assert labels_ignore_or["label"].to_dict() == {20: "good", 21: "noise"} + + +def test_threshold_metrics_label_units_invalid_operator_raises(): + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) + thresholds = {"m1": {"min": 0.0}} + with pytest.raises(ValueError, match="operator must be 'and' or 'or'"): + threshold_metrics_label_units(metrics, thresholds, operator="xor") + + +def test_threshold_metrics_label_units_invalid_nan_policy_raises(): + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) + thresholds = {"m1": {"min": 0.0}} + with pytest.raises(ValueError, match="nan_policy must be 'fail' or 'ignore'"): + threshold_metrics_label_units(metrics, thresholds, nan_policy="omit") + + +def test_threshold_metrics_label_units_missing_metric_raises(): + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) + thresholds = {"does_not_exist": {"min": 0.0}} + with pytest.raises(ValueError, match="specified in thresholds are not present"): + threshold_metrics_label_units(metrics, thresholds) diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index 451edea87b..19e6e104f4 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -13,6 +13,8 @@ def threshold_metrics_label_units( thresholds: dict | str | Path, pass_label: str = "good", fail_label: str = "noise", + operator: str = "and", + nan_policy: str = "fail", ): """Label units based on metrics and thresholds. @@ -29,6 +31,12 @@ def threshold_metrics_label_units( The label to assign to units that pass all thresholds. fail_label : str, default: "noise" The label to assign to units that fail any threshold. + operator : "and" | "or", default: "and" + The logical operator to combine multiple metric thresholds. "and" means a unit must pass all thresholds to be + labeled as pass_label, while "or" means a unit must pass at least one threshold to be labeled as pass_label. + nan_policy : "fail" | "ignore", default: "fail" + Policy for handling NaN values in metrics. If "fail", units with NaN values in any metric will be labeled as + fail_label. If "ignore", NaN values will be ignored Returns ------- @@ -65,19 +73,57 @@ def threshold_metrics_label_units( f"Available metrics are: {metrics.columns.tolist()}" ) - # Initialize an empty DataFrame to store labels + if operator not in ("and", "or"): + raise ValueError("operator must be 'and' or 'or'") + + if nan_policy not in ("fail", "ignore"): + raise ValueError("nan_policy must be 'fail' or 'ignore'") + labels = pd.DataFrame(index=metrics.index, dtype=str) labels["label"] = fail_label - # Apply thresholds to label units - pass_mask = np.ones(len(metrics), dtype=bool) + # Key change: init depends on operator + pass_mask = np.ones(len(metrics), dtype=bool) if operator == "and" else np.zeros(len(metrics), dtype=bool) + any_threshold_applied = False + for metric_name, threshold in thresholds_dict.items(): min_value = threshold.get("min", None) max_value = threshold.get("max", None) + + # If both disabled, ignore this metric + if is_threshold_disabled(min_value) and is_threshold_disabled(max_value): + continue + + values = metrics[metric_name].to_numpy() + is_nan = np.isnan(values) + + metric_ok = np.ones(len(values), dtype=bool) if not is_threshold_disabled(min_value): - pass_mask &= metrics[metric_name] >= min_value + metric_ok &= values >= min_value if not is_threshold_disabled(max_value): - pass_mask &= metrics[metric_name] <= max_value + metric_ok &= values <= max_value + + metric_pass = np.ones(len(metrics), dtype=bool) + if not is_threshold_disabled(min_value): + metric_pass &= values >= min_value + if not is_threshold_disabled(max_value): + metric_pass &= values <= max_value + + # Handle NaNs + if nan_policy == "fail": + metric_ok &= ~is_nan + else: # "ignore" + metric_ok |= is_nan + + any_threshold_applied = True + + if operator == "and": + pass_mask &= metric_ok + else: + pass_mask |= metric_ok + + if not any_threshold_applied: + pass_mask[:] = True labels.loc[pass_mask, "label"] = pass_label return labels From 32189edf0d0416bf7d860b9b3e08f68289ce7f3a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Feb 2026 17:45:20 +0100 Subject: [PATCH 09/13] Fix imports --- .../tests/test_threshold_metrics_curation.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py index 46d11f6057..9e75dd7b98 100644 --- a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py +++ b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py @@ -2,7 +2,6 @@ import json import numpy as np -import pandas as pd from spikeinterface.curation.tests.common import sorting_analyzer_for_curation from spikeinterface.curation import threshold_metrics_label_units @@ -131,6 +130,8 @@ def test_threshold_metrics_label_external_labels(sorting_analyzer_with_metrics): def test_threshold_metrics_label_units_operator_or_with_dataframe(): + import pandas as pd + metrics = pd.DataFrame( { "m1": [1.0, 1.0, -1.0, -1.0], @@ -158,6 +159,8 @@ def test_threshold_metrics_label_units_operator_or_with_dataframe(): def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): + import pandas as pd + metrics = pd.DataFrame( { "m1": [np.nan, 1.0, np.nan], @@ -188,6 +191,8 @@ def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): def test_threshold_metrics_label_units_nan_policy_ignore_with_or(): + import pandas as pd + metrics = pd.DataFrame( { "m1": [np.nan, -1.0], @@ -209,6 +214,8 @@ def test_threshold_metrics_label_units_nan_policy_ignore_with_or(): def test_threshold_metrics_label_units_invalid_operator_raises(): + import pandas as pd + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) thresholds = {"m1": {"min": 0.0}} with pytest.raises(ValueError, match="operator must be 'and' or 'or'"): @@ -216,6 +223,8 @@ def test_threshold_metrics_label_units_invalid_operator_raises(): def test_threshold_metrics_label_units_invalid_nan_policy_raises(): + import pandas as pd + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) thresholds = {"m1": {"min": 0.0}} with pytest.raises(ValueError, match="nan_policy must be 'fail' or 'ignore'"): @@ -223,6 +232,8 @@ def test_threshold_metrics_label_units_invalid_nan_policy_raises(): def test_threshold_metrics_label_units_missing_metric_raises(): + import pandas as pd + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) thresholds = {"does_not_exist": {"min": 0.0}} with pytest.raises(ValueError, match="specified in thresholds are not present"): From 8625b31bb86fa5cf60fe05c241166843f7d57571 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Feb 2026 12:57:21 +0100 Subject: [PATCH 10/13] Add pass nan_policy --- .../tests/test_threshold_metrics_curation.py | 175 ++++++++---------- .../curation/threshold_metrics_curation.py | 37 ++-- 2 files changed, 97 insertions(+), 115 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py index 9e75dd7b98..82e0400b29 100644 --- a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py +++ b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py @@ -3,54 +3,41 @@ import numpy as np -from spikeinterface.curation.tests.common import sorting_analyzer_for_curation from spikeinterface.curation import threshold_metrics_label_units -@pytest.fixture -def sorting_analyzer_with_metrics(sorting_analyzer_for_curation): - """A sorting analyzer with computed quality metrics.""" - - sorting_analyzer = sorting_analyzer_for_curation - sorting_analyzer.compute("quality_metrics") - return sorting_analyzer - - -def test_threshold_metrics_label_units(sorting_analyzer_with_metrics): - """Test the `threshold_metrics_label_units` function.""" - sorting_analyzer = sorting_analyzer_with_metrics +def test_threshold_metrics_label_units_with_dataframe(): + import pandas as pd + metrics = pd.DataFrame( + { + "snr": [6.0, 4.0, 5.0], + "firing_rate": [0.5, 0.2, 25.0], + }, + index=[0, 1, 2], + ) thresholds = { "snr": {"min": 5.0}, "firing_rate": {"min": 0.1, "max": 20.0}, } - labels = threshold_metrics_label_units( - sorting_analyzer, - thresholds, - ) + labels = threshold_metrics_label_units(metrics, thresholds) assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) - - # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' - qm = sorting_analyzer.get_extension("quality_metrics").get_data() - for unit_id in sorting_analyzer.sorting.unit_ids: - snr = qm.loc[unit_id, "snr"] - firing_rate = qm.loc[unit_id, "firing_rate"] - if ( - snr >= thresholds["snr"]["min"] - and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] - ): - assert labels.loc[unit_id, "label"] == "good" - else: - assert labels.loc[unit_id, "label"] == "noise" - - -def test_threshold_metrics_label_units_with_file(sorting_analyzer_with_metrics, tmp_path): - """Test the `threshold_metrics_label_units` function with thresholds from a JSON file.""" - sorting_analyzer = sorting_analyzer_with_metrics + assert labels.shape[0] == len(metrics.index) + assert labels["label"].to_dict() == {0: "good", 1: "noise", 2: "noise"} + +def test_threshold_metrics_label_units_with_file(tmp_path): + import pandas as pd + + metrics = pd.DataFrame( + { + "snr": [6.0, 4.0], + "firing_rate": [0.5, 0.05], + }, + index=[0, 1], + ) thresholds = { "snr": {"min": 5.0}, "firing_rate": {"min": 0.1}, @@ -60,72 +47,32 @@ def test_threshold_metrics_label_units_with_file(sorting_analyzer_with_metrics, with open(thresholds_file, "w") as f: json.dump(thresholds, f) - labels = threshold_metrics_label_units( - sorting_analyzer, - thresholds_file, - ) + labels = threshold_metrics_label_units(metrics, thresholds_file) - assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) - - # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' - qm = sorting_analyzer.get_extension("quality_metrics").get_data() - for unit_id in sorting_analyzer.sorting.unit_ids: - snr = qm.loc[unit_id, "snr"] - firing_rate = qm.loc[unit_id, "firing_rate"] - if snr >= thresholds["snr"]["min"] and firing_rate >= thresholds["firing_rate"]["min"]: - assert labels.loc[unit_id, "label"] == "good" - else: - assert labels.loc[unit_id, "label"] == "noise" - - -def test_threshold_metrics_label_units_with_external_metrics(sorting_analyzer_with_metrics): - """Test the `threshold_metrics_label_units` function with external metrics DataFrame.""" - sorting_analyzer = sorting_analyzer_with_metrics - thresholds = { - "snr": {"min": 5.0}, - "firing_rate": {"min": 0.1, "max": 20.0}, - } + assert labels["label"].to_dict() == {0: "good", 1: "noise"} - qm = sorting_analyzer.get_extension("quality_metrics").get_data() - labels = threshold_metrics_label_units( - sorting_analyzer_or_metrics=qm, - thresholds=thresholds, - ) +def test_threshold_metrics_label_external_labels(): + import pandas as pd - assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) - - # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' - for unit_id in sorting_analyzer.sorting.unit_ids: - snr = qm.loc[unit_id, "snr"] - firing_rate = qm.loc[unit_id, "firing_rate"] - if ( - snr >= thresholds["snr"]["min"] - and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] - ): - assert labels.loc[unit_id, "label"] == "good" - else: - assert labels.loc[unit_id, "label"] == "noise" - - -def test_threshold_metrics_label_external_labels(sorting_analyzer_with_metrics): - """Test the `threshold_metrics_label_units` function with custom pass/fail labels.""" - sorting_analyzer = sorting_analyzer_with_metrics + metrics = pd.DataFrame( + { + "snr": [6.0, 4.0], + "firing_rate": [0.5, 0.05], + }, + index=[0, 1], + ) thresholds = { "snr": {"min": 5.0}, - "firing_rate": {"min": 0.1, "max": 20.0}, + "firing_rate": {"min": 0.1}, } labels = threshold_metrics_label_units( - sorting_analyzer, + metrics, thresholds=thresholds, pass_label="accepted", fail_label="rejected", ) - assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) assert set(labels["label"]).issubset({"accepted", "rejected"}) @@ -142,7 +89,7 @@ def test_threshold_metrics_label_units_operator_or_with_dataframe(): thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} labels_and = threshold_metrics_label_units( - sorting_analyzer_or_metrics=metrics, + metrics, thresholds=thresholds, operator="and", ) @@ -150,7 +97,7 @@ def test_threshold_metrics_label_units_operator_or_with_dataframe(): assert labels_and["label"].to_dict() == {0: "good", 1: "noise", 2: "noise", 3: "noise"} labels_or = threshold_metrics_label_units( - sorting_analyzer_or_metrics=metrics, + metrics, thresholds=thresholds, operator="or", ) @@ -171,7 +118,7 @@ def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} labels_fail = threshold_metrics_label_units( - sorting_analyzer_or_metrics=metrics, + metrics, thresholds=thresholds, operator="and", nan_policy="fail", @@ -179,7 +126,7 @@ def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): assert labels_fail["label"].to_dict() == {10: "noise", 11: "noise", 12: "noise"} labels_ignore = threshold_metrics_label_units( - sorting_analyzer_or_metrics=metrics, + metrics, thresholds=thresholds, operator="and", nan_policy="ignore", @@ -203,14 +150,48 @@ def test_threshold_metrics_label_units_nan_policy_ignore_with_or(): thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} labels_ignore_or = threshold_metrics_label_units( - sorting_analyzer_or_metrics=metrics, + metrics, thresholds=thresholds, operator="or", nan_policy="ignore", ) - # unit 20: m1 is NaN and ignored => passes that metric => good under "or" + # unit 20: m1 is NaN and ignored; m2 fails => noise # unit 21: both metrics fail => noise - assert labels_ignore_or["label"].to_dict() == {20: "good", 21: "noise"} + assert labels_ignore_or["label"].to_dict() == {20: "noise", 21: "noise"} + + +def test_threshold_metrics_label_units_nan_policy_pass_and_or(): + import pandas as pd + + metrics = pd.DataFrame( + { + "m1": [np.nan, np.nan, 1.0, -1.0], + "m2": [1.0, -1.0, np.nan, np.nan], + }, + index=[30, 31, 32, 33], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_and = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + operator="and", + nan_policy="pass", + ) + # unit 30: m1 NaN (pass), m2 pass => good + # unit 31: m1 NaN (pass), m2 fail => noise + # unit 32: m1 pass, m2 NaN (pass) => good + # unit 33: m1 fail, m2 NaN (pass) => noise + assert labels_and["label"].to_dict() == {30: "good", 31: "noise", 32: "good", 33: "noise"} + + labels_or = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + operator="or", + nan_policy="pass", + ) + # any NaN counts as pass => good unless all metrics fail without NaN + assert labels_or["label"].to_dict() == {30: "good", 31: "good", 32: "good", 33: "good"} def test_threshold_metrics_label_units_invalid_operator_raises(): @@ -227,7 +208,7 @@ def test_threshold_metrics_label_units_invalid_nan_policy_raises(): metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) thresholds = {"m1": {"min": 0.0}} - with pytest.raises(ValueError, match="nan_policy must be 'fail' or 'ignore'"): + with pytest.raises(ValueError, match="nan_policy must be"): threshold_metrics_label_units(metrics, thresholds, nan_policy="omit") diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index 19e6e104f4..daa8b138f1 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -9,7 +9,7 @@ def threshold_metrics_label_units( - sorting_analyzer_or_metrics: "SortingAnalyzer | pd.DataFrame", + metrics: "pd.DataFrame", thresholds: dict | str | Path, pass_label: str = "good", fail_label: str = "noise", @@ -20,9 +20,8 @@ def threshold_metrics_label_units( Parameters ---------- - sorting_analyzer_or_metrics : SortingAnalyzer | pd.DataFrame - The SortingAnalyzer object containing the some metrics extensions (e.g., quality metrics) or a DataFrame - containing unit metrics with unit IDs as index. + metrics : pd.DataFrame + A DataFrame containing unit metrics with unit IDs as index. thresholds : dict | str | Path A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values @@ -34,9 +33,12 @@ def threshold_metrics_label_units( operator : "and" | "or", default: "and" The logical operator to combine multiple metric thresholds. "and" means a unit must pass all thresholds to be labeled as pass_label, while "or" means a unit must pass at least one threshold to be labeled as pass_label. - nan_policy : "fail" | "ignore", default: "fail" + nan_policy : "fail" | "pass" | "ignore", default: "fail" Policy for handling NaN values in metrics. If "fail", units with NaN values in any metric will be labeled as - fail_label. If "ignore", NaN values will be ignored + fail_label. If "pass", units with NaN values in one metric will be labeled as pass_label. + If "ignore", NaN values will be ignored. Note that the "ignore" behavior will depend on the operator used. + If "and", NaNs will be treated as passing, since the initial mask is all true; + if "or", NaNs will be treated as failing, since the initial mask is all false. Returns ------- @@ -45,13 +47,8 @@ def threshold_metrics_label_units( """ import pandas as pd - if not isinstance(sorting_analyzer_or_metrics, (SortingAnalyzer, pd.DataFrame)): - raise ValueError("Only SortingAnalyzer or pd.DataFrame are supported for sorting_analyzer_or_metrics.") - - if isinstance(sorting_analyzer_or_metrics, SortingAnalyzer): - metrics = sorting_analyzer_or_metrics.get_metrics_extension_data() - else: - metrics = sorting_analyzer_or_metrics + if not isinstance(metrics, pd.DataFrame): + raise ValueError("Only pd.DataFrame is supported for metrics.") # Load thresholds from file if a path is provided if isinstance(thresholds, (str, Path)): @@ -76,8 +73,8 @@ def threshold_metrics_label_units( if operator not in ("and", "or"): raise ValueError("operator must be 'and' or 'or'") - if nan_policy not in ("fail", "ignore"): - raise ValueError("nan_policy must be 'fail' or 'ignore'") + if nan_policy not in ("fail", "pass", "ignore"): + raise ValueError("nan_policy must be 'fail', 'pass', or 'ignore'") labels = pd.DataFrame(index=metrics.index, dtype=str) labels["label"] = fail_label @@ -110,17 +107,21 @@ def threshold_metrics_label_units( metric_pass &= values <= max_value # Handle NaNs + nan_mask = slice(None) if nan_policy == "fail": metric_ok &= ~is_nan - else: # "ignore" + elif nan_policy == "pass": metric_ok |= is_nan + else: + # if nan_policy == "ignore", we only set values for non-nan entries + nan_mask = ~is_nan any_threshold_applied = True if operator == "and": - pass_mask &= metric_ok + pass_mask[nan_mask] &= metric_ok[nan_mask] else: - pass_mask |= metric_ok + pass_mask[nan_mask] |= metric_ok[nan_mask] if not any_threshold_applied: pass_mask[:] = True From 78c831a7f4c7e57df03fec04a06a568b75628314 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Feb 2026 20:21:56 +0100 Subject: [PATCH 11/13] Update src/spikeinterface/curation/threshold_metrics_curation.py Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/curation/threshold_metrics_curation.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index daa8b138f1..2186a58fe5 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -100,12 +100,6 @@ def threshold_metrics_label_units( if not is_threshold_disabled(max_value): metric_ok &= values <= max_value - metric_pass = np.ones(len(metrics), dtype=bool) - if not is_threshold_disabled(min_value): - metric_pass &= values >= min_value - if not is_threshold_disabled(max_value): - metric_pass &= values <= max_value - # Handle NaNs nan_mask = slice(None) if nan_policy == "fail": From 58c3b4dd640fee84f0057da1929ba324636fd89a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Feb 2026 20:22:06 +0100 Subject: [PATCH 12/13] Update src/spikeinterface/curation/threshold_metrics_curation.py Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/curation/threshold_metrics_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index 2186a58fe5..32a6a48d91 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -43,7 +43,7 @@ def threshold_metrics_label_units( Returns ------- labels : pd.DataFrame - A DataFrame with unit IDs as index and a column 'label' containing the assigned labels ('noise' or 'good'). + A DataFrame with unit IDs as index and a column 'label' containing the assigned labels (`fail_label` or `pass_label`) """ import pandas as pd From 5696bdb7c1a6c3ec596d14fd946ca3d01e978634 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Feb 2026 16:04:57 +0100 Subject: [PATCH 13/13] Clarify logic and add column_name --- .../curation/threshold_metrics_curation.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index 32a6a48d91..daab840893 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -15,6 +15,7 @@ def threshold_metrics_label_units( fail_label: str = "noise", operator: str = "and", nan_policy: str = "fail", + column_name: str = "label", ): """Label units based on metrics and thresholds. @@ -39,11 +40,13 @@ def threshold_metrics_label_units( If "ignore", NaN values will be ignored. Note that the "ignore" behavior will depend on the operator used. If "and", NaNs will be treated as passing, since the initial mask is all true; if "or", NaNs will be treated as failing, since the initial mask is all false. + column_name : str, default: "label" + The name of the column in the output DataFrame that will contain the assigned labels. Returns ------- labels : pd.DataFrame - A DataFrame with unit IDs as index and a column 'label' containing the assigned labels (`fail_label` or `pass_label`) + A DataFrame with unit IDs as index and a column `column_name` containing the assigned labels (`fail_label` or `pass_label`) """ import pandas as pd @@ -77,7 +80,7 @@ def threshold_metrics_label_units( raise ValueError("nan_policy must be 'fail', 'pass', or 'ignore'") labels = pd.DataFrame(index=metrics.index, dtype=str) - labels["label"] = fail_label + labels[column_name] = fail_label # Key change: init depends on operator pass_mask = np.ones(len(metrics), dtype=bool) if operator == "and" else np.zeros(len(metrics), dtype=bool) @@ -101,24 +104,24 @@ def threshold_metrics_label_units( metric_ok &= values <= max_value # Handle NaNs - nan_mask = slice(None) if nan_policy == "fail": metric_ok &= ~is_nan + valid_mask = slice(None) elif nan_policy == "pass": metric_ok |= is_nan - else: - # if nan_policy == "ignore", we only set values for non-nan entries - nan_mask = ~is_nan + valid_mask = slice(None) + elif nan_policy == "ignore": + valid_mask = ~is_nan any_threshold_applied = True if operator == "and": - pass_mask[nan_mask] &= metric_ok[nan_mask] - else: - pass_mask[nan_mask] |= metric_ok[nan_mask] + pass_mask[valid_mask] &= metric_ok[valid_mask] + elif operator == "or": + pass_mask[valid_mask] |= metric_ok[valid_mask] if not any_threshold_applied: pass_mask[:] = True - labels.loc[pass_mask, "label"] = pass_label + labels.loc[pass_mask, column_name] = pass_label return labels