@@ -112,30 +112,20 @@ def check_sortings_equal(
112112
113113 max_spike_index = SX1 .to_spike_vector ()["sample_index" ].max ()
114114
115- # TODO for later use to_spike_vector() to do this without looping
116- for segment_idx in range (SX1 .get_num_segments ()):
117- # get_unit_ids
118- ids1 = np .sort (np .array (SX1 .get_unit_ids ()))
119- ids2 = np .sort (np .array (SX2 .get_unit_ids ()))
120- assert_array_equal (ids1 , ids2 )
121- for id in ids1 :
122- train1 = np .sort (SX1 .get_unit_spike_train (id , segment_index = segment_idx ))
123- train2 = np .sort (SX2 .get_unit_spike_train (id , segment_index = segment_idx ))
124- assert np .array_equal (train1 , train2 )
125- train1 = np .sort (SX1 .get_unit_spike_train (id , segment_index = segment_idx , start_frame = 30 ))
126- train2 = np .sort (SX2 .get_unit_spike_train (id , segment_index = segment_idx , start_frame = 30 ))
127- assert np .array_equal (train1 , train2 )
128- # test that slicing works correctly
129- train1 = np .sort (SX1 .get_unit_spike_train (id , segment_index = segment_idx , end_frame = max_spike_index - 30 ))
130- train2 = np .sort (SX2 .get_unit_spike_train (id , segment_index = segment_idx , end_frame = max_spike_index - 30 ))
131- assert np .array_equal (train1 , train2 )
132- train1 = np .sort (
133- SX1 .get_unit_spike_train (id , segment_index = segment_idx , start_frame = 30 , end_frame = max_spike_index - 30 )
134- )
135- train2 = np .sort (
136- SX2 .get_unit_spike_train (id , segment_index = segment_idx , start_frame = 30 , end_frame = max_spike_index - 30 )
137- )
138- assert np .array_equal (train1 , train2 )
115+ s1 = SX1 .to_spike_vector ()
116+ s2 = SX2 .to_spike_vector ()
117+ assert_array_equal (s1 , s2 )
118+
119+ for start_frame , end_frame in [
120+ (None , None ),
121+ (30 , None ),
122+ (None , max_spike_index - 30 ),
123+ (30 , max_spike_index - 30 ),
124+ ]:
125+
126+ slice1 = _slice_spikes (s1 , start_frame , end_frame )
127+ slice2 = _slice_spikes (s2 , start_frame , end_frame )
128+ assert np .array_equal (slice1 , slice2 )
139129
140130 if check_annotations :
141131 check_extractor_annotations_equal (SX1 , SX2 )
@@ -155,3 +145,16 @@ def check_extractor_properties_equal(EX1, EX2) -> None:
155145
156146 for property_name in EX1 .get_property_keys ():
157147 assert_array_equal (EX1 .get_property (property_name ), EX2 .get_property (property_name ))
148+
149+
150+ def _slice_spikes (spikes , start_frame = None , end_frame = None ):
151+ sample_indices = spikes ["sample_index" ]
152+ if len (sample_indices ) == 0 :
153+ return spikes [:0 ]
154+ if start_frame is None :
155+ start_frame = sample_indices [0 ]
156+ if end_frame is None :
157+ end_frame = sample_indices [- 1 ] + 1
158+ start_idx , end_idx = np .searchsorted (sample_indices , [start_frame , end_frame + 1 ], side = "left" )
159+
160+ return spikes [start_idx :end_idx ]
0 commit comments