@@ -672,13 +672,20 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz
672672
673673 # finally we compute the percentiles
674674 firing_ranges = {}
675+ failed_units = []
675676 for unit_id in unit_ids :
676677 if num_spikes [unit_id ] == 0 or total_samples [unit_id ] < bin_size_samples :
678+ failed_units .append (unit_id )
677679 firing_ranges [unit_id ] = np .nan
678680 continue
679681 firing_ranges [unit_id ] = np .percentile (firing_rate_histograms [unit_id ], percentiles [1 ]) - np .percentile (
680682 firing_rate_histograms [unit_id ], percentiles [0 ]
681683 )
684+ if len (failed_units ) > 0 :
685+ warnings .warn (
686+ f"Firing range could not be computed for units { failed_units } "
687+ f"because they have no spikes or the total duration is less than bin size."
688+ )
682689
683690 return firing_ranges
684691
@@ -1156,18 +1163,16 @@ def compute_drift_metrics(
11561163 bin_duration_s = interval_s ,
11571164 )
11581165
1166+ failed_units = []
11591167 median_positions_per_unit = {}
11601168 for unit_id in unit_ids :
11611169 bins = bin_edges_for_units [unit_id ]
11621170 num_bins = len (bins ) - 1
11631171 if num_bins < min_num_bins :
1164- warnings .warn (
1165- f"Unit { unit_id } has only { num_bins } bins given the specified 'interval_s' and "
1166- f"'min_num_bins'. Drift metrics will be set to NaN"
1167- )
11681172 drift_ptps [unit_id ] = np .nan
11691173 drift_stds [unit_id ] = np .nan
11701174 drift_mads [unit_id ] = np .nan
1175+ failed_units .append (unit_id )
11711176 continue
11721177
11731178 # bin_edges are global across segments, so we have to use spike_sample_indices,
@@ -1191,6 +1196,7 @@ def compute_drift_metrics(
11911196 if np .any (np .isnan (position_diff )):
11921197 # deal with nans: if more than 50% nans --> set to nan
11931198 if np .sum (np .isnan (position_diff )) > min_fraction_valid_intervals * len (position_diff ):
1199+ failed_units .append (unit_id )
11941200 ptp_drift = np .nan
11951201 std_drift = np .nan
11961202 mad_drift = np .nan
@@ -1206,6 +1212,12 @@ def compute_drift_metrics(
12061212 drift_stds [unit_id ] = std_drift
12071213 drift_mads [unit_id ] = mad_drift
12081214
1215+ if len (failed_units ) > 0 :
1216+ warnings .warn (
1217+ f"Drift metrics could not be computed for units { failed_units } because they have less than "
1218+ f"{ min_num_bins } bins given the specified 'interval_s' and 'min_num_bins' or not enough valid intervals."
1219+ )
1220+
12091221 if return_positions :
12101222 outs = res (drift_ptps , drift_stds , drift_mads ), median_positions_per_unit
12111223 else :
0 commit comments