Skip to content

Commit b5bf3c3

Browse files
committed
Move warnings at the end of the loop for firing range and drift
1 parent 3291638 commit b5bf3c3

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

src/spikeinterface/metrics/quality/misc_metrics.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -672,13 +672,20 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz
672672

673673
# finally we compute the percentiles
674674
firing_ranges = {}
675+
failed_units = []
675676
for unit_id in unit_ids:
676677
if num_spikes[unit_id] == 0 or total_samples[unit_id] < bin_size_samples:
678+
failed_units.append(unit_id)
677679
firing_ranges[unit_id] = np.nan
678680
continue
679681
firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile(
680682
firing_rate_histograms[unit_id], percentiles[0]
681683
)
684+
if len(failed_units) > 0:
685+
warnings.warn(
686+
f"Firing range could not be computed for units {failed_units} "
687+
f"because they have no spikes or the total duration is less than bin size."
688+
)
682689

683690
return firing_ranges
684691

@@ -1156,18 +1163,16 @@ def compute_drift_metrics(
11561163
bin_duration_s=interval_s,
11571164
)
11581165

1166+
failed_units = []
11591167
median_positions_per_unit = {}
11601168
for unit_id in unit_ids:
11611169
bins = bin_edges_for_units[unit_id]
11621170
num_bins = len(bins) - 1
11631171
if num_bins < min_num_bins:
1164-
warnings.warn(
1165-
f"Unit {unit_id} has only {num_bins} bins given the specified 'interval_s' and "
1166-
f"'min_num_bins'. Drift metrics will be set to NaN"
1167-
)
11681172
drift_ptps[unit_id] = np.nan
11691173
drift_stds[unit_id] = np.nan
11701174
drift_mads[unit_id] = np.nan
1175+
failed_units.append(unit_id)
11711176
continue
11721177

11731178
# bin_edges are global across segments, so we have to use spike_sample_indices,
@@ -1191,6 +1196,7 @@ def compute_drift_metrics(
11911196
if np.any(np.isnan(position_diff)):
11921197
# deal with nans: if more than 50% nans --> set to nan
11931198
if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff):
1199+
failed_units.append(unit_id)
11941200
ptp_drift = np.nan
11951201
std_drift = np.nan
11961202
mad_drift = np.nan
@@ -1206,6 +1212,12 @@ def compute_drift_metrics(
12061212
drift_stds[unit_id] = std_drift
12071213
drift_mads[unit_id] = mad_drift
12081214

1215+
if len(failed_units) > 0:
1216+
warnings.warn(
1217+
f"Drift metrics could not be computed for units {failed_units} because they have less than "
1218+
f"{min_num_bins} bins given the specified 'interval_s' and 'min_num_bins' or not enough valid intervals."
1219+
)
1220+
12091221
if return_positions:
12101222
outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit
12111223
else:

0 commit comments

Comments
 (0)