From b50d1a559d8b5f15b5b7d66fc25ff4abe7bde550 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 10:52:20 +0100 Subject: [PATCH 01/12] Add qualitymetrics curation --- src/spikeinterface/curation/curation_tools.py | 9 +++ .../curation/qualitymetrics_curation.py | 80 +++++++++++++++++++ .../tests/test_qualitymetrics_curation.py | 68 ++++++++++++++++ 3 files changed, 157 insertions(+) create mode 100644 src/spikeinterface/curation/qualitymetrics_curation.py create mode 100644 src/spikeinterface/curation/tests/test_qualitymetrics_curation.py diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index f1d4eba3b5..e25cfba1e2 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -14,6 +14,15 @@ _methods_numpy = ("keep_first", "random", "keep_last") +def _is_threshold_disabled(value): + """Check if a threshold value is disabled (None or np.nan).""" + if value is None: + return True + if isinstance(value, float) and np.isnan(value): + return True + return False + + def _find_duplicated_spikes_numpy( spike_train: np.ndarray, censored_period: int, diff --git a/src/spikeinterface/curation/qualitymetrics_curation.py b/src/spikeinterface/curation/qualitymetrics_curation.py new file mode 100644 index 0000000000..f61a3c8078 --- /dev/null +++ b/src/spikeinterface/curation/qualitymetrics_curation.py @@ -0,0 +1,80 @@ +import json +from pathlib import Path + +import numpy as np + +from spikeinterface.core.analyzer_extension_core import SortingAnalyzer + +from .curation_tools import is_threshold_disabled + + +def qualitymetrics_label_units( + analyzer: SortingAnalyzer, + thresholds: dict | str | Path, +): + """Label units based on quality metrics and thresholds. + + Parameters + ---------- + analyzer : SortingAnalyzer + The SortingAnalyzer object containing the quality metrics. + thresholds : dict | str | Path + A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. + Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values + should contain at least "min" and/or "max" keys to specify threshold ranges. + Units that do not meet the threshold for a given metric will be labeled as 'noise', while those that do will + be labeled as 'good'. + + Returns + ------- + labels : pd.DataFrame + A DataFrame with unit IDs as index and a column 'label' containing the assigned labels ('noise' or 'good'). + """ + import pandas as pd + + # Get the quality metrics from the analyzer + assert analyzer.has_extension("quality_metrics"), ( + "The provided analyzer does not have quality metrics computed. " + "Please compute quality metrics before labeling units." + ) + qm = analyzer.get_extension("quality_metrics").get_data() + + # Load thresholds from file if a path is provided + if isinstance(thresholds, (str, Path)): + + with open(thresholds, "r") as f: + thresholds_dict = json.load(f) + elif isinstance(thresholds, dict): + thresholds_dict = thresholds + else: + raise ValueError("Thresholds must be a dictionary or a path to a JSON file containing the thresholds.") + + # Check that all specified metrics are present in the quality metrics DataFrame + missing_metrics = [] + for metric in thresholds_dict.keys(): + if metric not in qm.columns: + missing_metrics.append(metric) + if len(missing_metrics) > 0: + raise ValueError( + f"Metric(s) {missing_metrics} specified in thresholds are not present in the quality metrics DataFrame. " + f"Available metrics are: {qm.columns.tolist()}" + ) + + # Initialize an empty DataFrame to store labels + labels = pd.DataFrame(index=qm.index, dtype=str) + labels["label"] = "noise" # Default label is 'noise' + + # Apply thresholds to label units + good_mask = np.ones(len(qm), dtype=bool) + + for metric_name, threshold in thresholds_dict.items(): + min_value = threshold.get("min", None) + max_value = threshold.get("max", None) + if not is_threshold_disabled(min_value): + good_mask &= qm[metric_name] >= min_value + if not is_threshold_disabled(max_value): + good_mask &= qm[metric_name] <= max_value + + labels.loc[good_mask, "label"] = "good" + + return labels diff --git a/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py b/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py new file mode 100644 index 0000000000..96462818d1 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py @@ -0,0 +1,68 @@ +import pytest +import json + +from spikeinterface.curation.tests.common import sorting_analyzer_for_curation +from spikeinterface.curation import qualitymetrics_label_units + + +def test_qualitymetrics_label_units(sorting_analyzer_for_curation): + """Test the `qualitymetrics_label_units` function.""" + sorting_analyzer_for_curation.compute("quality_metrics") + + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1, "max": 20.0}, + } + + labels = qualitymetrics_label_units( + sorting_analyzer_for_curation, + thresholds, + ) + + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) + + # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' + qm = sorting_analyzer_for_curation.get_extension("quality_metrics").get_data() + for unit_id in sorting_analyzer_for_curation.sorting.unit_ids: + snr = qm.loc[unit_id, "snr"] + firing_rate = qm.loc[unit_id, "firing_rate"] + if ( + snr >= thresholds["snr"]["min"] + and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] + ): + assert labels.loc[unit_id, "label"] == "good" + else: + assert labels.loc[unit_id, "label"] == "noise" + + +def test_qualitymetrics_label_units_with_file(sorting_analyzer_for_curation, tmp_path): + """Test the `qualitymetrics_label_units` function with thresholds from a JSON file.""" + sorting_analyzer_for_curation.compute("quality_metrics") + + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1}, + } + + thresholds_file = tmp_path / "thresholds.json" + with open(thresholds_file, "w") as f: + json.dump(thresholds, f) + + labels = qualitymetrics_label_units( + sorting_analyzer_for_curation, + thresholds_file, + ) + + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) + + # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' + qm = sorting_analyzer_for_curation.get_extension("quality_metrics").get_data() + for unit_id in sorting_analyzer_for_curation.sorting.unit_ids: + snr = qm.loc[unit_id, "snr"] + firing_rate = qm.loc[unit_id, "firing_rate"] + if snr >= thresholds["snr"]["min"] and firing_rate >= thresholds["firing_rate"]["min"]: + assert labels.loc[unit_id, "label"] == "good" + else: + assert labels.loc[unit_id, "label"] == "noise" From d395c436393930e37d829156fb8e9fea26cecf34 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 10:53:58 +0100 Subject: [PATCH 02/12] Add to __init__ --- src/spikeinterface/curation/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 730481937c..8292116681 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -20,6 +20,7 @@ from .sortingview_curation import apply_sortingview_curation # automated curation +from .qualitymetrics_curation import qualitymetrics_label_units from .model_based_curation import model_based_label_units, load_model, auto_label_units from .train_manual_curation import train_model, get_default_classifier_search_spaces from .unitrefine_curation import unitrefine_label_units From 57616d0e944f744a18fef5b45fe0dd1bcdb8eff1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 10:57:44 +0100 Subject: [PATCH 03/12] rename function --- src/spikeinterface/curation/curation_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index e25cfba1e2..3b5cb046f6 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -14,7 +14,7 @@ _methods_numpy = ("keep_first", "random", "keep_last") -def _is_threshold_disabled(value): +def is_threshold_disabled(value): """Check if a threshold value is disabled (None or np.nan).""" if value is None: return True From 78f732f864d8027e847dbcf26d8f705922e4b5b5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 18:10:42 +0100 Subject: [PATCH 04/12] Rename threshold_metrics_label_units --- doc/api.rst | 1 + src/spikeinterface/curation/__init__.py | 2 +- .../curation/qualitymetrics_curation.py | 2 +- .../curation/tests/test_qualitymetrics_curation.py | 14 +++++++------- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index adfdb85470..f4a97caabe 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -373,6 +373,7 @@ spikeinterface.curation .. autofunction:: remove_redundant_units .. autofunction:: remove_duplicated_spikes .. autofunction:: remove_excess_spikes + .. autofunction:: threshold_metrics_label_units .. autofunction:: model_based_label_units .. autofunction:: load_model .. autofunction:: train_model diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 8292116681..c72ef82d3d 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -20,7 +20,7 @@ from .sortingview_curation import apply_sortingview_curation # automated curation -from .qualitymetrics_curation import qualitymetrics_label_units +from .qualitymetrics_curation import threshold_metrics_label_units from .model_based_curation import model_based_label_units, load_model, auto_label_units from .train_manual_curation import train_model, get_default_classifier_search_spaces from .unitrefine_curation import unitrefine_label_units diff --git a/src/spikeinterface/curation/qualitymetrics_curation.py b/src/spikeinterface/curation/qualitymetrics_curation.py index f61a3c8078..c8cc379dfd 100644 --- a/src/spikeinterface/curation/qualitymetrics_curation.py +++ b/src/spikeinterface/curation/qualitymetrics_curation.py @@ -8,7 +8,7 @@ from .curation_tools import is_threshold_disabled -def qualitymetrics_label_units( +def threshold_metrics_label_units( analyzer: SortingAnalyzer, thresholds: dict | str | Path, ): diff --git a/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py b/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py index 96462818d1..bd7c354688 100644 --- a/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py +++ b/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py @@ -2,11 +2,11 @@ import json from spikeinterface.curation.tests.common import sorting_analyzer_for_curation -from spikeinterface.curation import qualitymetrics_label_units +from spikeinterface.curation import threshold_metrics_label_units -def test_qualitymetrics_label_units(sorting_analyzer_for_curation): - """Test the `qualitymetrics_label_units` function.""" +def test_threshold_metrics_label_units(sorting_analyzer_for_curation): + """Test the `threshold_metrics_label_units` function.""" sorting_analyzer_for_curation.compute("quality_metrics") thresholds = { @@ -14,7 +14,7 @@ def test_qualitymetrics_label_units(sorting_analyzer_for_curation): "firing_rate": {"min": 0.1, "max": 20.0}, } - labels = qualitymetrics_label_units( + labels = threshold_metrics_label_units( sorting_analyzer_for_curation, thresholds, ) @@ -36,8 +36,8 @@ def test_qualitymetrics_label_units(sorting_analyzer_for_curation): assert labels.loc[unit_id, "label"] == "noise" -def test_qualitymetrics_label_units_with_file(sorting_analyzer_for_curation, tmp_path): - """Test the `qualitymetrics_label_units` function with thresholds from a JSON file.""" +def test_threshold_metrics_label_units_with_file(sorting_analyzer_for_curation, tmp_path): + """Test the `threshold_metrics_label_units` function with thresholds from a JSON file.""" sorting_analyzer_for_curation.compute("quality_metrics") thresholds = { @@ -49,7 +49,7 @@ def test_qualitymetrics_label_units_with_file(sorting_analyzer_for_curation, tmp with open(thresholds_file, "w") as f: json.dump(thresholds, f) - labels = qualitymetrics_label_units( + labels = threshold_metrics_label_units( sorting_analyzer_for_curation, thresholds_file, ) From 467420b75beae400edc405f37c4ba69cc2d17e1f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 18:14:41 +0100 Subject: [PATCH 05/12] Generalize over any metric --- src/spikeinterface/curation/__init__.py | 2 +- .../curation/qualitymetrics_curation.py | 80 ------------------- 2 files changed, 1 insertion(+), 81 deletions(-) delete mode 100644 src/spikeinterface/curation/qualitymetrics_curation.py diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index c72ef82d3d..e00629086b 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -20,7 +20,7 @@ from .sortingview_curation import apply_sortingview_curation # automated curation -from .qualitymetrics_curation import threshold_metrics_label_units +from .threshold_metrics_curation import threshold_metrics_label_units from .model_based_curation import model_based_label_units, load_model, auto_label_units from .train_manual_curation import train_model, get_default_classifier_search_spaces from .unitrefine_curation import unitrefine_label_units diff --git a/src/spikeinterface/curation/qualitymetrics_curation.py b/src/spikeinterface/curation/qualitymetrics_curation.py deleted file mode 100644 index c8cc379dfd..0000000000 --- a/src/spikeinterface/curation/qualitymetrics_curation.py +++ /dev/null @@ -1,80 +0,0 @@ -import json -from pathlib import Path - -import numpy as np - -from spikeinterface.core.analyzer_extension_core import SortingAnalyzer - -from .curation_tools import is_threshold_disabled - - -def threshold_metrics_label_units( - analyzer: SortingAnalyzer, - thresholds: dict | str | Path, -): - """Label units based on quality metrics and thresholds. - - Parameters - ---------- - analyzer : SortingAnalyzer - The SortingAnalyzer object containing the quality metrics. - thresholds : dict | str | Path - A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. - Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values - should contain at least "min" and/or "max" keys to specify threshold ranges. - Units that do not meet the threshold for a given metric will be labeled as 'noise', while those that do will - be labeled as 'good'. - - Returns - ------- - labels : pd.DataFrame - A DataFrame with unit IDs as index and a column 'label' containing the assigned labels ('noise' or 'good'). - """ - import pandas as pd - - # Get the quality metrics from the analyzer - assert analyzer.has_extension("quality_metrics"), ( - "The provided analyzer does not have quality metrics computed. " - "Please compute quality metrics before labeling units." - ) - qm = analyzer.get_extension("quality_metrics").get_data() - - # Load thresholds from file if a path is provided - if isinstance(thresholds, (str, Path)): - - with open(thresholds, "r") as f: - thresholds_dict = json.load(f) - elif isinstance(thresholds, dict): - thresholds_dict = thresholds - else: - raise ValueError("Thresholds must be a dictionary or a path to a JSON file containing the thresholds.") - - # Check that all specified metrics are present in the quality metrics DataFrame - missing_metrics = [] - for metric in thresholds_dict.keys(): - if metric not in qm.columns: - missing_metrics.append(metric) - if len(missing_metrics) > 0: - raise ValueError( - f"Metric(s) {missing_metrics} specified in thresholds are not present in the quality metrics DataFrame. " - f"Available metrics are: {qm.columns.tolist()}" - ) - - # Initialize an empty DataFrame to store labels - labels = pd.DataFrame(index=qm.index, dtype=str) - labels["label"] = "noise" # Default label is 'noise' - - # Apply thresholds to label units - good_mask = np.ones(len(qm), dtype=bool) - - for metric_name, threshold in thresholds_dict.items(): - min_value = threshold.get("min", None) - max_value = threshold.get("max", None) - if not is_threshold_disabled(min_value): - good_mask &= qm[metric_name] >= min_value - if not is_threshold_disabled(max_value): - good_mask &= qm[metric_name] <= max_value - - labels.loc[good_mask, "label"] = "good" - - return labels From b8cb1fcd16d41109331f675868a583190637aa59 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Feb 2026 18:15:24 +0100 Subject: [PATCH 06/12] add file... --- .../curation/threshold_metrics_curation.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 src/spikeinterface/curation/threshold_metrics_curation.py diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py new file mode 100644 index 0000000000..95c75dbd14 --- /dev/null +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -0,0 +1,73 @@ +import json +from pathlib import Path + +import numpy as np + +from spikeinterface.core.analyzer_extension_core import SortingAnalyzer + +from .curation_tools import is_threshold_disabled + + +def threshold_metrics_label_units( + sorting_analyzer: SortingAnalyzer, + thresholds: dict | str | Path, +): + """Label units based on metrics and thresholds. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object containing the some metrics extensions (e.g., quality metrics). + thresholds : dict | str | Path + A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. + Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values + should contain at least "min" and/or "max" keys to specify threshold ranges. + Units that do not meet the threshold for a given metric will be labeled as 'noise', while those that do will + be labeled as 'good'. + + Returns + ------- + labels : pd.DataFrame + A DataFrame with unit IDs as index and a column 'label' containing the assigned labels ('noise' or 'good'). + """ + import pandas as pd + + metrics = sorting_analyzer.get_metrics_extension_data() + + # Load thresholds from file if a path is provided + if isinstance(thresholds, (str, Path)): + with open(thresholds, "r") as f: + thresholds_dict = json.load(f) + elif isinstance(thresholds, dict): + thresholds_dict = thresholds + else: + raise ValueError("Thresholds must be a dictionary or a path to a JSON file containing the thresholds.") + + # Check that all specified metrics are present in the quality metrics DataFrame + missing_metrics = [] + for metric in thresholds_dict.keys(): + if metric not in metrics.columns: + missing_metrics.append(metric) + if len(missing_metrics) > 0: + raise ValueError( + f"Metric(s) {missing_metrics} specified in thresholds are not present in the quality metrics DataFrame. " + f"Available metrics are: {metrics.columns.tolist()}" + ) + + # Initialize an empty DataFrame to store labels + labels = pd.DataFrame(index=metrics.index, dtype=str) + labels["label"] = "noise" # Default label is 'noise' + + # Apply thresholds to label units + good_mask = np.ones(len(metrics), dtype=bool) + for metric_name, threshold in thresholds_dict.items(): + min_value = threshold.get("min", None) + max_value = threshold.get("max", None) + if not is_threshold_disabled(min_value): + good_mask &= metrics[metric_name] >= min_value + if not is_threshold_disabled(max_value): + good_mask &= metrics[metric_name] <= max_value + + labels.loc[good_mask, "label"] = "good" + + return labels From 2958d76e185ad1fa9e945fb382975234fd23b3f7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Feb 2026 12:20:13 +0100 Subject: [PATCH 07/12] Allow passing external labels and accept analyzer or dataframe --- .../tests/test_qualitymetrics_curation.py | 68 ---------- .../tests/test_threshold_metrics_curation.py | 127 ++++++++++++++++++ .../curation/threshold_metrics_curation.py | 34 +++-- 3 files changed, 149 insertions(+), 80 deletions(-) delete mode 100644 src/spikeinterface/curation/tests/test_qualitymetrics_curation.py create mode 100644 src/spikeinterface/curation/tests/test_threshold_metrics_curation.py diff --git a/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py b/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py deleted file mode 100644 index bd7c354688..0000000000 --- a/src/spikeinterface/curation/tests/test_qualitymetrics_curation.py +++ /dev/null @@ -1,68 +0,0 @@ -import pytest -import json - -from spikeinterface.curation.tests.common import sorting_analyzer_for_curation -from spikeinterface.curation import threshold_metrics_label_units - - -def test_threshold_metrics_label_units(sorting_analyzer_for_curation): - """Test the `threshold_metrics_label_units` function.""" - sorting_analyzer_for_curation.compute("quality_metrics") - - thresholds = { - "snr": {"min": 5.0}, - "firing_rate": {"min": 0.1, "max": 20.0}, - } - - labels = threshold_metrics_label_units( - sorting_analyzer_for_curation, - thresholds, - ) - - assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) - - # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' - qm = sorting_analyzer_for_curation.get_extension("quality_metrics").get_data() - for unit_id in sorting_analyzer_for_curation.sorting.unit_ids: - snr = qm.loc[unit_id, "snr"] - firing_rate = qm.loc[unit_id, "firing_rate"] - if ( - snr >= thresholds["snr"]["min"] - and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] - ): - assert labels.loc[unit_id, "label"] == "good" - else: - assert labels.loc[unit_id, "label"] == "noise" - - -def test_threshold_metrics_label_units_with_file(sorting_analyzer_for_curation, tmp_path): - """Test the `threshold_metrics_label_units` function with thresholds from a JSON file.""" - sorting_analyzer_for_curation.compute("quality_metrics") - - thresholds = { - "snr": {"min": 5.0}, - "firing_rate": {"min": 0.1}, - } - - thresholds_file = tmp_path / "thresholds.json" - with open(thresholds_file, "w") as f: - json.dump(thresholds, f) - - labels = threshold_metrics_label_units( - sorting_analyzer_for_curation, - thresholds_file, - ) - - assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) - - # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' - qm = sorting_analyzer_for_curation.get_extension("quality_metrics").get_data() - for unit_id in sorting_analyzer_for_curation.sorting.unit_ids: - snr = qm.loc[unit_id, "snr"] - firing_rate = qm.loc[unit_id, "firing_rate"] - if snr >= thresholds["snr"]["min"] and firing_rate >= thresholds["firing_rate"]["min"]: - assert labels.loc[unit_id, "label"] == "good" - else: - assert labels.loc[unit_id, "label"] == "noise" diff --git a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py new file mode 100644 index 0000000000..90625401bb --- /dev/null +++ b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py @@ -0,0 +1,127 @@ +import pytest +import json + +from spikeinterface.curation.tests.common import sorting_analyzer_for_curation +from spikeinterface.curation import threshold_metrics_label_units + + +@pytest.fixture +def sorting_analyzer_with_metrics(sorting_analyzer_for_curation): + """A sorting analyzer with computed quality metrics.""" + + sorting_analyzer = sorting_analyzer_for_curation + sorting_analyzer.compute("quality_metrics") + return sorting_analyzer + + +def test_threshold_metrics_label_units(sorting_analyzer_with_metrics): + """Test the `threshold_metrics_label_units` function.""" + sorting_analyzer = sorting_analyzer_with_metrics + + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1, "max": 20.0}, + } + + labels = threshold_metrics_label_units( + sorting_analyzer, + thresholds, + ) + + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) + + # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' + qm = sorting_analyzer.get_extension("quality_metrics").get_data() + for unit_id in sorting_analyzer.sorting.unit_ids: + snr = qm.loc[unit_id, "snr"] + firing_rate = qm.loc[unit_id, "firing_rate"] + if ( + snr >= thresholds["snr"]["min"] + and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] + ): + assert labels.loc[unit_id, "label"] == "good" + else: + assert labels.loc[unit_id, "label"] == "noise" + + +def test_threshold_metrics_label_units_with_file(sorting_analyzer_with_metrics, tmp_path): + """Test the `threshold_metrics_label_units` function with thresholds from a JSON file.""" + sorting_analyzer = sorting_analyzer_with_metrics + + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1}, + } + + thresholds_file = tmp_path / "thresholds.json" + with open(thresholds_file, "w") as f: + json.dump(thresholds, f) + + labels = threshold_metrics_label_units( + sorting_analyzer, + thresholds_file, + ) + + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) + + # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' + qm = sorting_analyzer.get_extension("quality_metrics").get_data() + for unit_id in sorting_analyzer.sorting.unit_ids: + snr = qm.loc[unit_id, "snr"] + firing_rate = qm.loc[unit_id, "firing_rate"] + if snr >= thresholds["snr"]["min"] and firing_rate >= thresholds["firing_rate"]["min"]: + assert labels.loc[unit_id, "label"] == "good" + else: + assert labels.loc[unit_id, "label"] == "noise" + + +def test_threshold_metrics_label_units_with_external_metrics(sorting_analyzer_with_metrics): + """Test the `threshold_metrics_label_units` function with external metrics DataFrame.""" + sorting_analyzer = sorting_analyzer_with_metrics + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1, "max": 20.0}, + } + + qm = sorting_analyzer.get_extension("quality_metrics").get_data() + + labels = threshold_metrics_label_units( + sorting_analyzer_or_metrics=qm, + thresholds=thresholds, + ) + + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) + + # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' + for unit_id in sorting_analyzer.sorting.unit_ids: + snr = qm.loc[unit_id, "snr"] + firing_rate = qm.loc[unit_id, "firing_rate"] + if ( + snr >= thresholds["snr"]["min"] + and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] + ): + assert labels.loc[unit_id, "label"] == "good" + else: + assert labels.loc[unit_id, "label"] == "noise" + + +def test_threshold_metrics_label_external_labels(sorting_analyzer_with_metrics): + """Test the `threshold_metrics_label_units` function with custom pass/fail labels.""" + sorting_analyzer = sorting_analyzer_with_metrics + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1, "max": 20.0}, + } + + labels = threshold_metrics_label_units( + sorting_analyzer, + thresholds=thresholds, + pass_label="accepted", + fail_label="rejected", + ) + assert "label" in labels.columns + assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) + assert set(labels["label"]).issubset({"accepted", "rejected"}) diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index 95c75dbd14..451edea87b 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -9,21 +9,26 @@ def threshold_metrics_label_units( - sorting_analyzer: SortingAnalyzer, + sorting_analyzer_or_metrics: "SortingAnalyzer | pd.DataFrame", thresholds: dict | str | Path, + pass_label: str = "good", + fail_label: str = "noise", ): """Label units based on metrics and thresholds. Parameters ---------- - sorting_analyzer : SortingAnalyzer - The SortingAnalyzer object containing the some metrics extensions (e.g., quality metrics). + sorting_analyzer_or_metrics : SortingAnalyzer | pd.DataFrame + The SortingAnalyzer object containing the some metrics extensions (e.g., quality metrics) or a DataFrame + containing unit metrics with unit IDs as index. thresholds : dict | str | Path A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values should contain at least "min" and/or "max" keys to specify threshold ranges. - Units that do not meet the threshold for a given metric will be labeled as 'noise', while those that do will - be labeled as 'good'. + pass_label : str, default: "good" + The label to assign to units that pass all thresholds. + fail_label : str, default: "noise" + The label to assign to units that fail any threshold. Returns ------- @@ -32,7 +37,13 @@ def threshold_metrics_label_units( """ import pandas as pd - metrics = sorting_analyzer.get_metrics_extension_data() + if not isinstance(sorting_analyzer_or_metrics, (SortingAnalyzer, pd.DataFrame)): + raise ValueError("Only SortingAnalyzer or pd.DataFrame are supported for sorting_analyzer_or_metrics.") + + if isinstance(sorting_analyzer_or_metrics, SortingAnalyzer): + metrics = sorting_analyzer_or_metrics.get_metrics_extension_data() + else: + metrics = sorting_analyzer_or_metrics # Load thresholds from file if a path is provided if isinstance(thresholds, (str, Path)): @@ -56,18 +67,17 @@ def threshold_metrics_label_units( # Initialize an empty DataFrame to store labels labels = pd.DataFrame(index=metrics.index, dtype=str) - labels["label"] = "noise" # Default label is 'noise' + labels["label"] = fail_label # Apply thresholds to label units - good_mask = np.ones(len(metrics), dtype=bool) + pass_mask = np.ones(len(metrics), dtype=bool) for metric_name, threshold in thresholds_dict.items(): min_value = threshold.get("min", None) max_value = threshold.get("max", None) if not is_threshold_disabled(min_value): - good_mask &= metrics[metric_name] >= min_value + pass_mask &= metrics[metric_name] >= min_value if not is_threshold_disabled(max_value): - good_mask &= metrics[metric_name] <= max_value - - labels.loc[good_mask, "label"] = "good" + pass_mask &= metrics[metric_name] <= max_value + labels.loc[pass_mask, "label"] = pass_label return labels From 7b795883a0e28bad702987123f60bcd3c9959118 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Feb 2026 14:48:00 +0100 Subject: [PATCH 08/12] Extend threshold_metrics curation with operator and nan policy --- .../tests/test_threshold_metrics_curation.py | 102 ++++++++++++++++++ .../curation/threshold_metrics_curation.py | 56 +++++++++- 2 files changed, 153 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py index 90625401bb..46d11f6057 100644 --- a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py +++ b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py @@ -1,6 +1,9 @@ import pytest import json +import numpy as np +import pandas as pd + from spikeinterface.curation.tests.common import sorting_analyzer_for_curation from spikeinterface.curation import threshold_metrics_label_units @@ -125,3 +128,102 @@ def test_threshold_metrics_label_external_labels(sorting_analyzer_with_metrics): assert "label" in labels.columns assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) assert set(labels["label"]).issubset({"accepted", "rejected"}) + + +def test_threshold_metrics_label_units_operator_or_with_dataframe(): + metrics = pd.DataFrame( + { + "m1": [1.0, 1.0, -1.0, -1.0], + "m2": [1.0, -1.0, 1.0, -1.0], + }, + index=[0, 1, 2, 3], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_and = threshold_metrics_label_units( + sorting_analyzer_or_metrics=metrics, + thresholds=thresholds, + operator="and", + ) + assert labels_and.index.equals(metrics.index) + assert labels_and["label"].to_dict() == {0: "good", 1: "noise", 2: "noise", 3: "noise"} + + labels_or = threshold_metrics_label_units( + sorting_analyzer_or_metrics=metrics, + thresholds=thresholds, + operator="or", + ) + assert labels_or.index.equals(metrics.index) + assert labels_or["label"].to_dict() == {0: "good", 1: "good", 2: "good", 3: "noise"} + + +def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): + metrics = pd.DataFrame( + { + "m1": [np.nan, 1.0, np.nan], + "m2": [1.0, -1.0, -1.0], + }, + index=[10, 11, 12], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_fail = threshold_metrics_label_units( + sorting_analyzer_or_metrics=metrics, + thresholds=thresholds, + operator="and", + nan_policy="fail", + ) + assert labels_fail["label"].to_dict() == {10: "noise", 11: "noise", 12: "noise"} + + labels_ignore = threshold_metrics_label_units( + sorting_analyzer_or_metrics=metrics, + thresholds=thresholds, + operator="and", + nan_policy="ignore", + ) + # unit 10: m1 ignored (NaN), m2 passes -> good + # unit 11: m2 fails -> noise + # unit 12: m1 ignored but m2 fails -> noise + assert labels_ignore["label"].to_dict() == {10: "good", 11: "noise", 12: "noise"} + + +def test_threshold_metrics_label_units_nan_policy_ignore_with_or(): + metrics = pd.DataFrame( + { + "m1": [np.nan, -1.0], + "m2": [-1.0, -1.0], + }, + index=[20, 21], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_ignore_or = threshold_metrics_label_units( + sorting_analyzer_or_metrics=metrics, + thresholds=thresholds, + operator="or", + nan_policy="ignore", + ) + # unit 20: m1 is NaN and ignored => passes that metric => good under "or" + # unit 21: both metrics fail => noise + assert labels_ignore_or["label"].to_dict() == {20: "good", 21: "noise"} + + +def test_threshold_metrics_label_units_invalid_operator_raises(): + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) + thresholds = {"m1": {"min": 0.0}} + with pytest.raises(ValueError, match="operator must be 'and' or 'or'"): + threshold_metrics_label_units(metrics, thresholds, operator="xor") + + +def test_threshold_metrics_label_units_invalid_nan_policy_raises(): + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) + thresholds = {"m1": {"min": 0.0}} + with pytest.raises(ValueError, match="nan_policy must be 'fail' or 'ignore'"): + threshold_metrics_label_units(metrics, thresholds, nan_policy="omit") + + +def test_threshold_metrics_label_units_missing_metric_raises(): + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) + thresholds = {"does_not_exist": {"min": 0.0}} + with pytest.raises(ValueError, match="specified in thresholds are not present"): + threshold_metrics_label_units(metrics, thresholds) diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index 451edea87b..19e6e104f4 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -13,6 +13,8 @@ def threshold_metrics_label_units( thresholds: dict | str | Path, pass_label: str = "good", fail_label: str = "noise", + operator: str = "and", + nan_policy: str = "fail", ): """Label units based on metrics and thresholds. @@ -29,6 +31,12 @@ def threshold_metrics_label_units( The label to assign to units that pass all thresholds. fail_label : str, default: "noise" The label to assign to units that fail any threshold. + operator : "and" | "or", default: "and" + The logical operator to combine multiple metric thresholds. "and" means a unit must pass all thresholds to be + labeled as pass_label, while "or" means a unit must pass at least one threshold to be labeled as pass_label. + nan_policy : "fail" | "ignore", default: "fail" + Policy for handling NaN values in metrics. If "fail", units with NaN values in any metric will be labeled as + fail_label. If "ignore", NaN values will be ignored Returns ------- @@ -65,19 +73,57 @@ def threshold_metrics_label_units( f"Available metrics are: {metrics.columns.tolist()}" ) - # Initialize an empty DataFrame to store labels + if operator not in ("and", "or"): + raise ValueError("operator must be 'and' or 'or'") + + if nan_policy not in ("fail", "ignore"): + raise ValueError("nan_policy must be 'fail' or 'ignore'") + labels = pd.DataFrame(index=metrics.index, dtype=str) labels["label"] = fail_label - # Apply thresholds to label units - pass_mask = np.ones(len(metrics), dtype=bool) + # Key change: init depends on operator + pass_mask = np.ones(len(metrics), dtype=bool) if operator == "and" else np.zeros(len(metrics), dtype=bool) + any_threshold_applied = False + for metric_name, threshold in thresholds_dict.items(): min_value = threshold.get("min", None) max_value = threshold.get("max", None) + + # If both disabled, ignore this metric + if is_threshold_disabled(min_value) and is_threshold_disabled(max_value): + continue + + values = metrics[metric_name].to_numpy() + is_nan = np.isnan(values) + + metric_ok = np.ones(len(values), dtype=bool) if not is_threshold_disabled(min_value): - pass_mask &= metrics[metric_name] >= min_value + metric_ok &= values >= min_value if not is_threshold_disabled(max_value): - pass_mask &= metrics[metric_name] <= max_value + metric_ok &= values <= max_value + + metric_pass = np.ones(len(metrics), dtype=bool) + if not is_threshold_disabled(min_value): + metric_pass &= values >= min_value + if not is_threshold_disabled(max_value): + metric_pass &= values <= max_value + + # Handle NaNs + if nan_policy == "fail": + metric_ok &= ~is_nan + else: # "ignore" + metric_ok |= is_nan + + any_threshold_applied = True + + if operator == "and": + pass_mask &= metric_ok + else: + pass_mask |= metric_ok + + if not any_threshold_applied: + pass_mask[:] = True labels.loc[pass_mask, "label"] = pass_label return labels From 32189edf0d0416bf7d860b9b3e08f68289ce7f3a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Feb 2026 17:45:20 +0100 Subject: [PATCH 09/12] Fix imports --- .../tests/test_threshold_metrics_curation.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py index 46d11f6057..9e75dd7b98 100644 --- a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py +++ b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py @@ -2,7 +2,6 @@ import json import numpy as np -import pandas as pd from spikeinterface.curation.tests.common import sorting_analyzer_for_curation from spikeinterface.curation import threshold_metrics_label_units @@ -131,6 +130,8 @@ def test_threshold_metrics_label_external_labels(sorting_analyzer_with_metrics): def test_threshold_metrics_label_units_operator_or_with_dataframe(): + import pandas as pd + metrics = pd.DataFrame( { "m1": [1.0, 1.0, -1.0, -1.0], @@ -158,6 +159,8 @@ def test_threshold_metrics_label_units_operator_or_with_dataframe(): def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): + import pandas as pd + metrics = pd.DataFrame( { "m1": [np.nan, 1.0, np.nan], @@ -188,6 +191,8 @@ def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): def test_threshold_metrics_label_units_nan_policy_ignore_with_or(): + import pandas as pd + metrics = pd.DataFrame( { "m1": [np.nan, -1.0], @@ -209,6 +214,8 @@ def test_threshold_metrics_label_units_nan_policy_ignore_with_or(): def test_threshold_metrics_label_units_invalid_operator_raises(): + import pandas as pd + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) thresholds = {"m1": {"min": 0.0}} with pytest.raises(ValueError, match="operator must be 'and' or 'or'"): @@ -216,6 +223,8 @@ def test_threshold_metrics_label_units_invalid_operator_raises(): def test_threshold_metrics_label_units_invalid_nan_policy_raises(): + import pandas as pd + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) thresholds = {"m1": {"min": 0.0}} with pytest.raises(ValueError, match="nan_policy must be 'fail' or 'ignore'"): @@ -223,6 +232,8 @@ def test_threshold_metrics_label_units_invalid_nan_policy_raises(): def test_threshold_metrics_label_units_missing_metric_raises(): + import pandas as pd + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) thresholds = {"does_not_exist": {"min": 0.0}} with pytest.raises(ValueError, match="specified in thresholds are not present"): From 8625b31bb86fa5cf60fe05c241166843f7d57571 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Feb 2026 12:57:21 +0100 Subject: [PATCH 10/12] Add pass nan_policy --- .../tests/test_threshold_metrics_curation.py | 175 ++++++++---------- .../curation/threshold_metrics_curation.py | 37 ++-- 2 files changed, 97 insertions(+), 115 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py index 9e75dd7b98..82e0400b29 100644 --- a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py +++ b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py @@ -3,54 +3,41 @@ import numpy as np -from spikeinterface.curation.tests.common import sorting_analyzer_for_curation from spikeinterface.curation import threshold_metrics_label_units -@pytest.fixture -def sorting_analyzer_with_metrics(sorting_analyzer_for_curation): - """A sorting analyzer with computed quality metrics.""" - - sorting_analyzer = sorting_analyzer_for_curation - sorting_analyzer.compute("quality_metrics") - return sorting_analyzer - - -def test_threshold_metrics_label_units(sorting_analyzer_with_metrics): - """Test the `threshold_metrics_label_units` function.""" - sorting_analyzer = sorting_analyzer_with_metrics +def test_threshold_metrics_label_units_with_dataframe(): + import pandas as pd + metrics = pd.DataFrame( + { + "snr": [6.0, 4.0, 5.0], + "firing_rate": [0.5, 0.2, 25.0], + }, + index=[0, 1, 2], + ) thresholds = { "snr": {"min": 5.0}, "firing_rate": {"min": 0.1, "max": 20.0}, } - labels = threshold_metrics_label_units( - sorting_analyzer, - thresholds, - ) + labels = threshold_metrics_label_units(metrics, thresholds) assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) - - # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' - qm = sorting_analyzer.get_extension("quality_metrics").get_data() - for unit_id in sorting_analyzer.sorting.unit_ids: - snr = qm.loc[unit_id, "snr"] - firing_rate = qm.loc[unit_id, "firing_rate"] - if ( - snr >= thresholds["snr"]["min"] - and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] - ): - assert labels.loc[unit_id, "label"] == "good" - else: - assert labels.loc[unit_id, "label"] == "noise" - - -def test_threshold_metrics_label_units_with_file(sorting_analyzer_with_metrics, tmp_path): - """Test the `threshold_metrics_label_units` function with thresholds from a JSON file.""" - sorting_analyzer = sorting_analyzer_with_metrics + assert labels.shape[0] == len(metrics.index) + assert labels["label"].to_dict() == {0: "good", 1: "noise", 2: "noise"} + +def test_threshold_metrics_label_units_with_file(tmp_path): + import pandas as pd + + metrics = pd.DataFrame( + { + "snr": [6.0, 4.0], + "firing_rate": [0.5, 0.05], + }, + index=[0, 1], + ) thresholds = { "snr": {"min": 5.0}, "firing_rate": {"min": 0.1}, @@ -60,72 +47,32 @@ def test_threshold_metrics_label_units_with_file(sorting_analyzer_with_metrics, with open(thresholds_file, "w") as f: json.dump(thresholds, f) - labels = threshold_metrics_label_units( - sorting_analyzer, - thresholds_file, - ) + labels = threshold_metrics_label_units(metrics, thresholds_file) - assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) - - # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' - qm = sorting_analyzer.get_extension("quality_metrics").get_data() - for unit_id in sorting_analyzer.sorting.unit_ids: - snr = qm.loc[unit_id, "snr"] - firing_rate = qm.loc[unit_id, "firing_rate"] - if snr >= thresholds["snr"]["min"] and firing_rate >= thresholds["firing_rate"]["min"]: - assert labels.loc[unit_id, "label"] == "good" - else: - assert labels.loc[unit_id, "label"] == "noise" - - -def test_threshold_metrics_label_units_with_external_metrics(sorting_analyzer_with_metrics): - """Test the `threshold_metrics_label_units` function with external metrics DataFrame.""" - sorting_analyzer = sorting_analyzer_with_metrics - thresholds = { - "snr": {"min": 5.0}, - "firing_rate": {"min": 0.1, "max": 20.0}, - } + assert labels["label"].to_dict() == {0: "good", 1: "noise"} - qm = sorting_analyzer.get_extension("quality_metrics").get_data() - labels = threshold_metrics_label_units( - sorting_analyzer_or_metrics=qm, - thresholds=thresholds, - ) +def test_threshold_metrics_label_external_labels(): + import pandas as pd - assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) - - # Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise' - for unit_id in sorting_analyzer.sorting.unit_ids: - snr = qm.loc[unit_id, "snr"] - firing_rate = qm.loc[unit_id, "firing_rate"] - if ( - snr >= thresholds["snr"]["min"] - and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"] - ): - assert labels.loc[unit_id, "label"] == "good" - else: - assert labels.loc[unit_id, "label"] == "noise" - - -def test_threshold_metrics_label_external_labels(sorting_analyzer_with_metrics): - """Test the `threshold_metrics_label_units` function with custom pass/fail labels.""" - sorting_analyzer = sorting_analyzer_with_metrics + metrics = pd.DataFrame( + { + "snr": [6.0, 4.0], + "firing_rate": [0.5, 0.05], + }, + index=[0, 1], + ) thresholds = { "snr": {"min": 5.0}, - "firing_rate": {"min": 0.1, "max": 20.0}, + "firing_rate": {"min": 0.1}, } labels = threshold_metrics_label_units( - sorting_analyzer, + metrics, thresholds=thresholds, pass_label="accepted", fail_label="rejected", ) - assert "label" in labels.columns - assert labels.shape[0] == len(sorting_analyzer.sorting.unit_ids) assert set(labels["label"]).issubset({"accepted", "rejected"}) @@ -142,7 +89,7 @@ def test_threshold_metrics_label_units_operator_or_with_dataframe(): thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} labels_and = threshold_metrics_label_units( - sorting_analyzer_or_metrics=metrics, + metrics, thresholds=thresholds, operator="and", ) @@ -150,7 +97,7 @@ def test_threshold_metrics_label_units_operator_or_with_dataframe(): assert labels_and["label"].to_dict() == {0: "good", 1: "noise", 2: "noise", 3: "noise"} labels_or = threshold_metrics_label_units( - sorting_analyzer_or_metrics=metrics, + metrics, thresholds=thresholds, operator="or", ) @@ -171,7 +118,7 @@ def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} labels_fail = threshold_metrics_label_units( - sorting_analyzer_or_metrics=metrics, + metrics, thresholds=thresholds, operator="and", nan_policy="fail", @@ -179,7 +126,7 @@ def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): assert labels_fail["label"].to_dict() == {10: "noise", 11: "noise", 12: "noise"} labels_ignore = threshold_metrics_label_units( - sorting_analyzer_or_metrics=metrics, + metrics, thresholds=thresholds, operator="and", nan_policy="ignore", @@ -203,14 +150,48 @@ def test_threshold_metrics_label_units_nan_policy_ignore_with_or(): thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} labels_ignore_or = threshold_metrics_label_units( - sorting_analyzer_or_metrics=metrics, + metrics, thresholds=thresholds, operator="or", nan_policy="ignore", ) - # unit 20: m1 is NaN and ignored => passes that metric => good under "or" + # unit 20: m1 is NaN and ignored; m2 fails => noise # unit 21: both metrics fail => noise - assert labels_ignore_or["label"].to_dict() == {20: "good", 21: "noise"} + assert labels_ignore_or["label"].to_dict() == {20: "noise", 21: "noise"} + + +def test_threshold_metrics_label_units_nan_policy_pass_and_or(): + import pandas as pd + + metrics = pd.DataFrame( + { + "m1": [np.nan, np.nan, 1.0, -1.0], + "m2": [1.0, -1.0, np.nan, np.nan], + }, + index=[30, 31, 32, 33], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_and = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + operator="and", + nan_policy="pass", + ) + # unit 30: m1 NaN (pass), m2 pass => good + # unit 31: m1 NaN (pass), m2 fail => noise + # unit 32: m1 pass, m2 NaN (pass) => good + # unit 33: m1 fail, m2 NaN (pass) => noise + assert labels_and["label"].to_dict() == {30: "good", 31: "noise", 32: "good", 33: "noise"} + + labels_or = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + operator="or", + nan_policy="pass", + ) + # any NaN counts as pass => good unless all metrics fail without NaN + assert labels_or["label"].to_dict() == {30: "good", 31: "good", 32: "good", 33: "good"} def test_threshold_metrics_label_units_invalid_operator_raises(): @@ -227,7 +208,7 @@ def test_threshold_metrics_label_units_invalid_nan_policy_raises(): metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) thresholds = {"m1": {"min": 0.0}} - with pytest.raises(ValueError, match="nan_policy must be 'fail' or 'ignore'"): + with pytest.raises(ValueError, match="nan_policy must be"): threshold_metrics_label_units(metrics, thresholds, nan_policy="omit") diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index 19e6e104f4..daa8b138f1 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -9,7 +9,7 @@ def threshold_metrics_label_units( - sorting_analyzer_or_metrics: "SortingAnalyzer | pd.DataFrame", + metrics: "pd.DataFrame", thresholds: dict | str | Path, pass_label: str = "good", fail_label: str = "noise", @@ -20,9 +20,8 @@ def threshold_metrics_label_units( Parameters ---------- - sorting_analyzer_or_metrics : SortingAnalyzer | pd.DataFrame - The SortingAnalyzer object containing the some metrics extensions (e.g., quality metrics) or a DataFrame - containing unit metrics with unit IDs as index. + metrics : pd.DataFrame + A DataFrame containing unit metrics with unit IDs as index. thresholds : dict | str | Path A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values @@ -34,9 +33,12 @@ def threshold_metrics_label_units( operator : "and" | "or", default: "and" The logical operator to combine multiple metric thresholds. "and" means a unit must pass all thresholds to be labeled as pass_label, while "or" means a unit must pass at least one threshold to be labeled as pass_label. - nan_policy : "fail" | "ignore", default: "fail" + nan_policy : "fail" | "pass" | "ignore", default: "fail" Policy for handling NaN values in metrics. If "fail", units with NaN values in any metric will be labeled as - fail_label. If "ignore", NaN values will be ignored + fail_label. If "pass", units with NaN values in one metric will be labeled as pass_label. + If "ignore", NaN values will be ignored. Note that the "ignore" behavior will depend on the operator used. + If "and", NaNs will be treated as passing, since the initial mask is all true; + if "or", NaNs will be treated as failing, since the initial mask is all false. Returns ------- @@ -45,13 +47,8 @@ def threshold_metrics_label_units( """ import pandas as pd - if not isinstance(sorting_analyzer_or_metrics, (SortingAnalyzer, pd.DataFrame)): - raise ValueError("Only SortingAnalyzer or pd.DataFrame are supported for sorting_analyzer_or_metrics.") - - if isinstance(sorting_analyzer_or_metrics, SortingAnalyzer): - metrics = sorting_analyzer_or_metrics.get_metrics_extension_data() - else: - metrics = sorting_analyzer_or_metrics + if not isinstance(metrics, pd.DataFrame): + raise ValueError("Only pd.DataFrame is supported for metrics.") # Load thresholds from file if a path is provided if isinstance(thresholds, (str, Path)): @@ -76,8 +73,8 @@ def threshold_metrics_label_units( if operator not in ("and", "or"): raise ValueError("operator must be 'and' or 'or'") - if nan_policy not in ("fail", "ignore"): - raise ValueError("nan_policy must be 'fail' or 'ignore'") + if nan_policy not in ("fail", "pass", "ignore"): + raise ValueError("nan_policy must be 'fail', 'pass', or 'ignore'") labels = pd.DataFrame(index=metrics.index, dtype=str) labels["label"] = fail_label @@ -110,17 +107,21 @@ def threshold_metrics_label_units( metric_pass &= values <= max_value # Handle NaNs + nan_mask = slice(None) if nan_policy == "fail": metric_ok &= ~is_nan - else: # "ignore" + elif nan_policy == "pass": metric_ok |= is_nan + else: + # if nan_policy == "ignore", we only set values for non-nan entries + nan_mask = ~is_nan any_threshold_applied = True if operator == "and": - pass_mask &= metric_ok + pass_mask[nan_mask] &= metric_ok[nan_mask] else: - pass_mask |= metric_ok + pass_mask[nan_mask] |= metric_ok[nan_mask] if not any_threshold_applied: pass_mask[:] = True From 78c831a7f4c7e57df03fec04a06a568b75628314 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Feb 2026 20:21:56 +0100 Subject: [PATCH 11/12] Update src/spikeinterface/curation/threshold_metrics_curation.py Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/curation/threshold_metrics_curation.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index daa8b138f1..2186a58fe5 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -100,12 +100,6 @@ def threshold_metrics_label_units( if not is_threshold_disabled(max_value): metric_ok &= values <= max_value - metric_pass = np.ones(len(metrics), dtype=bool) - if not is_threshold_disabled(min_value): - metric_pass &= values >= min_value - if not is_threshold_disabled(max_value): - metric_pass &= values <= max_value - # Handle NaNs nan_mask = slice(None) if nan_policy == "fail": From 58c3b4dd640fee84f0057da1929ba324636fd89a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Feb 2026 20:22:06 +0100 Subject: [PATCH 12/12] Update src/spikeinterface/curation/threshold_metrics_curation.py Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/curation/threshold_metrics_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index 2186a58fe5..32a6a48d91 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -43,7 +43,7 @@ def threshold_metrics_label_units( Returns ------- labels : pd.DataFrame - A DataFrame with unit IDs as index and a column 'label' containing the assigned labels ('noise' or 'good'). + A DataFrame with unit IDs as index and a column 'label' containing the assigned labels (`fail_label` or `pass_label`) """ import pandas as pd