From 1db746561883383fe342117558a2e3a520b12513 Mon Sep 17 00:00:00 2001 From: Pascal <65159092+pas-calc@users.noreply.github.com> Date: Sun, 8 Feb 2026 18:05:38 +0100 Subject: [PATCH 1/2] Add load_templates_clusters parameter (kilosort output // phy/manual refinement) Added a new parameter 'load_templates_clusters' to control loading of templates or clusters from Kilosort output. Updated logic to handle loading based on the parameter value. --- .../extractors/phykilosortextractors.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 68b16074fb..4d8c9e06eb 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -37,6 +37,9 @@ class BasePhyKilosortSortingExtractor(BaseSorting): If True, empty units are removed from the sorting extractor. load_all_cluster_properties : bool, default: True If True, all cluster properties are loaded from the tsv/csv files. + load_templates_clusters : str, templates|clusters|auto , default: "auto" + Defines whether to load templates (kilosort output) or clusters (after manual refinement) + If "auto", try to load clusters, fallback to templates if not existing Notes ----- @@ -68,6 +71,7 @@ def __init__( keep_good_only: bool = False, remove_empty_units: bool = False, load_all_cluster_properties: bool = True, + load_templates_clusters="auto", ): try: import pandas as pd @@ -77,10 +81,17 @@ def __init__( phy_folder = Path(folder_path) spike_times = np.load(phy_folder / "spike_times.npy").astype(int) - if (phy_folder / "spike_clusters.npy").is_file(): + if load_templates_clusters=="auto": + if (phy_folder / "spike_clusters.npy").is_file(): + spike_clusters = np.load(phy_folder / "spike_clusters.npy") + else: + spike_clusters = np.load(phy_folder / "spike_templates.npy") + elif load_templates_clusters=="templates": + spike_clusters = np.load(phy_folder / "spike_templates.npy") + elif load_templates_clusters=="clusters": spike_clusters = np.load(phy_folder / "spike_clusters.npy") else: - spike_clusters = np.load(phy_folder / "spike_templates.npy") + raise ValueError("Invalid value provided for load_templates_clusters: '{}'.".format(load_templates_clusters)) # spike_times and spike_clusters can be 2d sometimes --> convert to 1d. spike_times = np.atleast_1d(spike_times.squeeze()) From d2a32f405491fa8245a3b1f76bb889560b92cf2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Feb 2026 17:07:43 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/extractors/phykilosortextractors.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 4d8c9e06eb..0401813e7c 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -81,17 +81,19 @@ def __init__( phy_folder = Path(folder_path) spike_times = np.load(phy_folder / "spike_times.npy").astype(int) - if load_templates_clusters=="auto": + if load_templates_clusters == "auto": if (phy_folder / "spike_clusters.npy").is_file(): spike_clusters = np.load(phy_folder / "spike_clusters.npy") else: spike_clusters = np.load(phy_folder / "spike_templates.npy") - elif load_templates_clusters=="templates": + elif load_templates_clusters == "templates": spike_clusters = np.load(phy_folder / "spike_templates.npy") - elif load_templates_clusters=="clusters": + elif load_templates_clusters == "clusters": spike_clusters = np.load(phy_folder / "spike_clusters.npy") else: - raise ValueError("Invalid value provided for load_templates_clusters: '{}'.".format(load_templates_clusters)) + raise ValueError( + "Invalid value provided for load_templates_clusters: '{}'.".format(load_templates_clusters) + ) # spike_times and spike_clusters can be 2d sometimes --> convert to 1d. spike_times = np.atleast_1d(spike_times.squeeze())