Skip to content

Commit 6926532

Browse files
committed
Merge branch 'select_sorting_periods' into goodtimes
2 parents e785b64 + d0a1e66 commit 6926532

File tree

9 files changed

+367
-241
lines changed

9 files changed

+367
-241
lines changed

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,17 @@ class BaseMetricExtension(AnalyzerExtension):
901901
need_backward_compatibility_on_load = False
902902
metric_list: list[BaseMetric] = None # list of BaseMetric
903903

904+
@classmethod
905+
def get_available_metric_names(cls):
906+
"""Get the available metric names.
907+
908+
Returns
909+
-------
910+
available_metric_names : list[str]
911+
List of available metric names.
912+
"""
913+
return [m.metric_name for m in cls.metric_list]
914+
904915
@classmethod
905916
def get_default_metric_params(cls):
906917
"""Get the default metric parameters.
Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,85 @@
11
import pytest
22

3-
from spikeinterface.postprocessing.tests.conftest import _small_sorting_analyzer
3+
from spikeinterface.core import (
4+
generate_ground_truth_recording,
5+
create_sorting_analyzer,
6+
)
7+
8+
job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")
9+
10+
11+
def make_small_analyzer():
12+
recording, sorting = generate_ground_truth_recording(
13+
durations=[10.0],
14+
num_units=10,
15+
seed=1205,
16+
)
17+
18+
channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
19+
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
20+
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
21+
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)
22+
23+
sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"])
24+
25+
sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")
26+
27+
extensions_to_compute = {
28+
"random_spikes": {"seed": 1205},
29+
"noise_levels": {"seed": 1205},
30+
"waveforms": {},
31+
"templates": {"operators": ["average", "median"]},
32+
"spike_amplitudes": {},
33+
"spike_locations": {},
34+
"principal_components": {},
35+
}
36+
37+
sorting_analyzer.compute(extensions_to_compute)
38+
39+
return sorting_analyzer
440

541

642
@pytest.fixture(scope="module")
743
def small_sorting_analyzer():
8-
return _small_sorting_analyzer()
44+
return make_small_analyzer()
45+
46+
47+
@pytest.fixture(scope="module")
48+
def sorting_analyzer_simple():
49+
# we need high firing rate for amplitude_cutoff
50+
recording, sorting = generate_ground_truth_recording(
51+
durations=[
52+
120.0,
53+
],
54+
sampling_frequency=30_000.0,
55+
num_channels=6,
56+
num_units=10,
57+
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
58+
generate_unit_locations_kwargs=dict(
59+
margin_um=5.0,
60+
minimum_z=5.0,
61+
maximum_z=20.0,
62+
),
63+
generate_templates_kwargs=dict(
64+
unit_params=dict(
65+
alpha=(200.0, 500.0),
66+
)
67+
),
68+
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
69+
seed=1205,
70+
)
71+
72+
channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
73+
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
74+
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
75+
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)
76+
77+
sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)
78+
79+
sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205)
80+
sorting_analyzer.compute("noise_levels")
81+
sorting_analyzer.compute("waveforms", **job_kwargs)
82+
sorting_analyzer.compute("templates")
83+
sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs)
84+
85+
return sorting_analyzer

0 commit comments

Comments
 (0)