Skip to content

Commit 7b93c50

Browse files
Implement select_sorting_periods in metrics (#4302)
Co-authored-by: Chris Halcrow <[email protected]>
1 parent 02c940f commit 7b93c50

File tree

12 files changed

+1052
-590
lines changed

12 files changed

+1052
-590
lines changed

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
from collections import namedtuple
1515

16+
from .numpyextractors import NumpySorting
1617
from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension
1718
from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator
1819
from .recording_tools import get_noise_levels
@@ -823,10 +824,9 @@ class BaseMetric:
823824
metric_columns = {} # column names and their dtypes of the dataframe
824825
metric_descriptions = {} # descriptions of each metric column
825826
needs_recording = False # whether the metric needs recording
826-
needs_tmp_data = (
827-
False # whether the metric needs temporary data comoputed with _prepare_data at the MetricExtension level
828-
)
829-
needs_job_kwargs = False
827+
needs_tmp_data = False # whether the metric needs temporary data computed with MetricExtension._prepare_data
828+
needs_job_kwargs = False # whether the metric needs job_kwargs
829+
supports_periods = False # whether the metric function supports periods
830830
depend_on = [] # extensions the metric depends on
831831

832832
# the metric function must have the signature:
@@ -839,7 +839,7 @@ class BaseMetric:
839839
metric_function = None # to be defined in subclass
840840

841841
@classmethod
842-
def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs):
842+
def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs, periods=None):
843843
"""Compute the metric.
844844
845845
Parameters
@@ -854,6 +854,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs
854854
Temporary data to pass to the metric function
855855
job_kwargs : dict
856856
Job keyword arguments to control parallelization
857+
periods : np.ndarray | None
858+
Numpy array of unit periods of unit_period_dtype if supports_periods is True
857859
858860
Returns
859861
-------
@@ -865,6 +867,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs
865867
args += (tmp_data,)
866868
if cls.needs_job_kwargs:
867869
args += (job_kwargs,)
870+
if cls.supports_periods:
871+
args += (periods,)
868872

869873
results = cls.metric_function(*args, **metric_params)
870874

@@ -897,6 +901,17 @@ class BaseMetricExtension(AnalyzerExtension):
897901
need_backward_compatibility_on_load = False
898902
metric_list: list[BaseMetric] = None # list of BaseMetric
899903

904+
@classmethod
905+
def get_available_metric_names(cls):
906+
"""Get the available metric names.
907+
908+
Returns
909+
-------
910+
available_metric_names : list[str]
911+
List of available metric names.
912+
"""
913+
return [m.metric_name for m in cls.metric_list]
914+
900915
@classmethod
901916
def get_default_metric_params(cls):
902917
"""Get the default metric parameters.
@@ -988,6 +1003,7 @@ def _set_params(
9881003
metric_params: dict | None = None,
9891004
delete_existing_metrics: bool = False,
9901005
metrics_to_compute: list[str] | None = None,
1006+
periods: np.ndarray | None = None,
9911007
**other_params,
9921008
):
9931009
"""
@@ -1004,6 +1020,8 @@ def _set_params(
10041020
If True, existing metrics in the extension will be deleted before computing new ones.
10051021
metrics_to_compute : list[str] | None
10061022
List of metric names to compute. If None, all metrics in `metric_names` are computed.
1023+
periods : np.ndarray | None
1024+
Numpy array of unit_period_dtype defining periods to compute metrics over.
10071025
other_params : dict
10081026
Additional parameters for metric computation.
10091027
@@ -1079,6 +1097,7 @@ def _set_params(
10791097
metrics_to_compute=metrics_to_compute,
10801098
delete_existing_metrics=delete_existing_metrics,
10811099
metric_params=metric_params,
1100+
periods=periods,
10821101
**other_params,
10831102
)
10841103
return params
@@ -1129,6 +1148,8 @@ def _compute_metrics(
11291148
if metric_names is None:
11301149
metric_names = self.params["metric_names"]
11311150

1151+
periods = self.params.get("periods", None)
1152+
11321153
column_names_dtypes = {}
11331154
for metric_name in metric_names:
11341155
metric = [m for m in self.metric_list if m.metric_name == metric_name][0]
@@ -1153,6 +1174,7 @@ def _compute_metrics(
11531174
metric_params=metric_params,
11541175
tmp_data=tmp_data,
11551176
job_kwargs=job_kwargs,
1177+
periods=periods,
11561178
)
11571179
except Exception as e:
11581180
warnings.warn(f"Error computing metric {metric_name}: {e}")
@@ -1179,6 +1201,7 @@ def _run(self, **job_kwargs):
11791201

11801202
metrics_to_compute = self.params["metrics_to_compute"]
11811203
delete_existing_metrics = self.params["delete_existing_metrics"]
1204+
periods = self.params.get("periods", None)
11821205

11831206
_, job_kwargs = split_job_kwargs(job_kwargs)
11841207
job_kwargs = fix_job_kwargs(job_kwargs)
@@ -1452,6 +1475,16 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None,
14521475
periods,
14531476
)
14541477
all_data = all_data[keep_mask]
1478+
# since we have the mask already, we can use it directly to avoid double computation
1479+
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=True)
1480+
sliced_spike_vector = spike_vector[keep_mask]
1481+
sorting = NumpySorting(
1482+
sliced_spike_vector,
1483+
sampling_frequency=self.sorting_analyzer.sampling_frequency,
1484+
unit_ids=self.sorting_analyzer.unit_ids,
1485+
)
1486+
else:
1487+
sorting = self.sorting_analyzer.sorting
14551488

14561489
if outputs == "numpy":
14571490
if copy:
@@ -1460,10 +1493,10 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None,
14601493
return all_data
14611494
elif outputs == "by_unit":
14621495
unit_ids = self.sorting_analyzer.unit_ids
1463-
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False)
1496+
14641497
if keep_mask is not None:
14651498
# since we are filtering spikes, we need to recompute the spike indices
1466-
spike_vector = spike_vector[keep_mask]
1499+
spike_vector = sorting.to_spike_vector(concatenated=False)
14671500
spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True)
14681501
else:
14691502
# use the cache of indices
Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,85 @@
11
import pytest
22

3-
from spikeinterface.postprocessing.tests.conftest import _small_sorting_analyzer
3+
from spikeinterface.core import (
4+
generate_ground_truth_recording,
5+
create_sorting_analyzer,
6+
)
7+
8+
job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")
9+
10+
11+
def make_small_analyzer():
12+
recording, sorting = generate_ground_truth_recording(
13+
durations=[10.0],
14+
num_units=10,
15+
seed=1205,
16+
)
17+
18+
channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
19+
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
20+
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
21+
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)
22+
23+
sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"])
24+
25+
sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")
26+
27+
extensions_to_compute = {
28+
"random_spikes": {"seed": 1205},
29+
"noise_levels": {"seed": 1205},
30+
"waveforms": {},
31+
"templates": {"operators": ["average", "median"]},
32+
"spike_amplitudes": {},
33+
"spike_locations": {},
34+
"principal_components": {},
35+
}
36+
37+
sorting_analyzer.compute(extensions_to_compute)
38+
39+
return sorting_analyzer
440

541

642
@pytest.fixture(scope="module")
743
def small_sorting_analyzer():
8-
return _small_sorting_analyzer()
44+
return make_small_analyzer()
45+
46+
47+
@pytest.fixture(scope="module")
48+
def sorting_analyzer_simple():
49+
# we need high firing rate for amplitude_cutoff
50+
recording, sorting = generate_ground_truth_recording(
51+
durations=[
52+
120.0,
53+
],
54+
sampling_frequency=30_000.0,
55+
num_channels=6,
56+
num_units=10,
57+
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
58+
generate_unit_locations_kwargs=dict(
59+
margin_um=5.0,
60+
minimum_z=5.0,
61+
maximum_z=20.0,
62+
),
63+
generate_templates_kwargs=dict(
64+
unit_params=dict(
65+
alpha=(200.0, 500.0),
66+
)
67+
),
68+
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
69+
seed=1205,
70+
)
71+
72+
channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
73+
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
74+
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
75+
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)
76+
77+
sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)
78+
79+
sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205)
80+
sorting_analyzer.compute("noise_levels")
81+
sorting_analyzer.compute("waveforms", **job_kwargs)
82+
sorting_analyzer.compute("templates")
83+
sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs)
84+
85+
return sorting_analyzer

0 commit comments

Comments
 (0)