diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index ba9fa53a51..997eac878a 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -13,7 +13,6 @@ from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel class ClusteringBenchmark(Benchmark): diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index be1cf18fbf..eae0bf0e59 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -5,8 +5,6 @@ from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.benchmark.benchmark_clustering import ClusteringStudy -from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel from pathlib import Path @@ -33,7 +31,8 @@ def test_benchmark_clustering(create_cache_folder): # sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False) # sorting_analyzer.compute(["random_spikes", "templates"]) - extremum_channel_inds = get_template_extremum_channel(gt_analyzer, outputs="index") + extremum_channel_inds = gt_analyzer.get_main_channels(outputs="index", with_dict=True) + spikes = gt_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py index b9207caaa3..82a51e8292 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py @@ -6,7 +6,6 @@ from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel @pytest.mark.skip() @@ -30,7 +29,7 @@ def test_benchmark_peak_detection(create_cache_folder): sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("templates", **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index e17731c70e..fd8476e2b6 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -15,6 +15,9 @@ class BaseSorting(BaseExtractor): """ Abstract class representing several segment several units and relative spiketrains. """ + _main_properties = [ + "main_channel_index", + ] def __init__(self, sampling_frequency: float, unit_ids: list): BaseExtractor.__init__(self, unit_ids) @@ -786,6 +789,7 @@ def _compute_and_cache_spike_vector(self) -> None: self._cached_spike_vector = spikes self._cached_spike_vector_segment_slices = segment_slices + # TODO sam : change extremum_channel_inds to main_channel_index with vector def to_spike_vector( self, concatenated=True, @@ -806,7 +810,8 @@ def to_spike_vector( extremum_channel_inds : None or dict, default: None If a dictionnary of unit_id to channel_ind is given then an extra field "channel_index". This can be convinient for computing spikes postion after sorter. - This dict can be computed with `get_template_extremum_channel(we, outputs="index")` + This dict can be given by analyzer.get_main_channels(outputs="index", with_dict=True) + use_cache : bool, default: True When True the spikes vector is cached as an attribute of the object (`_cached_spike_vector`). This caching only occurs when extremum_channel_inds=None. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 5d7ca1917a..454540663a 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2444,6 +2444,10 @@ def generate_ground_truth_recording( **generate_templates_kwargs, ) sorting.set_property("gt_unit_locations", unit_locations) + distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :], axis=2) + main_channel_index = np.argmin(distances, axis=1) + sorting.set_property("main_channel_index", main_channel_index) + else: assert templates.shape[0] == num_units diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1609f11d17..01340f2e68 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -132,6 +132,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea return (local_peaks,) +# TODO sam replace extremum_channels_indices by main_channel_index + # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): """ diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8de45210cd..f1fe4b81c9 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -45,6 +45,10 @@ from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor, super_zarr_open from .node_pipeline import run_node_pipeline +from .waveform_tools import estimate_templates_with_accumulator +from .sorting_tools import random_spikes_selection + + # high level function def create_sorting_analyzer( @@ -52,6 +56,10 @@ def create_sorting_analyzer( recording, format="memory", folder=None, + main_channel_index=None, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", + num_spikes_for_main_channel=100, sparse=True, sparsity=None, set_sparsity_by_dict_key=False, @@ -59,7 +67,9 @@ def create_sorting_analyzer( return_in_uV=True, overwrite=False, backend_options=None, - **sparsity_kwargs, + sparsity_kwargs=None, + seed=None, + **job_kwargs ) -> "SortingAnalyzer": """ Create a SortingAnalyzer by pairing a Sorting and the corresponding Recording. @@ -69,6 +79,11 @@ def create_sorting_analyzer( This object will be also use used for plotting purpose. + The main_channel_index can be externally provided. If not then this is taken from + sorting property. If not then the main_channel_index is estimated using + `estimate_templates_with_accumulator()` which is fast and parallel but need to traverse + the recording. + Parameters ---------- @@ -82,6 +97,17 @@ def create_sorting_analyzer( The mode to store analyzer. If "folder", the analyzer is stored on disk in the specified folder. The "folder" argument must be specified in case of mode "folder". If "memory" is used, the analyzer is stored in RAM. Use this option carefully! + main_channel_index : None | np.array + The main_channel_index can be externally provided + main_channel_peak_sign : "both" | "neg" + In case when the main_channel_index is estimated wich sign to consider "both" or "neg". + main_channel_peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum" : take the peak value (max or min depending on `peak_sign`) + * "at_index" : take value at `nbefore` index + * "peak_to_peak" : take the peak-to-peak amplitude + num_spikes_for_main_channel : int, default: 100 + How many spikes per units to compute the main channel. sparse : bool, default: True If True, then a sparsity mask is computed using the `estimate_sparsity()` function using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. @@ -107,8 +133,8 @@ def create_sorting_analyzer( * storage_options: dict | None (fsspec storage options) * saving_options: dict | None (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) - - sparsity_kwargs : keyword arguments + sparsity_kwargs : dict | None + Dict given to estimate the sparsity. Returns ------- @@ -144,6 +170,9 @@ def create_sorting_analyzer( sparsity off (or give external sparsity) like this. """ + if sparsity_kwargs is None: + sparsity_kwargs = dict() + if isinstance(sorting, dict) and isinstance(recording, dict): if sorting.keys() != recording.keys(): @@ -168,9 +197,14 @@ def create_sorting_analyzer( return_in_uV=return_in_uV, overwrite=overwrite, backend_options=backend_options, - **sparsity_kwargs, + sparsity_kwargs=sparsity_kwargs, + **job_kwargs ) + # normal case + + + if format != "memory": if format == "zarr": if not is_path_remote(folder): @@ -182,6 +216,26 @@ def create_sorting_analyzer( else: shutil.rmtree(folder) + + + # retrieve or compute the main channel index per unit + if main_channel_index is None: + if "main_channel_index" in sorting.get_property_keys(): + main_channel_index = sorting.get_property("main_channel_index") + + if main_channel_index is None: + # this is weird but due to the cyclic import + from .template_tools import estimate_main_channel_from_recording + main_channel_index = estimate_main_channel_from_recording( + recording, + sorting, + main_channel_peak_sign=main_channel_peak_sign, + peak_mode=main_channel_peak_mode, + num_spikes_for_main_channel=num_spikes_for_main_channel, + seed=seed, + **job_kwargs + ) + # handle sparsity if sparsity is not None: # some checks @@ -192,8 +246,9 @@ def create_sorting_analyzer( assert np.array_equal( recording.channel_ids, sparsity.channel_ids ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" + assert all(sparsity.mask[u, c] for u, c in enumerate(main_channel_index)), "sparsity si not constistentent with main_channel_index" elif sparse: - sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs) + sparsity = estimate_sparsity(sorting, recording, main_channel_index=main_channel_index, **sparsity_kwargs) else: sparsity = None @@ -215,6 +270,9 @@ def create_sorting_analyzer( recording, format=format, folder=folder, + main_channel_index=main_channel_index, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, sparsity=sparsity, return_in_uV=return_in_uV, backend_options=backend_options, @@ -284,6 +342,8 @@ def __init__( format: str | None = None, sparsity: ChannelSparsity | None = None, return_in_uV: bool = True, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", backend_options: dict | None = None, ): # very fast init because checks are done in load and create @@ -294,6 +354,8 @@ def __init__( self.format = format self.sparsity = sparsity self.return_in_uV = return_in_uV + self.main_channel_peak_sign = main_channel_peak_sign + self.main_channel_peak_mode = main_channel_peak_mode # For backward compatibility self.return_scaled = return_in_uV @@ -347,9 +409,12 @@ def create( "zarr", ] = "memory", folder=None, + main_channel_index=None, sparsity=None, return_scaled=None, return_in_uV=True, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", backend_options=None, ): assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" @@ -381,9 +446,16 @@ def create( from spikeinterface.curation.remove_excess_spikes import RemoveExcessSpikesSorting sorting = RemoveExcessSpikesSorting(sorting=sorting, recording=recording) - + + # This will ensure that the sorting saved always will have this main_channel + assert main_channel_index is not None + sorting.set_property("main_channel_index", main_channel_index) + if format == "memory": - sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, rec_attributes=None) + sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes=None) elif format == "binary_folder": sorting_analyzer = cls.create_binary_folder( folder, @@ -391,6 +463,8 @@ def create( recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, rec_attributes=None, backend_options=backend_options, ) @@ -404,6 +478,8 @@ def create( recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, rec_attributes=None, backend_options=backend_options, ) @@ -442,7 +518,10 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe return sorting_analyzer @classmethod - def create_memory(cls, sorting, recording, sparsity, return_in_uV, rec_attributes): + def create_memory(cls, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes): # used by create and save_as if rec_attributes is None: @@ -463,11 +542,18 @@ def create_memory(cls, sorting, recording, sparsity, return_in_uV, rec_attribute format="memory", sparsity=sparsity, return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, + ) return sorting_analyzer @classmethod - def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options): + def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + + rec_attributes, backend_options): # used by create and save_as folder = Path(folder) @@ -531,16 +617,51 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV settings_file = folder / f"settings.json" settings = dict( return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, ) with open(settings_file, mode="w") as f: json.dump(check_json(settings), f, indent=4) return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options) + @classmethod + def _handle_backward_compatibility(cls, settings, sorting, sparsity): + # backward compatibility at analyzer level + # (there is also something similar at extension level) + + new_settings = dict() + new_settings.update(settings) + if "return_scaled" in settings: + new_settings["return_in_uV"] = new_settings.pop("return_scaled") + elif "return_in_uV" in settings: + pass + else: + # old version did not have settings at all + new_settings["return_in_uV"] = True + + retrospect_main_channel_index = None + if "main_channel_peak_sign" not in settings: + # before 0.104.0 was not in main_channel_peak_sign + # TODO make something more fancy that exlore the previous params of extension + new_settings["main_channel_peak_sign"] = "both" + new_settings["main_channel_peak_mode"] = "extremum" + + if "main_channel_index" not in sorting.get_property_keys(): + # TODO + raise NotImplementedError("backward compatibility with main_channel_index is not implemented yet") + + if retrospect_main_channel_index is not None: + sorting.set_property("main_channel_index", retrospect_main_channel_index) + + return new_settings + @classmethod def load_from_binary_folder(cls, folder, recording=None, backend_options=None): from .loading import load + # TODO check that sorting has main_channel_index and ensure backward compatibility + folder = Path(folder) assert folder.is_dir(), f"This folder does not exists {folder}" @@ -596,13 +717,19 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): if settings_file.exists(): with open(settings_file, "r") as f: settings = json.load(f) + need_to_create = False else: + need_to_create = True + settings = dict() + + settings = cls._handle_backward_compatibility(settings, sorting, sparsity) + + if need_to_create: warnings.warn("settings.json not found for this folder writing one with return_in_uV=True") - settings = dict(return_in_uV=True) with open(settings_file, "w") as f: json.dump(check_json(settings), f, indent=4) - return_in_uV = settings.get("return_in_uV", settings.get("return_scaled", True)) + sorting_analyzer = SortingAnalyzer( sorting=sorting, @@ -610,7 +737,9 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): rec_attributes=rec_attributes, format="binary_folder", sparsity=sparsity, - return_in_uV=return_in_uV, + return_in_uV = settings["return_in_uV"], + main_channel_peak_sign = settings["main_channel_peak_sign"], + main_channel_peak_mode = settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -625,7 +754,11 @@ def _get_zarr_root(self, mode="r+"): return zarr_root @classmethod - def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options): + def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + + rec_attributes, backend_options): # used by create and save_as import zarr import numcodecs @@ -649,7 +782,11 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) - settings = dict(return_in_uV=return_in_uV) + settings = dict( + return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, + ) zarr_root.attrs["settings"] = check_json(settings) # the recording @@ -713,6 +850,8 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): import zarr from .loading import load + # TODO check that sorting has main_channel_index and ensure backward compatibility + backend_options = {} if backend_options is None else backend_options storage_options = backend_options.get("storage_options", {}) @@ -768,10 +907,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): ) else: sparsity = None - - return_in_uV = zarr_root.attrs["settings"].get( - "return_in_uV", zarr_root.attrs["settings"].get("return_scaled", True) - ) + + settings = zarr_root.attrs["settings"] + settings = cls._handle_backward_compatibility(settings, sorting, sparsity) sorting_analyzer = SortingAnalyzer( sorting=sorting, @@ -779,7 +917,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): rec_attributes=rec_attributes, format="zarr", sparsity=sparsity, - return_in_uV=return_in_uV, + return_in_uV = settings["return_in_uV"], + main_channel_peak_sign = settings["main_channel_peak_sign"], + main_channel_peak_mode = settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -881,6 +1021,24 @@ def get_sorting_property(self, key: str, ids: Optional[Iterable] = None) -> np.n Array of values for the property """ return self.sorting.get_property(key, ids=ids) + + def get_main_channels(self, outputs="index", with_dict=False): + """ + + """ + main_channel_index = self.get_sorting_property("main_channel_index") + if outputs is "index": + main_chans = main_channel_index + elif outputs is "id": + main_chans = self.channel_ids[main_channel_index] + else: + raise ValueError("wrong outputs") + + if with_dict: + return dict(zip(self.unit_ids, main_chans)) + else: + return main_chans + def are_units_mergeable( self, @@ -1103,7 +1261,10 @@ def _save_or_select_or_merge_or_split( if format == "memory": # This make a copy of actual SortingAnalyzer new_sorting_analyzer = SortingAnalyzer.create_memory( - sorting_provenance, recording, sparsity, self.return_in_uV, self.rec_attributes + sorting_provenance, recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, + self.rec_attributes ) elif format == "binary_folder": @@ -1116,6 +1277,9 @@ def _save_or_select_or_merge_or_split( recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, + self.rec_attributes, backend_options=backend_options, ) @@ -1129,6 +1293,8 @@ def _save_or_select_or_merge_or_split( recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, self.rec_attributes, backend_options=backend_options, ) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index ee19601068..e82d5f19e9 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -29,8 +29,6 @@ In this case the sparsity for each unit is given by the channels that have the same property value as the unit. Use the "by_property" argument to specify the property name. - peak_sign : "neg" | "pos" | "both" - Sign of the template to compute best channels. num_channels : int Number of channels for "best_channels" method. radius_um : float @@ -83,14 +81,14 @@ class ChannelSparsity: Using the N best channels (largest template amplitude): - >>> sparsity = ChannelSparsity.from_best_channels(sorting_analyzer, num_channels, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_best_channels(sorting_analyzer, num_channels) Using a neighborhood by radius: - >>> sparsity = ChannelSparsity.from_radius(sorting_analyzer, radius_um, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_radius(sorting_analyzer, radius_um) Using a SNR threshold: - >>> sparsity = ChannelSparsity.from_snr(sorting_analyzer, threshold, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_snr(sorting_analyzer, threshold) Using a template energy threshold: >>> sparsity = ChannelSparsity.from_energy(sorting_analyzer, threshold) @@ -365,9 +363,39 @@ def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels): return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): + def from_radius_and_main_channel(cls, unit_ids, channel_ids, main_channel_index, channel_locations, radius_um): """ - Construct sparsity from a radius around the best channel. + Construct sparsity from a radius around the main channel. + Use the "radius_um" argument to specify the radius in um. + + Parameters + ---------- + main_channel_index : np.array + Main channel index per units. + channel_locations : np.array + Channel locations of the recording. + radius_um : float + Radius in um for "radius" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. + """ + mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool") + distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) + for unit_ind, main_chan in enumerate(main_channel_index): + (chan_inds,) = np.nonzero(distances[main_chan, :] <= radius_um) + mask[unit_ind, chan_inds] = True + return cls(mask, unit_ids, channel_ids) + + + @classmethod + def from_radius(cls, templates_or_sorting_analyzer, radius_um): + """ + Construct sparsity from a radius around the main channel. Use the "radius_um" argument to specify the radius in um. Parameters @@ -384,19 +412,14 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): sparsity : ChannelSparsity The estimated sparsity. """ - from .template_tools import get_template_extremum_channel - - mask = np.zeros( - (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" - ) + main_channel_index = templates_or_sorting_analyzer.get_main_channels(outputs="index") channel_locations = templates_or_sorting_analyzer.get_channel_locations() - distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) - best_chan = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, outputs="index") - for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): - chan_ind = best_chan[unit_id] - (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) - mask[unit_ind, chan_inds] = True - return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) + return cls.from_radius_and_main_channel( + templates_or_sorting_analyzer.unit_ids, + templates_or_sorting_analyzer.channel_ids, + main_channel_index, + channel_locations, + radius_um) @classmethod def from_snr( @@ -724,6 +747,7 @@ def estimate_sparsity( amplitude_mode: "extremum" | "peak_to_peak" = "extremum", by_property: str | None = None, noise_levels: np.ndarray | list | None = None, + main_channel_index: np.ndarray | list | None = None, **job_kwargs, ): """ @@ -732,11 +756,10 @@ def estimate_sparsity( For the "snr" method, `noise_levels` must passed with the `noise_levels` argument. These can be computed with the `get_noise_levels()` function. - Contrary to the previous implementation: - * all units are computed in one read of recording - * it doesn't require a folder - * it doesn't consume too much memory - * it uses internally the `estimate_templates_with_accumulator()` which is fast and parallel + If main_channel_index is given and method="radius" then there is not need estimate + the templates otherwise the templates must be estimated using + `estimate_templates_with_accumulator()` which is fast and parallel but need to traverse + the recording. Note that the "energy" method is not supported because it requires a `SortingAnalyzer` object. @@ -755,6 +778,9 @@ def estimate_sparsity( noise_levels : np.array | None, default: None Noise levels required for the "snr" and "energy" methods. You can use the `get_noise_levels()` function to compute them. + main_channel_index : np.array | None, default: None + Main channel indicies can be provided in case of method="radius", this avoid the + `estimate_templates_with_accumulator()` which is slow. {} Returns @@ -779,7 +805,14 @@ def estimate_sparsity( chan_locs = recording.get_channel_locations() probe = recording.create_dummy_probe_from_locations(chan_locs) - if method != "by_property": + if method == "radius" and main_channel_index is not None: + assert main_channel_index.size == sorting.unit_ids.size + chan_locs = recording.get_channel_locations() + sparsity = ChannelSparsity.from_radius_and_main_channel( + sorting.unit_ids, recording.channel_ids, main_channel_index, chan_locs, radius_um + ) + + elif method != "by_property": nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 91d25bece6..4aac641372 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -483,3 +483,27 @@ def get_channel_locations(self) -> np.ndarray: assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set" channel_locations = self.probe.contact_positions return channel_locations + + def get_main_channels(self, + main_channel_peak_sign: "neg" | "both" | "pos" = "both", + mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + outputs="index", + with_dict=True + ): + from .template_tools import _get_main_channel_from_template_array + + templates_array = self.get_dense_templates() + main_channel_index = _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, self.nbefore) + + if outputs is "index": + main_chans = main_channel_index + elif outputs is "id": + main_chans = self.channel_ids[main_channel_index] + else: + raise ValueError("wrong outputs") + + if with_dict: + return dict(zip(self.unit_ids, main_chans)) + else: + return main_chans + diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index ecc878e1f4..584bfaee7e 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -2,8 +2,13 @@ import numpy as np from .template import Templates +from .waveform_tools import estimate_templates_with_accumulator +from .sorting_tools import random_spikes_selection from .sortinganalyzer import SortingAnalyzer +import warnings + + def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_in_uV: bool = True): """ @@ -126,6 +131,86 @@ def get_template_amplitudes( return peak_values + +def _get_main_channel_from_template_array(templates_array, peak_mode, main_channel_peak_sign, nbefore): + # Step1 : max on time axis + if peak_mode == "extremum": + if main_channel_peak_sign == "both": + values = np.max(np.abs(templates_array), axis=1) + elif main_channel_peak_sign == "neg": + values = -np.min(templates_array, axis=1) + elif main_channel_peak_sign == "pos": + values = np.max(templates_array, axis=1) + elif peak_mode == "at_index": + if main_channel_peak_sign == "both": + values = np.abs(templates_array[:, nbefore, :]) + elif main_channel_peak_sign in ["neg", "pos"]: + values = templates_array[:, nbefore, :] + elif peak_mode == "peak_to_peak": + values = np.ptp(templates_array, axis=1) + + # Step2: max on channel axis + main_channel_index = np.argmax(values, axis=1) + + return main_channel_index + +def estimate_main_channel_from_recording( + recording, + sorting, + main_channel_peak_sign: "neg" | "both" | "pos" = "both", + peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + num_spikes_for_main_channel=100, + ms_before = 1.0, + ms_after = 2.5, + seed=None, + **job_kwargs +): + """ + Estimate the main channel from recording using `estimate_templates_with_accumulator()` + + """ + + if main_channel_peak_sign == "pos": + warnings.warn( + "estimate_main_channel_from_recording() with peak_sign='pos' is a strange case maybe you " \ + "should revert the traces instead" + ) + + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + random_spikes_indices = random_spikes_selection( + sorting, + num_samples, + method="uniform", + max_spikes_per_unit=num_spikes_for_main_channel, + margin_size=max(nbefore, nafter), + seed=seed, + ) + spikes = sorting.to_spike_vector() + spikes = spikes[random_spikes_indices] + + templates_array = estimate_templates_with_accumulator( + recording, + spikes, + sorting.unit_ids, + nbefore, + nafter, + return_in_uV=False, + job_name="estimate_main_channel", + **job_kwargs, + ) + + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, main_channel_peak_sign, nbefore) + + return main_channel_index + + + + + def get_template_extremum_channel( templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", @@ -156,6 +241,9 @@ def get_template_extremum_channel( Dictionary with unit ids as keys and extremum channels (id or index based on "outputs") as values """ + warnings.warn("get_template_extremum_channel() is deprecated use analyzer.get_main_channels() instead") + # TODO make a better logic here + assert peak_sign in ("both", "neg", "pos"), "`peak_sign` must be one of `both`, `neg`, or `pos`" assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" assert outputs in ("id", "index"), "`outputs` must be either `id` or `index`" diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 4f8e600a3f..7a29a3cee6 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -3,7 +3,7 @@ from pathlib import Path import shutil -from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording +from spikeinterface import create_sorting_analyzer, generate_ground_truth_recording from spikeinterface.core.base import spike_peak_dtype from spikeinterface.core.job_tools import divide_recording_into_chunks @@ -80,7 +80,7 @@ def test_run_node_pipeline(cache_folder_creation): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) @@ -202,7 +202,7 @@ def test_skip_after_n_peaks_and_recording_slices(): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..765d5bec7e 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -52,6 +52,12 @@ def dataset(): def test_SortingAnalyzer_memory(tmp_path, dataset): recording, sorting = dataset + + # Note the sorting contain already main_channel_index + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + assert np.array_equal(sorting_analyzer.get_main_channels() , sorting.get_property("main_channel_index")) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) @@ -75,6 +81,16 @@ def test_SortingAnalyzer_memory(tmp_path, dataset): assert "quality" in sorting_analyzer.sorting.get_property_keys() assert "number" in sorting_analyzer.sorting.get_property_keys() + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + + # Create when main_channel_index is not given : this is estimated + sorting2 = sorting.clone() + sorting2._properties.pop("main_channel_index") + print(sorting2.get_property("main_channel_index")) + sorting_analyzer = create_sorting_analyzer(sorting2, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting2, cache_folder=tmp_path) + def test_SortingAnalyzer_binary_folder(tmp_path, dataset): recording, sorting = dataset @@ -349,7 +365,6 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): assert ext is None assert sorting_analyzer.has_recording() - # save to several format for format in ("memory", "binary_folder", "zarr"): if format != "memory": @@ -615,12 +630,11 @@ def _set_params(self, param0=5.5): return params def _get_pipeline_nodes(self): - from spikeinterface.core.template_tools import get_template_extremum_channel recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - extremum_channel_inds = get_template_extremum_channel(self.sorting_analyzer, outputs="index") + extremum_channel_inds = self.sorting_analyzer.get_main_channels( outputs="index", with_dict=True) spike_retriever_node = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) @@ -718,9 +732,9 @@ def test_runtime_dependencies(dataset): tmp_path = Path("test_SortingAnalyzer") dataset = get_dataset() test_SortingAnalyzer_memory(tmp_path, dataset) - test_SortingAnalyzer_binary_folder(tmp_path, dataset) - test_SortingAnalyzer_zarr(tmp_path, dataset) - test_SortingAnalyzer_tmp_recording(dataset) - test_extension() - test_extension_params() - test_runtime_dependencies() + # test_SortingAnalyzer_binary_folder(tmp_path, dataset) + # test_SortingAnalyzer_zarr(tmp_path, dataset) + # test_SortingAnalyzer_tmp_recording(dataset) + # test_extension() + # test_extension_params() + # test_runtime_dependencies(dataset) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index c865068e4a..b013c9ca90 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -285,7 +285,7 @@ def test_compute_sparsity(): # using object SortingAnalyzer sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") + sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0) sparsity = compute_sparsity(sorting_analyzer, method="closest_channels", num_channels=2) sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") sparsity = compute_sparsity( @@ -299,7 +299,7 @@ def test_compute_sparsity(): templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") + sparsity = compute_sparsity(templates, method="radius", radius_um=50.0) sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(templates, method="closest_channels", num_channels=2) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index fe9fb3ba52..66b71a1b1c 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -6,7 +6,7 @@ from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs import spikeinterface.widgets as sw -from spikeinterface.core import get_template_extremum_channel, get_template_extremum_amplitude +from spikeinterface.core import get_template_extremum_amplitude from spikeinterface.postprocessing import compute_correlograms @@ -101,9 +101,10 @@ def export_report( # unit list units = pd.DataFrame(index=unit_ids) #  , columns=['max_on_channel_id', 'amplitude']) units.index.name = "unit_id" - units["max_on_channel_id"] = pd.Series( - get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="id") - ) + # max_on_channel_id is kept (oold name) + units["max_on_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) + units["main_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) + units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign)) units.to_csv(output_folder / "unit list.csv", sep="\t") diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 6559e89d52..9a0d5847fb 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -10,7 +10,6 @@ from spikeinterface.core import SortingAnalyzer, BaseRecording, get_random_data_chunks from spikeinterface.core.job_tools import fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.exporters import export_to_phy @@ -102,7 +101,7 @@ def export_to_ibl_gui( output_folder.mkdir(parents=True, exist_ok=True) ### Save spikes info ### - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_indices) # spikes.clusters @@ -137,7 +136,8 @@ def export_to_ibl_gui( np.save(output_folder / "clusters.waveforms.npy", templates) # cluster channels - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + cluster_channels = np.array(list(extremum_channel_indices.values()), dtype="int32") np.save(output_folder / "clusters.channels.npy", cluster_channels) diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index bbab9262af..406fbff18e 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -13,7 +13,6 @@ InjectTemplatesRecording, _ensure_seed, ) -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.motion import Motion @@ -128,8 +127,8 @@ def select_templates( min_amplitude is not None or max_amplitude is not None or min_depth is not None or max_depth is not None ), "At least one of min_amplitude, max_amplitude, min_depth, max_depth should be provided" # get template amplitudes and depth - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) + mask = np.ones(templates.num_units, dtype=bool) if min_amplitude is not None or max_amplitude is not None: @@ -143,7 +142,7 @@ def select_templates( amplitudes = np.zeros(templates.num_units) templates_array = templates.templates_array for i in range(templates.num_units): - amplitudes[i] = amp_fun(templates_array[i, :, extremum_channel_indices[i]]) + amplitudes[i] = amp_fun(templates_array[i, :, main_channel_indices[i]]) if min_amplitude is not None: mask &= amplitudes >= min_amplitude if max_amplitude is not None: @@ -152,7 +151,7 @@ def select_templates( assert templates.probe is not None, "Templates should have a probe to filter based on depth" depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] - unit_depths = channel_depths[extremum_channel_indices] + unit_depths = channel_depths[main_channel_indices] if min_depth is not None: mask &= unit_depths >= min_depth if max_depth is not None: @@ -191,8 +190,7 @@ def scale_template_to_range( Templates The scaled templates. """ - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) # get amplitudes if amplitude_function == "ptp": @@ -204,7 +202,7 @@ def scale_template_to_range( amplitudes = np.zeros(templates.num_units) templates_array = templates.templates_array for i in range(templates.num_units): - amplitudes[i] = amp_fun(templates_array[i, :, extremum_channel_indices[i]]) + amplitudes[i] = amp_fun(templates_array[i, :, main_channel_indices[i]]) # scale templates to meet min_amplitude and max_amplitude range min_scale = np.min(amplitudes) / min_amplitude @@ -265,11 +263,10 @@ def relocate_templates( """ seed = _ensure_seed(seed) - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] - unit_depths = channel_depths[extremum_channel_indices] + unit_depths = channel_depths[main_channel_indices] assert margin >= 0, "margin should be positive" top_margin = np.max(channel_depths) + margin diff --git a/src/spikeinterface/generation/splitting_tools.py b/src/spikeinterface/generation/splitting_tools.py index 1f404ea3f7..03d7cd5ef8 100644 --- a/src/spikeinterface/generation/splitting_tools.py +++ b/src/spikeinterface/generation/splitting_tools.py @@ -107,9 +107,9 @@ def split_sorting_by_amplitudes( rng = np.random.default_rng(seed) fs = sorting_analyzer.sampling_frequency - from spikeinterface.core.template_tools import get_template_extremum_channel - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, concatenated=False) new_spikes = spikes[0].copy() amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 3d26f7a85e..128a509e80 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -20,7 +20,6 @@ from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.core import SortingAnalyzer, get_noise_levels, NumpySorting from spikeinterface.core.template_tools import ( - get_template_extremum_channel, get_template_extremum_amplitude, get_dense_templates_array, ) @@ -182,16 +181,13 @@ def compute_snrs( channel_ids = sorting_analyzer.channel_ids - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) + main_channel_index = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) unit_amplitudes = get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) - # make a dict to access by chan_id - noise_levels = dict(zip(channel_ids, noise_levels)) - snrs = {} for unit_id in unit_ids: - chan_id = extremum_channels_ids[unit_id] - noise = noise_levels[chan_id] + chan_ind = main_channel_index[unit_id] + noise = noise_levels[chan_ind] amplitude = unit_amplitudes[unit_id] snrs[unit_id] = np.abs(amplitude) / noise @@ -1294,7 +1290,8 @@ def compute_sd_ratio( noise_levels = get_noise_levels( sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs ) - best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", **kwargs) + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) if correct_for_template_itself: @@ -1330,7 +1327,7 @@ def compute_sd_ratio( else: unit_std = np.std(spk_amp) - best_channel = best_channels[unit_id] + best_channel = main_channels[unit_id] std_noise = noise_levels[best_channel] n_samples = sorting_analyzer.get_total_samples() diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index 5476aa405a..405a31d068 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -5,7 +5,6 @@ import warnings import numpy as np -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension @@ -137,7 +136,8 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): all_labels = sorting_analyzer.sorting.unit_ids[spike_unit_indices] # Get extremum channels for neighbor selection in sparse mode - extremum_channels = get_template_extremum_channel(sorting_analyzer) + + main_channels = sorting_analyzer.get_main_channels(outputs="id", with_dict=True) # Pre-compute spike counts and firing rates if advanced NN metrics are requested advanced_nn_metrics = ["nn_advanced"] # Our grouped advanced NN metric @@ -152,7 +152,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): if sorting_analyzer.is_sparse(): neighbor_channel_ids = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] neighbor_unit_ids = [ - other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids + other_unit for other_unit in unit_ids if main_channels[other_unit] in neighbor_channel_ids ] neighbor_channel_indices = sorting_analyzer.channel_ids_to_indices(neighbor_channel_ids) else: diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 85ef9e22cb..ed7ded47e6 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -12,7 +12,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension -from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array +from spikeinterface.core.template_tools import get_dense_templates_array from .metrics import get_trough_and_peak_idx, single_channel_metrics, multi_channel_metrics @@ -189,7 +189,8 @@ def _prepare_data(self, sorting_analyzer, unit_ids): m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"] ) - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True) channel_locations = sorting_analyzer.get_channel_locations() diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 473798fe7c..dfef0e3bc6 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -3,7 +3,7 @@ import numpy as np from spikeinterface.core import ChannelSparsity -from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore +from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -104,10 +104,8 @@ def _get_pipeline_nodes(self): else: cut_out_after = nafter - peak_sign = "neg" if np.abs(np.min(all_templates)) > np.max(all_templates) else "pos" - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + # collisions handle_collisions = self.params["handle_collisions"] diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 671b9bb239..fb9c89c521 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -6,7 +6,7 @@ import numpy as np from spikeinterface.core import SortingAnalyzer, Templates, compute_sparsity -from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array, get_template_extremum_channel +from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -99,7 +99,9 @@ def compute_monopolar_triangulation( chan_inds = sparsity.unit_id_to_channel_indices[unit_id] neighbours_mask[i, chan_inds] = True enforce_decrease_radial_parents = make_radial_order_parents(contact_locations, neighbours_mask) - best_channels = get_template_extremum_channel(sorting_analyzer_or_templates, outputs="index") + + best_channels = sorting_analyzer_or_templates.get_main_channels(outputs="index", with_dict=True) + unit_location = np.zeros((unit_ids.size, 4), dtype="float64") for i, unit_id in enumerate(unit_ids): @@ -278,7 +280,7 @@ def compute_grid_convolution( contact_locations, radius_um, upsampling_um, margin_um, weight_method ) - peak_channels = get_template_extremum_channel(sorting_analyzer_or_templates, peak_sign, outputs="index") + main_channels = sorting_analyzer_or_templates.get_main_channels(outputs="index", with_dict=True) weights_sparsity_mask = weights > 0 @@ -286,7 +288,7 @@ def compute_grid_convolution( unit_location = np.zeros((len(unit_ids), 3), dtype="float64") for i, unit_id in enumerate(unit_ids): - main_chan = peak_channels[unit_id] + main_chan = main_channels[unit_id] wf = templates[i, :, :] nearest_mask = nearest_template_mask[main_chan, :] channel_mask = np.sum(weights_sparsity_mask[:, :, nearest_mask], axis=(0, 2)) > 0 @@ -661,14 +663,10 @@ def get_convolution_weights( def compute_location_max_channel( templates_or_sorting_analyzer: SortingAnalyzer | Templates, unit_ids=None, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", ) -> np.ndarray: """ Localize a unit using max channel. - This uses internally `get_template_extremum_channel()` - Parameters ---------- @@ -689,9 +687,8 @@ def compute_location_max_channel( unit_locations: np.ndarray 2d """ - extremum_channels_index = get_template_extremum_channel( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" - ) + extremum_channels_index = templates_or_sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: unit_ids = templates_or_sorting_analyzer.unit_ids diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 993d1a105d..5a7aebf728 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -4,7 +4,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift +from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type @@ -35,9 +35,8 @@ def _get_pipeline_nodes(self): peak_sign = self.params["peak_sign"] return_in_uV = self.sorting_analyzer.return_in_uV - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) spike_retriever_node = SpikeRetriever( diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d4e226aa99..c6c9eca021 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -4,7 +4,6 @@ from spikeinterface.core.job_tools import _shared_job_kwargs_doc from spikeinterface.core.sortinganalyzer import register_result_extension -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.node_pipeline import SpikeRetriever from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -74,10 +73,7 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"] - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) retriever = SpikeRetriever( sorting, diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 44389cc503..0e46ba4df2 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -53,13 +53,11 @@ def __init__( num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: - from spikeinterface.core.template_tools import get_template_extremum_channel + main_channels = self.templates.get_main_channels(main_channel_peak_sign=self.peak_sign, outputs="index", with_dict=False) - best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index") - best_channels = np.array([best_channels[i] for i in templates.unit_ids]) channel_locations = recording.get_channel_locations() template_distances = np.linalg.norm( - channel_locations[:, None] - channel_locations[best_channels][np.newaxis, :], axis=2 + channel_locations[:, None] - channel_locations[main_channels][np.newaxis, :], axis=2 ) self.neighborhood_mask = template_distances <= neighborhood_radius_um else: diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index d3ae787a4b..96b667e404 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -6,7 +6,6 @@ import numpy as np from spikeinterface.core import ( get_channel_distances, - get_template_extremum_channel, ) from spikeinterface.sortingcomponents.peak_detection.method_list import ( @@ -222,12 +221,12 @@ def __init__( self.sparse_templates_array_static = templates.templates_array self.dtype = self.sparse_templates_array_static.dtype - extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") + # as numpy vector - self.extremum_channel = np.array([extremum_chan[unit_id] for unit_id in unit_ids], dtype="int64") + self.main_channels = templates.get_main_channels(main_channel_peak_sign=peak_sign, outputs="index", with_dict=False) channel_locations = templates.probe.contact_positions - unit_locations = channel_locations[self.extremum_channel] + unit_locations = channel_locations[self.main_channels] self.channel_locations = channel_locations # distance between units diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 505027f79a..23f662d14a 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -8,7 +8,6 @@ from .utils import get_unit_colors from .traces import TracesWidget from spikeinterface.core import ChannelSparsity -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import SortingAnalyzer from spikeinterface.core.baserecording import BaseRecording from spikeinterface.core.basesorting import BaseSorting @@ -121,9 +120,9 @@ def __init__( sparsity = sorting_analyzer.sparsity else: if sparsity is None: - # in this case, we construct a sparsity dictionary only with the best channel - extremum_channel_ids = get_template_extremum_channel(sorting_analyzer) - unit_id_to_channel_ids = {u: [ch] for u, ch in extremum_channel_ids.items()} + # in this case, we construct a sparsity dictionary only with the main channel + main_channels = sorting_analyzer.get_main_channels(outputs="id", with_dict=True) + unit_id_to_channel_ids = {u: [ch] for u, ch in main_channels.items()} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=sorting_analyzer.unit_ids, diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index c55c802f9b..2483f0a792 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -4,7 +4,6 @@ import numpy as np from probeinterface import ProbeGroup -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import SortingAnalyzer from .base import BaseWidget, to_attr @@ -86,10 +85,10 @@ def __init__( if np.any(np.isnan(all_unit_locations[sorting.ids_to_indices(unit_ids)])): warnings.warn("Some unit locations contain NaN values. Replacing with extremum channel location.") - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) for unit_id in unit_ids: if np.any(np.isnan(unit_locations[unit_id])): - unit_locations[unit_id] = channel_locations[extremum_channel_indices[unit_id]] + unit_locations[unit_id] = channel_locations[main_channels[unit_id]] data_plot = dict( all_unit_ids=sorting.unit_ids, diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index fb26a228ef..93c18a01e3 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -4,7 +4,6 @@ import warnings import numpy as np -from spikeinterface.core.template_tools import get_template_extremum_channel from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -136,12 +135,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): col_counter += 1 unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + unit_location = unit_locations[unit_id] x, y = unit_location[0], unit_location[1] if np.isnan(x) or np.isnan(y): warnings.warn(f"Unit {unit_id} location contains NaN values. Replacing NaN extremum channel location.") - x, y = sorting_analyzer.get_channel_locations()[extremum_channel_indices[unit_id]] + x, y = sorting_analyzer.get_channel_locations()[main_channels[unit_id]] ax_unit_locations.set_xlim(x - 80, x + 80) ax_unit_locations.set_ylim(y - 250, y + 250) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9543cbf734..10f9651bc6 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -5,7 +5,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from spikeinterface.core import ChannelSparsity, get_template_extremum_channel +from spikeinterface.core import ChannelSparsity class UnitWaveformDensityMapWidget(BaseWidget): @@ -25,8 +25,6 @@ class UnitWaveformDensityMapWidget(BaseWidget): If SortingAnalyzer is already sparse, the argument is ignored use_max_channel : bool, default: False Use only the max channel - peak_sign : "neg" | "pos" | "both", default: "neg" - Used to detect max channel only when use_max_channel=True unit_colors : dict | None, default: None Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted by matplotlib. If None, default colors are chosen using the `get_some_colors` function. @@ -43,7 +41,6 @@ def __init__( sparsity=None, same_axis=False, use_max_channel=False, - peak_sign="neg", unit_colors=None, backend=None, **backend_kwargs, @@ -61,9 +58,7 @@ def __init__( if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" - max_channels = get_template_extremum_channel( - sorting_analyzer, mode="extremum", peak_sign=peak_sign, outputs="index" - ) + max_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) # sparsity is done on all the units even if unit_ids is a few ones because some backends need them all if sorting_analyzer.is_sparse():