Skip to content

Commit c4e5551

Browse files
authored
Merge branch 'main' into goodtimes
2 parents 72a4e68 + 456265c commit c4e5551

File tree

4 files changed

+32
-29
lines changed

4 files changed

+32
-29
lines changed

src/spikeinterface/core/basesorting.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -766,10 +766,6 @@ def _compute_and_cache_spike_vector(self) -> None:
766766
if len(sample_indices) > 0:
767767
sample_indices = np.concatenate(sample_indices, dtype="int64")
768768
unit_indices = np.concatenate(unit_indices, dtype="int64")
769-
order = np.argsort(sample_indices)
770-
sample_indices = sample_indices[order]
771-
unit_indices = unit_indices[order]
772-
773769
n = sample_indices.size
774770
segment_slices[segment_index, 0] = seg_pos
775771
segment_slices[segment_index, 1] = seg_pos + n
@@ -783,7 +779,9 @@ def _compute_and_cache_spike_vector(self) -> None:
783779
spikes_in_seg["unit_index"] = unit_indices
784780
spikes_in_seg["segment_index"] = segment_index
785781
spikes.append(spikes_in_seg)
782+
786783
spikes = np.concatenate(spikes)
784+
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
787785

788786
self._cached_spike_vector = spikes
789787
self._cached_spike_vector_segment_slices = segment_slices

src/spikeinterface/core/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def generate_sorting(
174174
spikes.append(spikes_on_borders)
175175

176176
spikes = np.concatenate(spikes)
177-
spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))]
177+
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
178178

179179
sorting = NumpySorting(spikes, sampling_frequency, unit_ids)
180180

src/spikeinterface/core/testing.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -112,30 +112,20 @@ def check_sortings_equal(
112112

113113
max_spike_index = SX1.to_spike_vector()["sample_index"].max()
114114

115-
# TODO for later use to_spike_vector() to do this without looping
116-
for segment_idx in range(SX1.get_num_segments()):
117-
# get_unit_ids
118-
ids1 = np.sort(np.array(SX1.get_unit_ids()))
119-
ids2 = np.sort(np.array(SX2.get_unit_ids()))
120-
assert_array_equal(ids1, ids2)
121-
for id in ids1:
122-
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx))
123-
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx))
124-
assert np.array_equal(train1, train2)
125-
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30))
126-
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30))
127-
assert np.array_equal(train1, train2)
128-
# test that slicing works correctly
129-
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30))
130-
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30))
131-
assert np.array_equal(train1, train2)
132-
train1 = np.sort(
133-
SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30)
134-
)
135-
train2 = np.sort(
136-
SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30)
137-
)
138-
assert np.array_equal(train1, train2)
115+
s1 = SX1.to_spike_vector()
116+
s2 = SX2.to_spike_vector()
117+
assert_array_equal(s1, s2)
118+
119+
for start_frame, end_frame in [
120+
(None, None),
121+
(30, None),
122+
(None, max_spike_index - 30),
123+
(30, max_spike_index - 30),
124+
]:
125+
126+
slice1 = _slice_spikes(s1, start_frame, end_frame)
127+
slice2 = _slice_spikes(s2, start_frame, end_frame)
128+
assert np.array_equal(slice1, slice2)
139129

140130
if check_annotations:
141131
check_extractor_annotations_equal(SX1, SX2)
@@ -155,3 +145,16 @@ def check_extractor_properties_equal(EX1, EX2) -> None:
155145

156146
for property_name in EX1.get_property_keys():
157147
assert_array_equal(EX1.get_property(property_name), EX2.get_property(property_name))
148+
149+
150+
def _slice_spikes(spikes, start_frame=None, end_frame=None):
151+
sample_indices = spikes["sample_index"]
152+
if len(sample_indices) == 0:
153+
return spikes[:0]
154+
if start_frame is None:
155+
start_frame = sample_indices[0]
156+
if end_frame is None:
157+
end_frame = sample_indices[-1] + 1
158+
start_idx, end_idx = np.searchsorted(sample_indices, [start_frame, end_frame + 1], side="left")
159+
160+
return spikes[start_idx:end_idx]

src/spikeinterface/core/zarrextractors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None,
295295
spikes["unit_index"] = spikes_group["unit_index"][:]
296296
for i, (start, end) in enumerate(segment_slices_list):
297297
spikes["segment_index"][start:end] = i
298+
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
299+
self._cached_spike_vector = spikes
298300

299301
for segment_index in range(num_segments):
300302
soring_segment = SpikeVectorSortingSegment(spikes, segment_index, unit_ids)

0 commit comments

Comments
 (0)