diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 68b16074fb..0401813e7c 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -37,6 +37,9 @@ class BasePhyKilosortSortingExtractor(BaseSorting): If True, empty units are removed from the sorting extractor. load_all_cluster_properties : bool, default: True If True, all cluster properties are loaded from the tsv/csv files. + load_templates_clusters : str, templates|clusters|auto , default: "auto" + Defines whether to load templates (kilosort output) or clusters (after manual refinement) + If "auto", try to load clusters, fallback to templates if not existing Notes ----- @@ -68,6 +71,7 @@ def __init__( keep_good_only: bool = False, remove_empty_units: bool = False, load_all_cluster_properties: bool = True, + load_templates_clusters="auto", ): try: import pandas as pd @@ -77,10 +81,19 @@ def __init__( phy_folder = Path(folder_path) spike_times = np.load(phy_folder / "spike_times.npy").astype(int) - if (phy_folder / "spike_clusters.npy").is_file(): + if load_templates_clusters == "auto": + if (phy_folder / "spike_clusters.npy").is_file(): + spike_clusters = np.load(phy_folder / "spike_clusters.npy") + else: + spike_clusters = np.load(phy_folder / "spike_templates.npy") + elif load_templates_clusters == "templates": + spike_clusters = np.load(phy_folder / "spike_templates.npy") + elif load_templates_clusters == "clusters": spike_clusters = np.load(phy_folder / "spike_clusters.npy") else: - spike_clusters = np.load(phy_folder / "spike_templates.npy") + raise ValueError( + "Invalid value provided for load_templates_clusters: '{}'.".format(load_templates_clusters) + ) # spike_times and spike_clusters can be 2d sometimes --> convert to 1d. spike_times = np.atleast_1d(spike_times.squeeze())