From 6dd726997b86310640378550e6fc8da27c6b659a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Feb 2026 16:59:55 +0100 Subject: [PATCH 1/2] Try to remove get_template_extremum_channel() --- .../benchmark/benchmark_clustering.py | 1 - .../tests/test_benchmark_clustering.py | 5 +- .../tests/test_benchmark_peak_detection.py | 3 +- src/spikeinterface/core/basesorting.py | 7 +- src/spikeinterface/core/generate.py | 4 + src/spikeinterface/core/node_pipeline.py | 2 + src/spikeinterface/core/sortinganalyzer.py | 89 +++++++++++++++++-- src/spikeinterface/core/sparsity.py | 62 ++++++++++--- src/spikeinterface/core/template.py | 24 +++++ src/spikeinterface/core/template_tools.py | 88 ++++++++++++++++++ .../core/tests/test_node_pipeline.py | 6 +- .../core/tests/test_sortinganalyzer.py | 31 +++++-- src/spikeinterface/exporters/report.py | 9 +- src/spikeinterface/exporters/to_ibl.py | 6 +- src/spikeinterface/generation/hybrid_tools.py | 19 ++-- .../generation/splitting_tools.py | 4 +- .../metrics/quality/misc_metrics.py | 15 ++-- .../metrics/quality/quality_metrics.py | 6 +- .../metrics/template/template_metrics.py | 5 +- .../postprocessing/amplitude_scalings.py | 7 +- .../postprocessing/localization_tools.py | 19 ++-- .../postprocessing/spike_amplitudes.py | 7 +- .../postprocessing/spike_locations.py | 5 +- .../sortingcomponents/matching/nearest.py | 6 +- .../sortingcomponents/matching/tdc_peeler.py | 7 +- .../widgets/spikes_on_traces.py | 7 +- src/spikeinterface/widgets/unit_locations.py | 5 +- src/spikeinterface/widgets/unit_summary.py | 6 +- .../widgets/unit_waveforms_density_map.py | 7 +- 29 files changed, 346 insertions(+), 116 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index ba9fa53a51..997eac878a 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -13,7 +13,6 @@ from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel class ClusteringBenchmark(Benchmark): diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index be1cf18fbf..5660b68fda 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -5,8 +5,6 @@ from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.benchmark.benchmark_clustering import ClusteringStudy -from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel from pathlib import Path @@ -33,7 +31,8 @@ def test_benchmark_clustering(create_cache_folder): # sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False) # sorting_analyzer.compute(["random_spikes", "templates"]) - extremum_channel_inds = get_template_extremum_channel(gt_analyzer, outputs="index") + extremum_channel_inds = gt_analyzer.get_main_channel(outputs="index", with_dict=True) + spikes = gt_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py index b9207caaa3..86b6bde5c5 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py @@ -6,7 +6,6 @@ from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel @pytest.mark.skip() @@ -30,7 +29,7 @@ def test_benchmark_peak_detection(create_cache_folder): sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("templates", **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index e17731c70e..b11801c0be 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -15,6 +15,9 @@ class BaseSorting(BaseExtractor): """ Abstract class representing several segment several units and relative spiketrains. """ + _main_properties = [ + "main_channel_index", + ] def __init__(self, sampling_frequency: float, unit_ids: list): BaseExtractor.__init__(self, unit_ids) @@ -786,6 +789,7 @@ def _compute_and_cache_spike_vector(self) -> None: self._cached_spike_vector = spikes self._cached_spike_vector_segment_slices = segment_slices + # TODO sam : change extremum_channel_inds to main_channel_index with vector def to_spike_vector( self, concatenated=True, @@ -806,7 +810,8 @@ def to_spike_vector( extremum_channel_inds : None or dict, default: None If a dictionnary of unit_id to channel_ind is given then an extra field "channel_index". This can be convinient for computing spikes postion after sorter. - This dict can be computed with `get_template_extremum_channel(we, outputs="index")` + This dict can be given by analyzer.get_main_channel(outputs="index", with_dict=True) + use_cache : bool, default: True When True the spikes vector is cached as an attribute of the object (`_cached_spike_vector`). This caching only occurs when extremum_channel_inds=None. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 5d7ca1917a..454540663a 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2444,6 +2444,10 @@ def generate_ground_truth_recording( **generate_templates_kwargs, ) sorting.set_property("gt_unit_locations", unit_locations) + distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :], axis=2) + main_channel_index = np.argmin(distances, axis=1) + sorting.set_property("main_channel_index", main_channel_index) + else: assert templates.shape[0] == num_units diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1609f11d17..01340f2e68 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -132,6 +132,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea return (local_peaks,) +# TODO sam replace extremum_channels_indices by main_channel_index + # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): """ diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8de45210cd..fe125193d3 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -45,6 +45,10 @@ from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor, super_zarr_open from .node_pipeline import run_node_pipeline +from .waveform_tools import estimate_templates_with_accumulator +from .sorting_tools import random_spikes_selection + + # high level function def create_sorting_analyzer( @@ -52,6 +56,10 @@ def create_sorting_analyzer( recording, format="memory", folder=None, + main_channel_index=None, + main_channel_peak_sign="both", + main_channel_mode="extremum", + num_spikes_for_main_channel=100, sparse=True, sparsity=None, set_sparsity_by_dict_key=False, @@ -59,7 +67,9 @@ def create_sorting_analyzer( return_in_uV=True, overwrite=False, backend_options=None, - **sparsity_kwargs, + sparsity_kwargs=None, + seed=None, + **job_kwargs ) -> "SortingAnalyzer": """ Create a SortingAnalyzer by pairing a Sorting and the corresponding Recording. @@ -69,6 +79,11 @@ def create_sorting_analyzer( This object will be also use used for plotting purpose. + The main_channel_index can be externally provided. If not then this is taken from + sorting property. If not then the main_channel_index is estimated using + `estimate_templates_with_accumulator()` which is fast and parallel but need to traverse + the recording. + Parameters ---------- @@ -82,6 +97,12 @@ def create_sorting_analyzer( The mode to store analyzer. If "folder", the analyzer is stored on disk in the specified folder. The "folder" argument must be specified in case of mode "folder". If "memory" is used, the analyzer is stored in RAM. Use this option carefully! + main_channel_index : None | np.array + The main_channel_index can be externally provided + main_channel_peak_sign : "both" | "neg" + In case when the main_channel_index is estimated wich sign to consider "both" or "neg". + num_spikes_for_main_channel : int, default: 100 + How many spikes per units to compute the main channel. sparse : bool, default: True If True, then a sparsity mask is computed using the `estimate_sparsity()` function using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. @@ -107,8 +128,8 @@ def create_sorting_analyzer( * storage_options: dict | None (fsspec storage options) * saving_options: dict | None (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) - - sparsity_kwargs : keyword arguments + sparsity_kwargs : dict | None + Dict given to estimate the sparsity. Returns ------- @@ -144,6 +165,9 @@ def create_sorting_analyzer( sparsity off (or give external sparsity) like this. """ + if sparsity_kwargs is None: + sparsity_kwargs = dict() + if isinstance(sorting, dict) and isinstance(recording, dict): if sorting.keys() != recording.keys(): @@ -168,9 +192,14 @@ def create_sorting_analyzer( return_in_uV=return_in_uV, overwrite=overwrite, backend_options=backend_options, - **sparsity_kwargs, + sparsity_kwargs=sparsity_kwargs, + **job_kwargs ) + # normal case + + + if format != "memory": if format == "zarr": if not is_path_remote(folder): @@ -182,6 +211,26 @@ def create_sorting_analyzer( else: shutil.rmtree(folder) + + + # retrieve or compute the main channel index per unit + if main_channel_index is None: + if "main_channel_index" in sorting.get_property_keys(): + main_channel_index = sorting.get_property("main_channel_index") + + if main_channel_index is None: + # this is weird but due to the cyclic import + from .template_tools import estimate_main_channel_from_recording + main_channel_index = estimate_main_channel_from_recording( + recording, + sorting, + main_channel_peak_sign=main_channel_peak_sign, + mode=main_channel_mode, + num_spikes_for_main_channel=num_spikes_for_main_channel, + seed=seed, + **job_kwargs + ) + # handle sparsity if sparsity is not None: # some checks @@ -192,8 +241,9 @@ def create_sorting_analyzer( assert np.array_equal( recording.channel_ids, sparsity.channel_ids ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" + assert all(sparsity.mask[u, c] for u, c in enumerate(main_channel_index)), "sparsity si not constistentent with main_channel_index" elif sparse: - sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs) + sparsity = estimate_sparsity(sorting, recording, main_channel_index=main_channel_index, **sparsity_kwargs) else: sparsity = None @@ -215,6 +265,7 @@ def create_sorting_analyzer( recording, format=format, folder=folder, + main_channel_index=main_channel_index, sparsity=sparsity, return_in_uV=return_in_uV, backend_options=backend_options, @@ -347,6 +398,7 @@ def create( "zarr", ] = "memory", folder=None, + main_channel_index=None, sparsity=None, return_scaled=None, return_in_uV=True, @@ -381,7 +433,10 @@ def create( from spikeinterface.curation.remove_excess_spikes import RemoveExcessSpikesSorting sorting = RemoveExcessSpikesSorting(sorting=sorting, recording=recording) - + + # This will ensure that the sorting saved always will have this main_channel + sorting.set_property("main_channel_index", main_channel_index) + if format == "memory": sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, rec_attributes=None) elif format == "binary_folder": @@ -541,6 +596,8 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV def load_from_binary_folder(cls, folder, recording=None, backend_options=None): from .loading import load + # TODO check that sorting has main_channel_index and ensure backward compatibility + folder = Path(folder) assert folder.is_dir(), f"This folder does not exists {folder}" @@ -713,6 +770,8 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): import zarr from .loading import load + # TODO check that sorting has main_channel_index and ensure backward compatibility + backend_options = {} if backend_options is None else backend_options storage_options = backend_options.get("storage_options", {}) @@ -881,6 +940,24 @@ def get_sorting_property(self, key: str, ids: Optional[Iterable] = None) -> np.n Array of values for the property """ return self.sorting.get_property(key, ids=ids) + + def get_main_channel(self, outputs="index", with_dict=False): + """ + + """ + main_channel_index = self.get_sorting_property("main_channel_index") + if outputs is "index": + main_chans = main_channel_index + elif outputs is "id": + main_chans = self.channel_ids[main_channel_index] + else: + raise ValueError("wrong outputs") + + if with_dict: + return dict(zip(self.unit_ids, main_chans)) + else: + return main_chans + def are_units_mergeable( self, diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index ee19601068..70bb4cdb46 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -365,9 +365,39 @@ def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels): return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): + def from_radius_and_main_channel(cls, unit_ids, channel_ids, main_channel_index, channel_locations, radius_um): """ - Construct sparsity from a radius around the best channel. + Construct sparsity from a radius around the main channel. + Use the "radius_um" argument to specify the radius in um. + + Parameters + ---------- + main_channel_index : np.array + Main channel index per units. + channel_locations : np.array + Channel locations of the recording. + radius_um : float + Radius in um for "radius" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. + """ + mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool") + distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) + for unit_ind, main_chan in enumerate(main_channel_index): + (chan_inds,) = np.nonzero(distances[main_chan, :] <= radius_um) + mask[unit_ind, chan_inds] = True + return cls(mask, unit_ids, channel_ids) + + + @classmethod + def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="both"): + """ + Construct sparsity from a radius around the main channel. Use the "radius_um" argument to specify the radius in um. Parameters @@ -384,16 +414,14 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): sparsity : ChannelSparsity The estimated sparsity. """ - from .template_tools import get_template_extremum_channel - mask = np.zeros( (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) channel_locations = templates_or_sorting_analyzer.get_channel_locations() distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) - best_chan = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, outputs="index") + main_channel_index = templates_or_sorting_analyzer.get_main_channel(outputs="index") for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): - chan_ind = best_chan[unit_id] + chan_ind = main_channel_index[unit_ind] (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) mask[unit_ind, chan_inds] = True return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @@ -724,6 +752,7 @@ def estimate_sparsity( amplitude_mode: "extremum" | "peak_to_peak" = "extremum", by_property: str | None = None, noise_levels: np.ndarray | list | None = None, + main_channel_index: np.ndarray | list | None = None, **job_kwargs, ): """ @@ -732,11 +761,10 @@ def estimate_sparsity( For the "snr" method, `noise_levels` must passed with the `noise_levels` argument. These can be computed with the `get_noise_levels()` function. - Contrary to the previous implementation: - * all units are computed in one read of recording - * it doesn't require a folder - * it doesn't consume too much memory - * it uses internally the `estimate_templates_with_accumulator()` which is fast and parallel + If main_channel_index is given and method="radius" then there is not need estimate + the templates otherwise the templates must be estimated using + `estimate_templates_with_accumulator()` which is fast and parallel but need to traverse + the recording. Note that the "energy" method is not supported because it requires a `SortingAnalyzer` object. @@ -755,6 +783,9 @@ def estimate_sparsity( noise_levels : np.array | None, default: None Noise levels required for the "snr" and "energy" methods. You can use the `get_noise_levels()` function to compute them. + main_channel_index : np.array | None, default: None + Main channel indicies can be provided in case of method="radius", this avoid the + `estimate_templates_with_accumulator()` which is slow. {} Returns @@ -779,7 +810,14 @@ def estimate_sparsity( chan_locs = recording.get_channel_locations() probe = recording.create_dummy_probe_from_locations(chan_locs) - if method != "by_property": + if method == "radius" and main_channel_index is not None: + assert main_channel_index.size == sorting.unit_ids.size + chan_locs = recording.get_channel_locations() + sparsity = ChannelSparsity.from_radius_and_main_channel( + sorting.unit_ids, recording.channel_ids, main_channel_index, chan_locs, radius_um + ) + + elif method != "by_property": nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 91d25bece6..151f6e2b8d 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -483,3 +483,27 @@ def get_channel_locations(self) -> np.ndarray: assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set" channel_locations = self.probe.contact_positions return channel_locations + + def get_main_channel(self, + main_channel_peak_sign: "neg" | "both" | "pos" = "both", + mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + outputs="index", + with_dict=True + ): + from .template_tools import _get_main_channel_from_template_array + + templates_array = self.get_dense_templates() + main_channel_index = _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, self.nbefore) + + if outputs is "index": + main_chans = main_channel_index + elif outputs is "id": + main_chans = self.channel_ids[main_channel_index] + else: + raise ValueError("wrong outputs") + + if with_dict: + return dict(zip(self.unit_ids, main_chans)) + else: + return main_chans + diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index ecc878e1f4..fe04540476 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -2,8 +2,13 @@ import numpy as np from .template import Templates +from .waveform_tools import estimate_templates_with_accumulator +from .sorting_tools import random_spikes_selection from .sortinganalyzer import SortingAnalyzer +import warnings + + def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_in_uV: bool = True): """ @@ -126,6 +131,86 @@ def get_template_amplitudes( return peak_values + +def _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, nbefore): + # Step1 : max on time axis + if mode == "extremum": + if main_channel_peak_sign == "both": + values = np.max(np.abs(templates_array), axis=1) + elif main_channel_peak_sign == "neg": + values = -np.min(templates_array, axis=1) + elif main_channel_peak_sign == "pos": + values = np.max(templates_array, axis=1) + elif mode == "at_index": + if main_channel_peak_sign == "both": + values = np.abs(templates_array[:, nbefore, :]) + elif main_channel_peak_sign in ["neg", "pos"]: + values = templates_array[:, nbefore, :] + elif mode == "peak_to_peak": + values = np.ptp(templates_array, axis=1) + + # Step2: max on channel axis + main_channel_index = np.argmax(values, axis=1) + + return main_channel_index + +def estimate_main_channel_from_recording( + recording, + sorting, + main_channel_peak_sign: "neg" | "both" | "pos" = "both", + mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + num_spikes_for_main_channel=100, + ms_before = 1.0, + ms_after = 2.5, + seed=None, + **job_kwargs +): + """ + Estimate the main channel from recording using `estimate_templates_with_accumulator()` + + """ + + if main_channel_peak_sign == "pos": + warnings.warn( + "estimate_main_channel_from_recording() with peak_sign='pos' is a strange case maybe you " \ + "should revert the traces instead" + ) + + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + random_spikes_indices = random_spikes_selection( + sorting, + num_samples, + method="uniform", + max_spikes_per_unit=num_spikes_for_main_channel, + margin_size=max(nbefore, nafter), + seed=seed, + ) + spikes = sorting.to_spike_vector() + spikes = spikes[random_spikes_indices] + + templates_array = estimate_templates_with_accumulator( + recording, + spikes, + sorting.unit_ids, + nbefore, + nafter, + return_in_uV=False, + job_name="estimate_main_channel", + **job_kwargs, + ) + + main_channel_index = _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, nbefore) + + return main_channel_index + + + + + def get_template_extremum_channel( templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", @@ -156,6 +241,9 @@ def get_template_extremum_channel( Dictionary with unit ids as keys and extremum channels (id or index based on "outputs") as values """ + warnings.warn("get_template_extremum_channel() is deprecated use analyzer.get_main_channel() instead") + # TODO make a better logic here + assert peak_sign in ("both", "neg", "pos"), "`peak_sign` must be one of `both`, `neg`, or `pos`" assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" assert outputs in ("id", "index"), "`outputs` must be either `id` or `index`" diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 4f8e600a3f..74131dac8d 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -3,7 +3,7 @@ from pathlib import Path import shutil -from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording +from spikeinterface import create_sorting_analyzer, generate_ground_truth_recording from spikeinterface.core.base import spike_peak_dtype from spikeinterface.core.job_tools import divide_recording_into_chunks @@ -80,7 +80,7 @@ def test_run_node_pipeline(cache_folder_creation): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channel( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) @@ -202,7 +202,7 @@ def test_skip_after_n_peaks_and_recording_slices(): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channel( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..7be23dbf85 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -52,6 +52,12 @@ def dataset(): def test_SortingAnalyzer_memory(tmp_path, dataset): recording, sorting = dataset + + # Note the sorting contain already main_channel_index + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + assert np.array_equal(sorting_analyzer.get_main_channel() , sorting.get_property("main_channel_index")) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) @@ -75,6 +81,16 @@ def test_SortingAnalyzer_memory(tmp_path, dataset): assert "quality" in sorting_analyzer.sorting.get_property_keys() assert "number" in sorting_analyzer.sorting.get_property_keys() + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + + # Create when main_channel_index is not given : this is estimated + sorting2 = sorting.clone() + sorting2._properties.pop("main_channel_index") + print(sorting2.get_property("main_channel_index")) + sorting_analyzer = create_sorting_analyzer(sorting2, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting2, cache_folder=tmp_path) + def test_SortingAnalyzer_binary_folder(tmp_path, dataset): recording, sorting = dataset @@ -615,12 +631,11 @@ def _set_params(self, param0=5.5): return params def _get_pipeline_nodes(self): - from spikeinterface.core.template_tools import get_template_extremum_channel recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - extremum_channel_inds = get_template_extremum_channel(self.sorting_analyzer, outputs="index") + extremum_channel_inds = self.sorting_analyzer.get_main_channel( outputs="index", with_dict=True) spike_retriever_node = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) @@ -718,9 +733,9 @@ def test_runtime_dependencies(dataset): tmp_path = Path("test_SortingAnalyzer") dataset = get_dataset() test_SortingAnalyzer_memory(tmp_path, dataset) - test_SortingAnalyzer_binary_folder(tmp_path, dataset) - test_SortingAnalyzer_zarr(tmp_path, dataset) - test_SortingAnalyzer_tmp_recording(dataset) - test_extension() - test_extension_params() - test_runtime_dependencies() + # test_SortingAnalyzer_binary_folder(tmp_path, dataset) + # test_SortingAnalyzer_zarr(tmp_path, dataset) + # test_SortingAnalyzer_tmp_recording(dataset) + # test_extension() + # test_extension_params() + # test_runtime_dependencies(dataset) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index fe9fb3ba52..8c12591c6e 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -6,7 +6,7 @@ from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs import spikeinterface.widgets as sw -from spikeinterface.core import get_template_extremum_channel, get_template_extremum_amplitude +from spikeinterface.core import get_template_extremum_amplitude from spikeinterface.postprocessing import compute_correlograms @@ -101,9 +101,10 @@ def export_report( # unit list units = pd.DataFrame(index=unit_ids) #  , columns=['max_on_channel_id', 'amplitude']) units.index.name = "unit_id" - units["max_on_channel_id"] = pd.Series( - get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="id") - ) + # max_on_channel_id is kept (oold name) + units["max_on_channel_id"] = sorting_analyzer.get_main_channel(outputs="id", with_dict=False) + units["main_channel_id"] = sorting_analyzer.get_main_channel(outputs="id", with_dict=False) + units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign)) units.to_csv(output_folder / "unit list.csv", sep="\t") diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 6559e89d52..b0445b4753 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -10,7 +10,6 @@ from spikeinterface.core import SortingAnalyzer, BaseRecording, get_random_data_chunks from spikeinterface.core.job_tools import fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.exporters import export_to_phy @@ -102,7 +101,7 @@ def export_to_ibl_gui( output_folder.mkdir(parents=True, exist_ok=True) ### Save spikes info ### - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_indices) # spikes.clusters @@ -137,7 +136,8 @@ def export_to_ibl_gui( np.save(output_folder / "clusters.waveforms.npy", templates) # cluster channels - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + cluster_channels = np.array(list(extremum_channel_indices.values()), dtype="int32") np.save(output_folder / "clusters.channels.npy", cluster_channels) diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index bbab9262af..72d8cea634 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -13,7 +13,6 @@ InjectTemplatesRecording, _ensure_seed, ) -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.motion import Motion @@ -128,8 +127,8 @@ def select_templates( min_amplitude is not None or max_amplitude is not None or min_depth is not None or max_depth is not None ), "At least one of min_amplitude, max_amplitude, min_depth, max_depth should be provided" # get template amplitudes and depth - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) + mask = np.ones(templates.num_units, dtype=bool) if min_amplitude is not None or max_amplitude is not None: @@ -143,7 +142,7 @@ def select_templates( amplitudes = np.zeros(templates.num_units) templates_array = templates.templates_array for i in range(templates.num_units): - amplitudes[i] = amp_fun(templates_array[i, :, extremum_channel_indices[i]]) + amplitudes[i] = amp_fun(templates_array[i, :, main_channel_indices[i]]) if min_amplitude is not None: mask &= amplitudes >= min_amplitude if max_amplitude is not None: @@ -152,7 +151,7 @@ def select_templates( assert templates.probe is not None, "Templates should have a probe to filter based on depth" depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] - unit_depths = channel_depths[extremum_channel_indices] + unit_depths = channel_depths[main_channel_indices] if min_depth is not None: mask &= unit_depths >= min_depth if max_depth is not None: @@ -191,8 +190,7 @@ def scale_template_to_range( Templates The scaled templates. """ - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) # get amplitudes if amplitude_function == "ptp": @@ -204,7 +202,7 @@ def scale_template_to_range( amplitudes = np.zeros(templates.num_units) templates_array = templates.templates_array for i in range(templates.num_units): - amplitudes[i] = amp_fun(templates_array[i, :, extremum_channel_indices[i]]) + amplitudes[i] = amp_fun(templates_array[i, :, main_channel_indices[i]]) # scale templates to meet min_amplitude and max_amplitude range min_scale = np.min(amplitudes) / min_amplitude @@ -265,11 +263,10 @@ def relocate_templates( """ seed = _ensure_seed(seed) - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] - unit_depths = channel_depths[extremum_channel_indices] + unit_depths = channel_depths[main_channel_indices] assert margin >= 0, "margin should be positive" top_margin = np.max(channel_depths) + margin diff --git a/src/spikeinterface/generation/splitting_tools.py b/src/spikeinterface/generation/splitting_tools.py index 1f404ea3f7..e2314319a1 100644 --- a/src/spikeinterface/generation/splitting_tools.py +++ b/src/spikeinterface/generation/splitting_tools.py @@ -107,9 +107,9 @@ def split_sorting_by_amplitudes( rng = np.random.default_rng(seed) fs = sorting_analyzer.sampling_frequency - from spikeinterface.core.template_tools import get_template_extremum_channel - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, concatenated=False) new_spikes = spikes[0].copy() amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 3d26f7a85e..f4a6afcd9b 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -20,7 +20,6 @@ from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.core import SortingAnalyzer, get_noise_levels, NumpySorting from spikeinterface.core.template_tools import ( - get_template_extremum_channel, get_template_extremum_amplitude, get_dense_templates_array, ) @@ -182,16 +181,13 @@ def compute_snrs( channel_ids = sorting_analyzer.channel_ids - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) + main_channel_index = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) unit_amplitudes = get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) - # make a dict to access by chan_id - noise_levels = dict(zip(channel_ids, noise_levels)) - snrs = {} for unit_id in unit_ids: - chan_id = extremum_channels_ids[unit_id] - noise = noise_levels[chan_id] + chan_ind = main_channel_index[unit_id] + noise = noise_levels[chan_ind] amplitude = unit_amplitudes[unit_id] snrs[unit_id] = np.abs(amplitude) / noise @@ -1294,7 +1290,8 @@ def compute_sd_ratio( noise_levels = get_noise_levels( sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs ) - best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", **kwargs) + main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) if correct_for_template_itself: @@ -1330,7 +1327,7 @@ def compute_sd_ratio( else: unit_std = np.std(spk_amp) - best_channel = best_channels[unit_id] + best_channel = main_channels[unit_id] std_noise = noise_levels[best_channel] n_samples = sorting_analyzer.get_total_samples() diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index 5476aa405a..e4ce1a68fe 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -5,7 +5,6 @@ import warnings import numpy as np -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension @@ -137,7 +136,8 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): all_labels = sorting_analyzer.sorting.unit_ids[spike_unit_indices] # Get extremum channels for neighbor selection in sparse mode - extremum_channels = get_template_extremum_channel(sorting_analyzer) + + main_channels = sorting_analyzer.get_main_channel(outputs="id", with_dict=True) # Pre-compute spike counts and firing rates if advanced NN metrics are requested advanced_nn_metrics = ["nn_advanced"] # Our grouped advanced NN metric @@ -152,7 +152,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): if sorting_analyzer.is_sparse(): neighbor_channel_ids = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] neighbor_unit_ids = [ - other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids + other_unit for other_unit in unit_ids if main_channels[other_unit] in neighbor_channel_ids ] neighbor_channel_indices = sorting_analyzer.channel_ids_to_indices(neighbor_channel_ids) else: diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 85ef9e22cb..7f6175e47a 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -12,7 +12,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension -from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array +from spikeinterface.core.template_tools import get_dense_templates_array from .metrics import get_trough_and_peak_idx, single_channel_metrics, multi_channel_metrics @@ -189,7 +189,8 @@ def _prepare_data(self, sorting_analyzer, unit_ids): m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"] ) - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True) channel_locations = sorting_analyzer.get_channel_locations() diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 473798fe7c..6cae681f15 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -3,7 +3,7 @@ import numpy as np from spikeinterface.core import ChannelSparsity -from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore +from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -105,9 +105,8 @@ def _get_pipeline_nodes(self): cut_out_after = nafter peak_sign = "neg" if np.abs(np.min(all_templates)) > np.max(all_templates) else "pos" - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + # collisions handle_collisions = self.params["handle_collisions"] diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 671b9bb239..e180b868ab 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -6,7 +6,7 @@ import numpy as np from spikeinterface.core import SortingAnalyzer, Templates, compute_sparsity -from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array, get_template_extremum_channel +from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -99,7 +99,9 @@ def compute_monopolar_triangulation( chan_inds = sparsity.unit_id_to_channel_indices[unit_id] neighbours_mask[i, chan_inds] = True enforce_decrease_radial_parents = make_radial_order_parents(contact_locations, neighbours_mask) - best_channels = get_template_extremum_channel(sorting_analyzer_or_templates, outputs="index") + + best_channels = sorting_analyzer_or_templates.get_main_channel(outputs="index", with_dict=True) + unit_location = np.zeros((unit_ids.size, 4), dtype="float64") for i, unit_id in enumerate(unit_ids): @@ -278,7 +280,7 @@ def compute_grid_convolution( contact_locations, radius_um, upsampling_um, margin_um, weight_method ) - peak_channels = get_template_extremum_channel(sorting_analyzer_or_templates, peak_sign, outputs="index") + main_channels = sorting_analyzer_or_templates.get_main_channel(outputs="index", with_dict=True) weights_sparsity_mask = weights > 0 @@ -286,7 +288,7 @@ def compute_grid_convolution( unit_location = np.zeros((len(unit_ids), 3), dtype="float64") for i, unit_id in enumerate(unit_ids): - main_chan = peak_channels[unit_id] + main_chan = main_channels[unit_id] wf = templates[i, :, :] nearest_mask = nearest_template_mask[main_chan, :] channel_mask = np.sum(weights_sparsity_mask[:, :, nearest_mask], axis=(0, 2)) > 0 @@ -661,14 +663,10 @@ def get_convolution_weights( def compute_location_max_channel( templates_or_sorting_analyzer: SortingAnalyzer | Templates, unit_ids=None, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", ) -> np.ndarray: """ Localize a unit using max channel. - This uses internally `get_template_extremum_channel()` - Parameters ---------- @@ -689,9 +687,8 @@ def compute_location_max_channel( unit_locations: np.ndarray 2d """ - extremum_channels_index = get_template_extremum_channel( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" - ) + extremum_channels_index = templates_or_sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: unit_ids = templates_or_sorting_analyzer.unit_ids diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 993d1a105d..b5c7ff1f6d 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -4,7 +4,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift +from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type @@ -35,9 +35,8 @@ def _get_pipeline_nodes(self): peak_sign = self.params["peak_sign"] return_in_uV = self.sorting_analyzer.return_in_uV - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) spike_retriever_node = SpikeRetriever( diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d4e226aa99..cf5adb6198 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -4,7 +4,6 @@ from spikeinterface.core.job_tools import _shared_job_kwargs_doc from spikeinterface.core.sortinganalyzer import register_result_extension -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.node_pipeline import SpikeRetriever from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -75,9 +74,7 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"] - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) retriever = SpikeRetriever( sorting, diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 44389cc503..7ac4d36d0a 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -53,13 +53,11 @@ def __init__( num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: - from spikeinterface.core.template_tools import get_template_extremum_channel + main_channels = self.templates.get_main_channel(main_channel_peak_sign=self.peak_sign, outputs="index", with_dict=False) - best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index") - best_channels = np.array([best_channels[i] for i in templates.unit_ids]) channel_locations = recording.get_channel_locations() template_distances = np.linalg.norm( - channel_locations[:, None] - channel_locations[best_channels][np.newaxis, :], axis=2 + channel_locations[:, None] - channel_locations[main_channels][np.newaxis, :], axis=2 ) self.neighborhood_mask = template_distances <= neighborhood_radius_um else: diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index d3ae787a4b..1622ab83a5 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -6,7 +6,6 @@ import numpy as np from spikeinterface.core import ( get_channel_distances, - get_template_extremum_channel, ) from spikeinterface.sortingcomponents.peak_detection.method_list import ( @@ -222,12 +221,12 @@ def __init__( self.sparse_templates_array_static = templates.templates_array self.dtype = self.sparse_templates_array_static.dtype - extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") + # as numpy vector - self.extremum_channel = np.array([extremum_chan[unit_id] for unit_id in unit_ids], dtype="int64") + self.main_channels = templates.get_main_channel(main_channel_peak_sign=peak_sign, outputs="index", with_dict=False) channel_locations = templates.probe.contact_positions - unit_locations = channel_locations[self.extremum_channel] + unit_locations = channel_locations[self.main_channels] self.channel_locations = channel_locations # distance between units diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 505027f79a..6f8ba998e7 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -8,7 +8,6 @@ from .utils import get_unit_colors from .traces import TracesWidget from spikeinterface.core import ChannelSparsity -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import SortingAnalyzer from spikeinterface.core.baserecording import BaseRecording from spikeinterface.core.basesorting import BaseSorting @@ -121,9 +120,9 @@ def __init__( sparsity = sorting_analyzer.sparsity else: if sparsity is None: - # in this case, we construct a sparsity dictionary only with the best channel - extremum_channel_ids = get_template_extremum_channel(sorting_analyzer) - unit_id_to_channel_ids = {u: [ch] for u, ch in extremum_channel_ids.items()} + # in this case, we construct a sparsity dictionary only with the main channel + main_channels = sorting_analyzer.get_main_channel(outputs="id", with_dict=True) + unit_id_to_channel_ids = {u: [ch] for u, ch in main_channels.items()} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=sorting_analyzer.unit_ids, diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index c55c802f9b..69bc2b05cf 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -4,7 +4,6 @@ import numpy as np from probeinterface import ProbeGroup -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import SortingAnalyzer from .base import BaseWidget, to_attr @@ -86,10 +85,10 @@ def __init__( if np.any(np.isnan(all_unit_locations[sorting.ids_to_indices(unit_ids)])): warnings.warn("Some unit locations contain NaN values. Replacing with extremum channel location.") - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) for unit_id in unit_ids: if np.any(np.isnan(unit_locations[unit_id])): - unit_locations[unit_id] = channel_locations[extremum_channel_indices[unit_id]] + unit_locations[unit_id] = channel_locations[main_channels[unit_id]] data_plot = dict( all_unit_ids=sorting.unit_ids, diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index fb26a228ef..411cd34203 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -4,7 +4,6 @@ import warnings import numpy as np -from spikeinterface.core.template_tools import get_template_extremum_channel from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -136,12 +135,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): col_counter += 1 unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + unit_location = unit_locations[unit_id] x, y = unit_location[0], unit_location[1] if np.isnan(x) or np.isnan(y): warnings.warn(f"Unit {unit_id} location contains NaN values. Replacing NaN extremum channel location.") - x, y = sorting_analyzer.get_channel_locations()[extremum_channel_indices[unit_id]] + x, y = sorting_analyzer.get_channel_locations()[main_channels[unit_id]] ax_unit_locations.set_xlim(x - 80, x + 80) ax_unit_locations.set_ylim(y - 250, y + 250) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9543cbf734..91854ec0eb 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -5,7 +5,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from spikeinterface.core import ChannelSparsity, get_template_extremum_channel +from spikeinterface.core import ChannelSparsity class UnitWaveformDensityMapWidget(BaseWidget): @@ -43,7 +43,6 @@ def __init__( sparsity=None, same_axis=False, use_max_channel=False, - peak_sign="neg", unit_colors=None, backend=None, **backend_kwargs, @@ -61,9 +60,7 @@ def __init__( if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" - max_channels = get_template_extremum_channel( - sorting_analyzer, mode="extremum", peak_sign=peak_sign, outputs="index" - ) + max_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) # sparsity is done on all the units even if unit_ids is a few ones because some backends need them all if sorting_analyzer.is_sparse(): From eede722009b7aae012d1c7e8aac92c44d3d88a90 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Feb 2026 19:03:30 +0100 Subject: [PATCH 2/2] Put main_channel_peak_sign and main_channel_peak_mode in analyzer settings to reload them --- .../tests/test_benchmark_clustering.py | 2 +- .../tests/test_benchmark_peak_detection.py | 2 +- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/sortinganalyzer.py | 123 +++++++++++++++--- src/spikeinterface/core/sparsity.py | 27 ++-- src/spikeinterface/core/template.py | 2 +- src/spikeinterface/core/template_tools.py | 14 +- .../core/tests/test_node_pipeline.py | 4 +- .../core/tests/test_sortinganalyzer.py | 5 +- .../core/tests/test_sparsity.py | 4 +- src/spikeinterface/exporters/report.py | 4 +- src/spikeinterface/exporters/to_ibl.py | 4 +- src/spikeinterface/generation/hybrid_tools.py | 6 +- .../generation/splitting_tools.py | 2 +- .../metrics/quality/misc_metrics.py | 4 +- .../metrics/quality/quality_metrics.py | 2 +- .../metrics/template/template_metrics.py | 2 +- .../postprocessing/amplitude_scalings.py | 3 +- .../postprocessing/localization_tools.py | 6 +- .../postprocessing/spike_amplitudes.py | 2 +- .../postprocessing/spike_locations.py | 3 +- .../sortingcomponents/matching/nearest.py | 2 +- .../sortingcomponents/matching/tdc_peeler.py | 2 +- .../widgets/spikes_on_traces.py | 2 +- src/spikeinterface/widgets/unit_locations.py | 2 +- src/spikeinterface/widgets/unit_summary.py | 2 +- .../widgets/unit_waveforms_density_map.py | 4 +- 27 files changed, 158 insertions(+), 79 deletions(-) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index 5660b68fda..eae0bf0e59 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -31,7 +31,7 @@ def test_benchmark_clustering(create_cache_folder): # sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False) # sorting_analyzer.compute(["random_spikes", "templates"]) - extremum_channel_inds = gt_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_inds = gt_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = gt_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py index 86b6bde5c5..82a51e8292 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py @@ -29,7 +29,7 @@ def test_benchmark_peak_detection(create_cache_folder): sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("templates", **job_kwargs) - extremum_channel_inds = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index b11801c0be..fd8476e2b6 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -810,7 +810,7 @@ def to_spike_vector( extremum_channel_inds : None or dict, default: None If a dictionnary of unit_id to channel_ind is given then an extra field "channel_index". This can be convinient for computing spikes postion after sorter. - This dict can be given by analyzer.get_main_channel(outputs="index", with_dict=True) + This dict can be given by analyzer.get_main_channels(outputs="index", with_dict=True) use_cache : bool, default: True When True the spikes vector is cached as an attribute of the object (`_cached_spike_vector`). diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index fe125193d3..f1fe4b81c9 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -58,7 +58,7 @@ def create_sorting_analyzer( folder=None, main_channel_index=None, main_channel_peak_sign="both", - main_channel_mode="extremum", + main_channel_peak_mode="extremum", num_spikes_for_main_channel=100, sparse=True, sparsity=None, @@ -101,6 +101,11 @@ def create_sorting_analyzer( The main_channel_index can be externally provided main_channel_peak_sign : "both" | "neg" In case when the main_channel_index is estimated wich sign to consider "both" or "neg". + main_channel_peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum" : take the peak value (max or min depending on `peak_sign`) + * "at_index" : take value at `nbefore` index + * "peak_to_peak" : take the peak-to-peak amplitude num_spikes_for_main_channel : int, default: 100 How many spikes per units to compute the main channel. sparse : bool, default: True @@ -225,7 +230,7 @@ def create_sorting_analyzer( recording, sorting, main_channel_peak_sign=main_channel_peak_sign, - mode=main_channel_mode, + peak_mode=main_channel_peak_mode, num_spikes_for_main_channel=num_spikes_for_main_channel, seed=seed, **job_kwargs @@ -266,6 +271,8 @@ def create_sorting_analyzer( format=format, folder=folder, main_channel_index=main_channel_index, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, sparsity=sparsity, return_in_uV=return_in_uV, backend_options=backend_options, @@ -335,6 +342,8 @@ def __init__( format: str | None = None, sparsity: ChannelSparsity | None = None, return_in_uV: bool = True, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", backend_options: dict | None = None, ): # very fast init because checks are done in load and create @@ -345,6 +354,8 @@ def __init__( self.format = format self.sparsity = sparsity self.return_in_uV = return_in_uV + self.main_channel_peak_sign = main_channel_peak_sign + self.main_channel_peak_mode = main_channel_peak_mode # For backward compatibility self.return_scaled = return_in_uV @@ -402,6 +413,8 @@ def create( sparsity=None, return_scaled=None, return_in_uV=True, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", backend_options=None, ): assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" @@ -435,10 +448,14 @@ def create( sorting = RemoveExcessSpikesSorting(sorting=sorting, recording=recording) # This will ensure that the sorting saved always will have this main_channel + assert main_channel_index is not None sorting.set_property("main_channel_index", main_channel_index) if format == "memory": - sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, rec_attributes=None) + sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes=None) elif format == "binary_folder": sorting_analyzer = cls.create_binary_folder( folder, @@ -446,6 +463,8 @@ def create( recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, rec_attributes=None, backend_options=backend_options, ) @@ -459,6 +478,8 @@ def create( recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, rec_attributes=None, backend_options=backend_options, ) @@ -497,7 +518,10 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe return sorting_analyzer @classmethod - def create_memory(cls, sorting, recording, sparsity, return_in_uV, rec_attributes): + def create_memory(cls, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes): # used by create and save_as if rec_attributes is None: @@ -518,11 +542,18 @@ def create_memory(cls, sorting, recording, sparsity, return_in_uV, rec_attribute format="memory", sparsity=sparsity, return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, + ) return sorting_analyzer @classmethod - def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options): + def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + + rec_attributes, backend_options): # used by create and save_as folder = Path(folder) @@ -586,12 +617,45 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV settings_file = folder / f"settings.json" settings = dict( return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, ) with open(settings_file, mode="w") as f: json.dump(check_json(settings), f, indent=4) return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options) + @classmethod + def _handle_backward_compatibility(cls, settings, sorting, sparsity): + # backward compatibility at analyzer level + # (there is also something similar at extension level) + + new_settings = dict() + new_settings.update(settings) + if "return_scaled" in settings: + new_settings["return_in_uV"] = new_settings.pop("return_scaled") + elif "return_in_uV" in settings: + pass + else: + # old version did not have settings at all + new_settings["return_in_uV"] = True + + retrospect_main_channel_index = None + if "main_channel_peak_sign" not in settings: + # before 0.104.0 was not in main_channel_peak_sign + # TODO make something more fancy that exlore the previous params of extension + new_settings["main_channel_peak_sign"] = "both" + new_settings["main_channel_peak_mode"] = "extremum" + + if "main_channel_index" not in sorting.get_property_keys(): + # TODO + raise NotImplementedError("backward compatibility with main_channel_index is not implemented yet") + + if retrospect_main_channel_index is not None: + sorting.set_property("main_channel_index", retrospect_main_channel_index) + + return new_settings + @classmethod def load_from_binary_folder(cls, folder, recording=None, backend_options=None): from .loading import load @@ -653,13 +717,19 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): if settings_file.exists(): with open(settings_file, "r") as f: settings = json.load(f) + need_to_create = False else: + need_to_create = True + settings = dict() + + settings = cls._handle_backward_compatibility(settings, sorting, sparsity) + + if need_to_create: warnings.warn("settings.json not found for this folder writing one with return_in_uV=True") - settings = dict(return_in_uV=True) with open(settings_file, "w") as f: json.dump(check_json(settings), f, indent=4) - return_in_uV = settings.get("return_in_uV", settings.get("return_scaled", True)) + sorting_analyzer = SortingAnalyzer( sorting=sorting, @@ -667,7 +737,9 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): rec_attributes=rec_attributes, format="binary_folder", sparsity=sparsity, - return_in_uV=return_in_uV, + return_in_uV = settings["return_in_uV"], + main_channel_peak_sign = settings["main_channel_peak_sign"], + main_channel_peak_mode = settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -682,7 +754,11 @@ def _get_zarr_root(self, mode="r+"): return zarr_root @classmethod - def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options): + def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + + rec_attributes, backend_options): # used by create and save_as import zarr import numcodecs @@ -706,7 +782,11 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) - settings = dict(return_in_uV=return_in_uV) + settings = dict( + return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, + ) zarr_root.attrs["settings"] = check_json(settings) # the recording @@ -827,10 +907,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): ) else: sparsity = None - - return_in_uV = zarr_root.attrs["settings"].get( - "return_in_uV", zarr_root.attrs["settings"].get("return_scaled", True) - ) + + settings = zarr_root.attrs["settings"] + settings = cls._handle_backward_compatibility(settings, sorting, sparsity) sorting_analyzer = SortingAnalyzer( sorting=sorting, @@ -838,7 +917,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): rec_attributes=rec_attributes, format="zarr", sparsity=sparsity, - return_in_uV=return_in_uV, + return_in_uV = settings["return_in_uV"], + main_channel_peak_sign = settings["main_channel_peak_sign"], + main_channel_peak_mode = settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -941,7 +1022,7 @@ def get_sorting_property(self, key: str, ids: Optional[Iterable] = None) -> np.n """ return self.sorting.get_property(key, ids=ids) - def get_main_channel(self, outputs="index", with_dict=False): + def get_main_channels(self, outputs="index", with_dict=False): """ """ @@ -1180,7 +1261,10 @@ def _save_or_select_or_merge_or_split( if format == "memory": # This make a copy of actual SortingAnalyzer new_sorting_analyzer = SortingAnalyzer.create_memory( - sorting_provenance, recording, sparsity, self.return_in_uV, self.rec_attributes + sorting_provenance, recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, + self.rec_attributes ) elif format == "binary_folder": @@ -1193,6 +1277,9 @@ def _save_or_select_or_merge_or_split( recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, + self.rec_attributes, backend_options=backend_options, ) @@ -1206,6 +1293,8 @@ def _save_or_select_or_merge_or_split( recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, self.rec_attributes, backend_options=backend_options, ) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 70bb4cdb46..e82d5f19e9 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -29,8 +29,6 @@ In this case the sparsity for each unit is given by the channels that have the same property value as the unit. Use the "by_property" argument to specify the property name. - peak_sign : "neg" | "pos" | "both" - Sign of the template to compute best channels. num_channels : int Number of channels for "best_channels" method. radius_um : float @@ -83,14 +81,14 @@ class ChannelSparsity: Using the N best channels (largest template amplitude): - >>> sparsity = ChannelSparsity.from_best_channels(sorting_analyzer, num_channels, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_best_channels(sorting_analyzer, num_channels) Using a neighborhood by radius: - >>> sparsity = ChannelSparsity.from_radius(sorting_analyzer, radius_um, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_radius(sorting_analyzer, radius_um) Using a SNR threshold: - >>> sparsity = ChannelSparsity.from_snr(sorting_analyzer, threshold, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_snr(sorting_analyzer, threshold) Using a template energy threshold: >>> sparsity = ChannelSparsity.from_energy(sorting_analyzer, threshold) @@ -395,7 +393,7 @@ def from_radius_and_main_channel(cls, unit_ids, channel_ids, main_channel_index, @classmethod - def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="both"): + def from_radius(cls, templates_or_sorting_analyzer, radius_um): """ Construct sparsity from a radius around the main channel. Use the "radius_um" argument to specify the radius in um. @@ -414,17 +412,14 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="both") sparsity : ChannelSparsity The estimated sparsity. """ - mask = np.zeros( - (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" - ) + main_channel_index = templates_or_sorting_analyzer.get_main_channels(outputs="index") channel_locations = templates_or_sorting_analyzer.get_channel_locations() - distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) - main_channel_index = templates_or_sorting_analyzer.get_main_channel(outputs="index") - for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): - chan_ind = main_channel_index[unit_ind] - (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) - mask[unit_ind, chan_inds] = True - return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) + return cls.from_radius_and_main_channel( + templates_or_sorting_analyzer.unit_ids, + templates_or_sorting_analyzer.channel_ids, + main_channel_index, + channel_locations, + radius_um) @classmethod def from_snr( diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 151f6e2b8d..4aac641372 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -484,7 +484,7 @@ def get_channel_locations(self) -> np.ndarray: channel_locations = self.probe.contact_positions return channel_locations - def get_main_channel(self, + def get_main_channels(self, main_channel_peak_sign: "neg" | "both" | "pos" = "both", mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", outputs="index", diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index fe04540476..584bfaee7e 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -132,21 +132,21 @@ def get_template_amplitudes( -def _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, nbefore): +def _get_main_channel_from_template_array(templates_array, peak_mode, main_channel_peak_sign, nbefore): # Step1 : max on time axis - if mode == "extremum": + if peak_mode == "extremum": if main_channel_peak_sign == "both": values = np.max(np.abs(templates_array), axis=1) elif main_channel_peak_sign == "neg": values = -np.min(templates_array, axis=1) elif main_channel_peak_sign == "pos": values = np.max(templates_array, axis=1) - elif mode == "at_index": + elif peak_mode == "at_index": if main_channel_peak_sign == "both": values = np.abs(templates_array[:, nbefore, :]) elif main_channel_peak_sign in ["neg", "pos"]: values = templates_array[:, nbefore, :] - elif mode == "peak_to_peak": + elif peak_mode == "peak_to_peak": values = np.ptp(templates_array, axis=1) # Step2: max on channel axis @@ -158,7 +158,7 @@ def estimate_main_channel_from_recording( recording, sorting, main_channel_peak_sign: "neg" | "both" | "pos" = "both", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", num_spikes_for_main_channel=100, ms_before = 1.0, ms_after = 2.5, @@ -203,7 +203,7 @@ def estimate_main_channel_from_recording( **job_kwargs, ) - main_channel_index = _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, nbefore) + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, main_channel_peak_sign, nbefore) return main_channel_index @@ -241,7 +241,7 @@ def get_template_extremum_channel( Dictionary with unit ids as keys and extremum channels (id or index based on "outputs") as values """ - warnings.warn("get_template_extremum_channel() is deprecated use analyzer.get_main_channel() instead") + warnings.warn("get_template_extremum_channel() is deprecated use analyzer.get_main_channels() instead") # TODO make a better logic here assert peak_sign in ("both", "neg", "pos"), "`peak_sign` must be one of `both`, `neg`, or `pos`" diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 74131dac8d..7a29a3cee6 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -80,7 +80,7 @@ def test_run_node_pipeline(cache_folder_creation): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = sorting_analyzer.get_main_channel( outputs="index", with_dict=True) + extremum_channel_inds = sorting_analyzer.get_main_channels( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) @@ -202,7 +202,7 @@ def test_skip_after_n_peaks_and_recording_slices(): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = sorting_analyzer.get_main_channel( outputs="index", with_dict=True) + extremum_channel_inds = sorting_analyzer.get_main_channels( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 7be23dbf85..765d5bec7e 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -56,7 +56,7 @@ def test_SortingAnalyzer_memory(tmp_path, dataset): # Note the sorting contain already main_channel_index sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) - assert np.array_equal(sorting_analyzer.get_main_channel() , sorting.get_property("main_channel_index")) + assert np.array_equal(sorting_analyzer.get_main_channels() , sorting.get_property("main_channel_index")) sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) @@ -365,7 +365,6 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): assert ext is None assert sorting_analyzer.has_recording() - # save to several format for format in ("memory", "binary_folder", "zarr"): if format != "memory": @@ -635,7 +634,7 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - extremum_channel_inds = self.sorting_analyzer.get_main_channel( outputs="index", with_dict=True) + extremum_channel_inds = self.sorting_analyzer.get_main_channels( outputs="index", with_dict=True) spike_retriever_node = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index c865068e4a..b013c9ca90 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -285,7 +285,7 @@ def test_compute_sparsity(): # using object SortingAnalyzer sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") + sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0) sparsity = compute_sparsity(sorting_analyzer, method="closest_channels", num_channels=2) sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") sparsity = compute_sparsity( @@ -299,7 +299,7 @@ def test_compute_sparsity(): templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") + sparsity = compute_sparsity(templates, method="radius", radius_um=50.0) sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(templates, method="closest_channels", num_channels=2) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 8c12591c6e..66b71a1b1c 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -102,8 +102,8 @@ def export_report( units = pd.DataFrame(index=unit_ids) #  , columns=['max_on_channel_id', 'amplitude']) units.index.name = "unit_id" # max_on_channel_id is kept (oold name) - units["max_on_channel_id"] = sorting_analyzer.get_main_channel(outputs="id", with_dict=False) - units["main_channel_id"] = sorting_analyzer.get_main_channel(outputs="id", with_dict=False) + units["max_on_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) + units["main_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign)) units.to_csv(output_folder / "unit list.csv", sep="\t") diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index b0445b4753..9a0d5847fb 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -101,7 +101,7 @@ def export_to_ibl_gui( output_folder.mkdir(parents=True, exist_ok=True) ### Save spikes info ### - extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_indices) # spikes.clusters @@ -136,7 +136,7 @@ def export_to_ibl_gui( np.save(output_folder / "clusters.waveforms.npy", templates) # cluster channels - extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) cluster_channels = np.array(list(extremum_channel_indices.values()), dtype="int32") np.save(output_folder / "clusters.channels.npy", cluster_channels) diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 72d8cea634..406fbff18e 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -127,7 +127,7 @@ def select_templates( min_amplitude is not None or max_amplitude is not None or min_depth is not None or max_depth is not None ), "At least one of min_amplitude, max_amplitude, min_depth, max_depth should be provided" # get template amplitudes and depth - main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) mask = np.ones(templates.num_units, dtype=bool) @@ -190,7 +190,7 @@ def scale_template_to_range( Templates The scaled templates. """ - main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) # get amplitudes if amplitude_function == "ptp": @@ -263,7 +263,7 @@ def relocate_templates( """ seed = _ensure_seed(seed) - main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] unit_depths = channel_depths[main_channel_indices] diff --git a/src/spikeinterface/generation/splitting_tools.py b/src/spikeinterface/generation/splitting_tools.py index e2314319a1..03d7cd5ef8 100644 --- a/src/spikeinterface/generation/splitting_tools.py +++ b/src/spikeinterface/generation/splitting_tools.py @@ -108,7 +108,7 @@ def split_sorting_by_amplitudes( rng = np.random.default_rng(seed) fs = sorting_analyzer.sampling_frequency - extremum_channel_inds = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, concatenated=False) new_spikes = spikes[0].copy() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index f4a6afcd9b..128a509e80 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -181,7 +181,7 @@ def compute_snrs( channel_ids = sorting_analyzer.channel_ids - main_channel_index = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + main_channel_index = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) unit_amplitudes = get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) snrs = {} @@ -1290,7 +1290,7 @@ def compute_sd_ratio( noise_levels = get_noise_levels( sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs ) - main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index e4ce1a68fe..405a31d068 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -137,7 +137,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): # Get extremum channels for neighbor selection in sparse mode - main_channels = sorting_analyzer.get_main_channel(outputs="id", with_dict=True) + main_channels = sorting_analyzer.get_main_channels(outputs="id", with_dict=True) # Pre-compute spike counts and firing rates if advanced NN metrics are requested advanced_nn_metrics = ["nn_advanced"] # Our grouped advanced NN metric diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 7f6175e47a..ed7ded47e6 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -189,7 +189,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids): m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"] ) - extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 6cae681f15..dfef0e3bc6 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -104,8 +104,7 @@ def _get_pipeline_nodes(self): else: cut_out_after = nafter - peak_sign = "neg" if np.abs(np.min(all_templates)) > np.max(all_templates) else "pos" - extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) # collisions diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index e180b868ab..fb9c89c521 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -100,7 +100,7 @@ def compute_monopolar_triangulation( neighbours_mask[i, chan_inds] = True enforce_decrease_radial_parents = make_radial_order_parents(contact_locations, neighbours_mask) - best_channels = sorting_analyzer_or_templates.get_main_channel(outputs="index", with_dict=True) + best_channels = sorting_analyzer_or_templates.get_main_channels(outputs="index", with_dict=True) unit_location = np.zeros((unit_ids.size, 4), dtype="float64") @@ -280,7 +280,7 @@ def compute_grid_convolution( contact_locations, radius_um, upsampling_um, margin_um, weight_method ) - main_channels = sorting_analyzer_or_templates.get_main_channel(outputs="index", with_dict=True) + main_channels = sorting_analyzer_or_templates.get_main_channels(outputs="index", with_dict=True) weights_sparsity_mask = weights > 0 @@ -687,7 +687,7 @@ def compute_location_max_channel( unit_locations: np.ndarray 2d """ - extremum_channels_index = templates_or_sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channels_index = templates_or_sorting_analyzer.get_main_channels(outputs="index", with_dict=True) contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index b5c7ff1f6d..5a7aebf728 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -35,7 +35,7 @@ def _get_pipeline_nodes(self): peak_sign = self.params["peak_sign"] return_in_uV = self.sorting_analyzer.return_in_uV - extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index cf5adb6198..c6c9eca021 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -73,8 +73,7 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"] - extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) retriever = SpikeRetriever( sorting, diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 7ac4d36d0a..0e46ba4df2 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -53,7 +53,7 @@ def __init__( num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: - main_channels = self.templates.get_main_channel(main_channel_peak_sign=self.peak_sign, outputs="index", with_dict=False) + main_channels = self.templates.get_main_channels(main_channel_peak_sign=self.peak_sign, outputs="index", with_dict=False) channel_locations = recording.get_channel_locations() template_distances = np.linalg.norm( diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index 1622ab83a5..96b667e404 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -223,7 +223,7 @@ def __init__( # as numpy vector - self.main_channels = templates.get_main_channel(main_channel_peak_sign=peak_sign, outputs="index", with_dict=False) + self.main_channels = templates.get_main_channels(main_channel_peak_sign=peak_sign, outputs="index", with_dict=False) channel_locations = templates.probe.contact_positions unit_locations = channel_locations[self.main_channels] diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 6f8ba998e7..23f662d14a 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -121,7 +121,7 @@ def __init__( else: if sparsity is None: # in this case, we construct a sparsity dictionary only with the main channel - main_channels = sorting_analyzer.get_main_channel(outputs="id", with_dict=True) + main_channels = sorting_analyzer.get_main_channels(outputs="id", with_dict=True) unit_id_to_channel_ids = {u: [ch] for u, ch in main_channels.items()} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 69bc2b05cf..2483f0a792 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -85,7 +85,7 @@ def __init__( if np.any(np.isnan(all_unit_locations[sorting.ids_to_indices(unit_ids)])): warnings.warn("Some unit locations contain NaN values. Replacing with extremum channel location.") - main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) for unit_id in unit_ids: if np.any(np.isnan(unit_locations[unit_id])): unit_locations[unit_id] = channel_locations[main_channels[unit_id]] diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 411cd34203..93c18a01e3 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -135,7 +135,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): col_counter += 1 unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) unit_location = unit_locations[unit_id] x, y = unit_location[0], unit_location[1] diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 91854ec0eb..10f9651bc6 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -25,8 +25,6 @@ class UnitWaveformDensityMapWidget(BaseWidget): If SortingAnalyzer is already sparse, the argument is ignored use_max_channel : bool, default: False Use only the max channel - peak_sign : "neg" | "pos" | "both", default: "neg" - Used to detect max channel only when use_max_channel=True unit_colors : dict | None, default: None Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted by matplotlib. If None, default colors are chosen using the `get_some_colors` function. @@ -60,7 +58,7 @@ def __init__( if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" - max_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + max_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) # sparsity is done on all the units even if unit_ids is a few ones because some backends need them all if sorting_analyzer.is_sparse():