|
1 | 1 | import pytest |
2 | 2 |
|
3 | | -from spikeinterface.postprocessing.tests.conftest import _small_sorting_analyzer |
| 3 | +from spikeinterface.core import ( |
| 4 | + generate_ground_truth_recording, |
| 5 | + create_sorting_analyzer, |
| 6 | +) |
| 7 | + |
| 8 | +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") |
| 9 | + |
| 10 | + |
| 11 | +def make_small_analyzer(): |
| 12 | + recording, sorting = generate_ground_truth_recording( |
| 13 | + durations=[10.0], |
| 14 | + num_units=10, |
| 15 | + seed=1205, |
| 16 | + ) |
| 17 | + |
| 18 | + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] |
| 19 | + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] |
| 20 | + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) |
| 21 | + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) |
| 22 | + |
| 23 | + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) |
| 24 | + |
| 25 | + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") |
| 26 | + |
| 27 | + extensions_to_compute = { |
| 28 | + "random_spikes": {"seed": 1205}, |
| 29 | + "noise_levels": {"seed": 1205}, |
| 30 | + "waveforms": {}, |
| 31 | + "templates": {"operators": ["average", "median"]}, |
| 32 | + "spike_amplitudes": {}, |
| 33 | + "spike_locations": {}, |
| 34 | + "principal_components": {}, |
| 35 | + } |
| 36 | + |
| 37 | + sorting_analyzer.compute(extensions_to_compute) |
| 38 | + |
| 39 | + return sorting_analyzer |
4 | 40 |
|
5 | 41 |
|
6 | 42 | @pytest.fixture(scope="module") |
7 | 43 | def small_sorting_analyzer(): |
8 | | - return _small_sorting_analyzer() |
| 44 | + return make_small_analyzer() |
| 45 | + |
| 46 | + |
| 47 | +@pytest.fixture(scope="module") |
| 48 | +def sorting_analyzer_simple(): |
| 49 | + # we need high firing rate for amplitude_cutoff |
| 50 | + recording, sorting = generate_ground_truth_recording( |
| 51 | + durations=[ |
| 52 | + 120.0, |
| 53 | + ], |
| 54 | + sampling_frequency=30_000.0, |
| 55 | + num_channels=6, |
| 56 | + num_units=10, |
| 57 | + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), |
| 58 | + generate_unit_locations_kwargs=dict( |
| 59 | + margin_um=5.0, |
| 60 | + minimum_z=5.0, |
| 61 | + maximum_z=20.0, |
| 62 | + ), |
| 63 | + generate_templates_kwargs=dict( |
| 64 | + unit_params=dict( |
| 65 | + alpha=(200.0, 500.0), |
| 66 | + ) |
| 67 | + ), |
| 68 | + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), |
| 69 | + seed=1205, |
| 70 | + ) |
| 71 | + |
| 72 | + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] |
| 73 | + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] |
| 74 | + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) |
| 75 | + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) |
| 76 | + |
| 77 | + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) |
| 78 | + |
| 79 | + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) |
| 80 | + sorting_analyzer.compute("noise_levels") |
| 81 | + sorting_analyzer.compute("waveforms", **job_kwargs) |
| 82 | + sorting_analyzer.compute("templates") |
| 83 | + sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs) |
| 84 | + |
| 85 | + return sorting_analyzer |
0 commit comments