@@ -1129,13 +1129,20 @@ def compute_drift_metrics(
11291129 unit_ids = sorting .unit_ids
11301130
11311131 spike_locations_ext = sorting_analyzer .get_extension ("spike_locations" )
1132- spike_locations_array = spike_locations_ext .get_data (periods = periods )
1132+ spike_locations_by_unit_and_segments = spike_locations_ext .get_data (
1133+ outputs = "by_unit" , concatenated = False , periods = periods
1134+ )
11331135 spike_locations_by_unit = spike_locations_ext .get_data (outputs = "by_unit" , concatenated = True , periods = periods )
11341136
11351137 segment_samples = [sorting_analyzer .get_num_samples (i ) for i in range (sorting_analyzer .get_num_segments ())]
1136- assert direction in spike_locations_array .dtype .names , (
1137- f"Direction { direction } is invalid. Available directions: " f"{ spike_locations_array .dtype .names } "
1138+ data = spike_locations_by_unit [unit_ids [0 ]]
1139+ assert direction in data .dtype .names , (
1140+ f"Direction { direction } is invalid. Available directions: " f"{ data .dtype .names } "
1141+ )
1142+ bin_edges_for_units = compute_bin_edges_per_unit (
1143+ sorting , segment_samples = segment_samples , periods = periods , bin_duration_s = interval_s , concatenated = False
11381144 )
1145+ failed_units = []
11391146
11401147 # we need
11411148 drift_ptps = {}
@@ -1144,62 +1151,43 @@ def compute_drift_metrics(
11441151
11451152 # reference positions are the medians across segments
11461153 reference_positions = {}
1154+ median_position_segments = {unit_id : np .array ([]) for unit_id in unit_ids }
1155+
11471156 for unit_id in unit_ids :
11481157 reference_positions [unit_id ] = np .median (spike_locations_by_unit [unit_id ][direction ])
11491158
1150- # now compute median positions and concatenate them over segments
1151- spike_vector = sorting .to_spike_vector ()
1152- spike_sample_indices = spike_vector ["sample_index" ].copy ()
1153- # we need to add the cumulative sum of segment samples to have global sample indices
1154- cumulative_segment_samples = np .cumsum ([0 ] + segment_samples [:- 1 ])
11551159 for segment_index in range (sorting_analyzer .get_num_segments ()):
1156- segment_slice = sorting ._get_spike_vector_segment_slices ()[segment_index ]
1157- spike_sample_indices [segment_slice [0 ] : segment_slice [1 ]] += cumulative_segment_samples [segment_index ]
1158-
1159- bin_edges_for_units = compute_bin_edges_per_unit (
1160- sorting ,
1161- segment_samples = segment_samples ,
1162- periods = periods ,
1163- bin_duration_s = interval_s ,
1164- )
1165-
1166- failed_units = []
1167- median_positions_per_unit = {}
1160+ for unit_id in unit_ids :
1161+ bins = bin_edges_for_units [unit_id ][segment_index ]
1162+ num_bin_edges = len (bins )
1163+ if (num_bin_edges - 1 ) < min_num_bins :
1164+ failed_units .append (unit_id )
1165+ continue
1166+ median_positions = np .nan * np .zeros ((num_bin_edges - 1 ))
1167+ spikes_in_segment_of_unit = sorting .get_unit_spike_train (unit_id , segment_index )
1168+ bounds = np .searchsorted (spikes_in_segment_of_unit , bins , side = "left" )
1169+ for bin_index , (i0 , i1 ) in enumerate (zip (bounds [:- 1 ], bounds [1 :])):
1170+ spike_locations_in_bin = spike_locations_by_unit_and_segments [segment_index ][unit_id ][i0 :i1 ][direction ]
1171+ if (i1 - i0 ) >= min_spikes_per_interval :
1172+ median_positions [bin_index ] = np .median (spike_locations_in_bin )
1173+ median_position_segments [unit_id ] = np .concatenate ((median_position_segments [unit_id ], median_positions ))
1174+
1175+ # finally, compute deviations and drifts
11681176 for unit_id in unit_ids :
1169- bins = bin_edges_for_units [unit_id ]
1170- num_bins = len (bins ) - 1
1171- if num_bins < min_num_bins :
1177+ # Skip units that already failed because not enough bins in at least one segment
1178+ if unit_id in failed_units :
11721179 drift_ptps [unit_id ] = np .nan
11731180 drift_stds [unit_id ] = np .nan
11741181 drift_mads [unit_id ] = np .nan
1175- failed_units .append (unit_id )
11761182 continue
1177-
1178- # bin_edges are global across segments, so we have to use spike_sample_indices,
1179- # since we offseted them to be global
1180- bin_spike_indices = np .searchsorted (spike_sample_indices , bins )
1181- median_positions = np .nan * np .zeros (num_bins )
1182- for bin_index , (i0 , i1 ) in enumerate (zip (bin_spike_indices [:- 1 ], bin_spike_indices [1 :])):
1183- spikes_in_bin = spike_vector [i0 :i1 ]
1184- spike_locations_in_bin = spike_locations_array [i0 :i1 ][direction ]
1185-
1186- unit_index = sorting_analyzer .sorting .id_to_index (unit_id )
1187- mask = spikes_in_bin ["unit_index" ] == unit_index
1188- if np .sum (mask ) >= min_spikes_per_interval :
1189- median_positions [bin_index ] = np .median (spike_locations_in_bin [mask ])
1190- else :
1191- median_positions [bin_index ] = np .nan
1192- median_positions_per_unit [unit_id ] = median_positions
1193-
1194- # now compute deviations and drifts for this unit
1195- position_diff = median_positions - reference_positions [unit_id ]
1183+ position_diff = median_position_segments [unit_id ] - reference_positions [unit_id ]
11961184 if np .any (np .isnan (position_diff )):
11971185 # deal with nans: if more than 50% nans --> set to nan
11981186 if np .sum (np .isnan (position_diff )) > min_fraction_valid_intervals * len (position_diff ):
1199- failed_units .append (unit_id )
12001187 ptp_drift = np .nan
12011188 std_drift = np .nan
12021189 mad_drift = np .nan
1190+ failed_units .append (unit_id )
12031191 else :
12041192 ptp_drift = np .nanmax (position_diff ) - np .nanmin (position_diff )
12051193 std_drift = np .nanstd (np .abs (position_diff ))
@@ -1219,7 +1207,7 @@ def compute_drift_metrics(
12191207 )
12201208
12211209 if return_positions :
1222- outs = res (drift_ptps , drift_stds , drift_mads ), median_positions_per_unit
1210+ outs = res (drift_ptps , drift_stds , drift_mads ), median_positions
12231211 else :
12241212 outs = res (drift_ptps , drift_stds , drift_mads )
12251213 return outs
0 commit comments