Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/spikeinterface/benchmark/benchmark_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.core.template_tools import get_template_extremum_channel


class ClusteringBenchmark(Benchmark):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset
from spikeinterface.benchmark.benchmark_clustering import ClusteringStudy
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.core.template_tools import get_template_extremum_channel

from pathlib import Path

Expand All @@ -33,7 +31,8 @@ def test_benchmark_clustering(create_cache_folder):

# sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False)
# sorting_analyzer.compute(["random_spikes", "templates"])
extremum_channel_inds = get_template_extremum_channel(gt_analyzer, outputs="index")
extremum_channel_inds = gt_analyzer.get_main_channels(outputs="index", with_dict=True)

spikes = gt_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)
peaks[dataset] = spikes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset
from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.core.template_tools import get_template_extremum_channel


@pytest.mark.skip()
Expand All @@ -30,7 +29,7 @@ def test_benchmark_peak_detection(create_cache_folder):
sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs)
sorting_analyzer.compute("random_spikes")
sorting_analyzer.compute("templates", **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index")
extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True)
spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)
peaks[dataset] = spikes

Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class BaseSorting(BaseExtractor):
"""
Abstract class representing several segment several units and relative spiketrains.
"""
_main_properties = [
"main_channel_index",
]

def __init__(self, sampling_frequency: float, unit_ids: list):
BaseExtractor.__init__(self, unit_ids)
Expand Down Expand Up @@ -786,6 +789,7 @@ def _compute_and_cache_spike_vector(self) -> None:
self._cached_spike_vector = spikes
self._cached_spike_vector_segment_slices = segment_slices

# TODO sam : change extremum_channel_inds to main_channel_index with vector
def to_spike_vector(
self,
concatenated=True,
Expand All @@ -806,7 +810,8 @@ def to_spike_vector(
extremum_channel_inds : None or dict, default: None
If a dictionnary of unit_id to channel_ind is given then an extra field "channel_index".
This can be convinient for computing spikes postion after sorter.
This dict can be computed with `get_template_extremum_channel(we, outputs="index")`
This dict can be given by analyzer.get_main_channels(outputs="index", with_dict=True)

use_cache : bool, default: True
When True the spikes vector is cached as an attribute of the object (`_cached_spike_vector`).
This caching only occurs when extremum_channel_inds=None.
Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2444,6 +2444,10 @@ def generate_ground_truth_recording(
**generate_templates_kwargs,
)
sorting.set_property("gt_unit_locations", unit_locations)
distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :], axis=2)
main_channel_index = np.argmin(distances, axis=1)
sorting.set_property("main_channel_index", main_channel_index)

else:
assert templates.shape[0] == num_units

Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea
return (local_peaks,)


# TODO sam replace extremum_channels_indices by main_channel_index

# this is not implemented yet this will be done in separted PR
class SpikeRetriever(PeakSource):
"""
Expand Down
Loading
Loading