Skip to content

Commit 066c378

Browse files
committed
Force NaN/-1 values for float/int metrics if num_spikes is 0
1 parent 173e747 commit 066c378

File tree

2 files changed

+108
-9
lines changed

2 files changed

+108
-9
lines changed

src/spikeinterface/metrics/quality/misc_metrics.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def compute_presence_ratios(
7575
if unit_ids is None:
7676
unit_ids = sorting_analyzer.unit_ids
7777
num_segs = sorting_analyzer.get_num_segments()
78+
num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids)
7879

7980
segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)]
8081
total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods)
@@ -104,6 +105,9 @@ def compute_presence_ratios(
104105
else:
105106

106107
for unit_id in unit_ids:
108+
if num_spikes[unit_id] == 0:
109+
presence_ratios[unit_id] = np.nan
110+
continue
107111
spike_train = []
108112
bin_edges = bin_edges_per_unit[unit_id]
109113
if len(bin_edges) < 2:
@@ -264,6 +268,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, periods=None, isi_th
264268
unit_ids = sorting_analyzer.unit_ids
265269

266270
total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods)
271+
num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids)
267272
fs = sorting_analyzer.sampling_frequency
268273

269274
isi_threshold_s = isi_threshold_ms / 1000
@@ -273,15 +278,17 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, periods=None, isi_th
273278
isi_violations_ratio = {}
274279

275280
for unit_id in unit_ids:
281+
if num_spikes[unit_id] == 0:
282+
isi_violations_ratio[unit_id] = np.nan
283+
isi_violations_count[unit_id] = -1
284+
continue
285+
276286
spike_train_list = []
277287
for segment_index in range(sorting_analyzer.get_num_segments()):
278288
spike_train = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index)
279289
if len(spike_train) > 0:
280290
spike_train_list.append(spike_train / fs)
281291

282-
if not any([len(train) > 0 for train in spike_train_list]):
283-
continue
284-
285292
total_duration = total_durations[unit_id]
286293
ratio, _, count = isi_violations(spike_train_list, total_duration, isi_threshold_s, min_isi_s)
287294

@@ -359,7 +366,7 @@ def compute_refrac_period_violations(
359366
if not HAVE_NUMBA:
360367
warnings.warn("Error: numba is not installed.")
361368
warnings.warn("compute_refrac_period_violations cannot run without numba.")
362-
return {unit_id: np.nan for unit_id in unit_ids}
369+
return res({unit_id: np.nan for unit_id in unit_ids}, {unit_id: 0 for unit_id in unit_ids})
363370

364371
num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids)
365372

@@ -372,6 +379,11 @@ def compute_refrac_period_violations(
372379
nb_violations = {}
373380
rp_contamination = {}
374381
for unit_id in unit_ids:
382+
if num_spikes[unit_id] == 0:
383+
rp_contamination[unit_id] = np.nan
384+
nb_violations[unit_id] = -1
385+
continue
386+
375387
nb_violations[unit_id] = 0
376388
total_samples_unit = total_samples[unit_id]
377389

@@ -556,7 +568,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, periods=None, syn
556568
if unit_ids is None:
557569
unit_ids = sorting.unit_ids
558570

559-
spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids)
571+
num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids)
560572

561573
spikes = sorting.to_spike_vector()
562574
all_unit_ids = sorting.unit_ids
@@ -569,10 +581,10 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, periods=None, syn
569581
for i, unit_id in enumerate(all_unit_ids):
570582
if unit_id not in unit_ids:
571583
continue
572-
if spike_counts[unit_id] != 0:
573-
sync_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / spike_counts[unit_id]
584+
if num_spikes[unit_id] != 0:
585+
sync_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / num_spikes[unit_id]
574586
else:
575-
sync_id_metrics_dict[unit_id] = 0
587+
sync_id_metrics_dict[unit_id] = -1
576588
synchrony_metrics_dict[f"sync_spike_{synchrony_size}"] = sync_id_metrics_dict
577589

578590
return res(**synchrony_metrics_dict)
@@ -629,6 +641,8 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz
629641
if unit_ids is None:
630642
unit_ids = sorting.unit_ids
631643

644+
num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids)
645+
632646
if all(
633647
[
634648
sorting_analyzer.get_num_samples(segment_index) < bin_size_samples
@@ -648,6 +662,8 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz
648662
)
649663
cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1])
650664
for unit_id in unit_ids:
665+
if num_spikes[unit_id] == 0:
666+
continue
651667
bin_edges = bin_edges_per_unit[unit_id]
652668

653669
# we can concatenate spike trains across segments adding the cumulative number of samples
@@ -665,6 +681,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz
665681
# finally we compute the percentiles
666682
firing_ranges = {}
667683
for unit_id in unit_ids:
684+
if num_spikes[unit_id] == 0:
685+
firing_ranges[unit_id] = np.nan
686+
continue
668687
firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile(
669688
firing_rate_histograms[unit_id], percentiles[0]
670689
)
@@ -748,6 +767,10 @@ def compute_amplitude_cv_metrics(
748767

749768
amplitude_cv_medians, amplitude_cv_ranges = {}, {}
750769
for unit_id in unit_ids:
770+
if num_spikes[unit_id] == 0:
771+
amplitude_cv_medians[unit_id] = np.nan
772+
amplitude_cv_ranges[unit_id] = np.nan
773+
continue
751774
total_duration = total_durations[unit_id]
752775
firing_rate = num_spikes[unit_id] / total_duration
753776
temporal_bin_size_samples = int(
@@ -1267,6 +1290,8 @@ def compute_sd_ratio(
12671290
if unit_ids is None:
12681291
unit_ids = sorting_analyzer.unit_ids
12691292

1293+
num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids)
1294+
12701295
if not sorting_analyzer.has_recording():
12711296
warnings.warn(
12721297
"The `sd_ratio` metric cannot work with a recordless SortingAnalyzer object"
@@ -1297,6 +1322,9 @@ def compute_sd_ratio(
12971322
sd_ratio = {}
12981323

12991324
for unit_id in unit_ids:
1325+
if num_spikes[unit_id] == 0:
1326+
sd_ratio[unit_id] = np.nan
1327+
continue
13001328
spk_amp = []
13011329
for segment_index in range(sorting_analyzer.get_num_segments()):
13021330
spike_train = sorting.get_unit_spike_train(unit_id, segment_index)

src/spikeinterface/metrics/quality/tests/test_metrics_functions.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@ def test_calculate_firing_range(sorting_analyzer_simple):
273273
firing_ranges_periods = compute_firing_ranges(sorting_analyzer, periods=periods, bin_size_s=1)
274274
assert firing_ranges == firing_ranges_periods
275275

276+
empty_periods = np.empty(0, dtype=unit_period_dtype)
277+
firing_ranges_empty = compute_firing_ranges(sorting_analyzer, periods=empty_periods)
278+
assert np.all(np.isnan(np.array(list(firing_ranges_empty.values()))))
279+
276280
with pytest.warns(UserWarning) as w:
277281
firing_ranges_nan = compute_firing_ranges(
278282
sorting_analyzer, bin_size_s=sorting_analyzer.get_total_duration() + 1
@@ -287,6 +291,10 @@ def test_calculate_amplitude_cutoff(sorting_analyzer_simple):
287291
periods = compute_periods(sorting_analyzer, num_periods=5)
288292
amp_cuts_periods = compute_amplitude_cutoffs(sorting_analyzer, periods=periods, num_histogram_bins=10)
289293
assert amp_cuts == amp_cuts_periods
294+
295+
empty_periods = np.empty(0, dtype=unit_period_dtype)
296+
amp_cuts_empty = compute_amplitude_cutoffs(sorting_analyzer, periods=empty_periods)
297+
assert np.all(np.isnan(np.array(list(amp_cuts_empty.values()))))
290298
# print(amp_cuts)
291299

292300
# testing method accuracy with magic number is not a good pratcice, I remove this.
@@ -302,6 +310,10 @@ def test_calculate_amplitude_median(sorting_analyzer_simple):
302310
amp_medians_periods = compute_amplitude_medians(sorting_analyzer, periods=periods)
303311
assert amp_medians == amp_medians_periods
304312

313+
empty_periods = np.empty(0, dtype=unit_period_dtype)
314+
amp_medians_empty = compute_amplitude_medians(sorting_analyzer, periods=empty_periods)
315+
assert np.all(np.isnan(np.array(list(amp_medians_empty.values()))))
316+
305317
# testing method accuracy with magic number is not a good pratcice, I remove this.
306318
# amp_medians_gt = {0: 130.77323354628675, 1: 130.7461997791725, 2: 130.7461997791725}
307319
# assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05)
@@ -319,6 +331,15 @@ def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple, periods_simple)
319331
assert amp_cv_median == amp_cv_median_periods
320332
assert amp_cv_range == amp_cv_range_periods
321333

334+
empty_periods = np.empty(0, dtype=unit_period_dtype)
335+
amp_cv_median_empty, amp_cv_range_empty = compute_amplitude_cv_metrics(
336+
sorting_analyzer,
337+
periods=empty_periods,
338+
average_num_spikes_per_bin=20,
339+
)
340+
assert np.all(np.isnan(np.array(list(amp_cv_median_empty.values()))))
341+
assert np.all(np.isnan(np.array(list(amp_cv_range_empty.values()))))
342+
322343
# amps_scalings = compute_amplitude_scalings(sorting_analyzer)
323344
sorting_analyzer.compute("amplitude_scalings", **job_kwargs)
324345
amp_cv_median_scalings, amp_cv_range_scalings = compute_amplitude_cv_metrics(
@@ -354,6 +375,10 @@ def test_calculate_presence_ratio(sorting_analyzer_simple, periods_simple):
354375
periods = periods_simple
355376
ratios_periods = compute_presence_ratios(sorting_analyzer, periods=periods, bin_duration_s=10)
356377
assert ratios == ratios_periods
378+
379+
empty_periods = np.empty(0, dtype=unit_period_dtype)
380+
ratios_periods_empty = compute_presence_ratios(sorting_analyzer, periods=empty_periods)
381+
assert np.all(np.isnan(np.array(list(ratios_periods_empty.values()))))
357382
# testing method accuracy with magic number is not a good pratcice, I remove this.
358383
# ratios_gt = {0: 1.0, 1: 1.0, 2: 1.0}
359384
# np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values()))
@@ -367,6 +392,12 @@ def test_calculate_isi_violations(sorting_analyzer_violations, periods_violation
367392
sorting_analyzer, isi_threshold_ms=1, min_isi_ms=0.0, periods=periods
368393
)
369394
assert isi_viol == isi_viol_periods
395+
assert counts == counts_periods
396+
397+
empty_periods = np.empty(0, dtype=unit_period_dtype)
398+
isi_viol_empty, isi_counts_empty = compute_isi_violations(sorting_analyzer, periods=empty_periods)
399+
assert np.all(np.isnan(np.array(list(isi_viol_empty.values()))))
400+
assert np.array_equal(np.array(list(isi_counts_empty.values())), -1 * np.ones(len(sorting_analyzer.unit_ids)))
370401

371402
# testing method accuracy with magic number is not a good pratcice, I remove this.
372403
# isi_viol_gt = {0: 0.0998002996004994, 1: 0.7904857139469347, 2: 1.929898371551754}
@@ -384,6 +415,12 @@ def test_calculate_sliding_rp_violations(sorting_analyzer_violations, periods_vi
384415
)
385416
assert contaminations == contaminations_periods
386417

418+
empty_periods = np.empty(0, dtype=unit_period_dtype)
419+
contaminations_periods_empty = compute_sliding_rp_violations(
420+
sorting_analyzer, periods=empty_periods, bin_size_ms=0.25, window_size_s=1
421+
)
422+
assert np.all(np.isnan(np.array(list(contaminations_periods_empty.values()))))
423+
387424
# testing method accuracy with magic number is not a good pratcice, I remove this.
388425
# contaminations_gt = {0: 0.03, 1: 0.185, 2: 0.325}
389426
# assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05)
@@ -399,6 +436,15 @@ def test_calculate_rp_violations(sorting_analyzer_violations, periods_violations
399436
sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0, periods=periods
400437
)
401438
assert rp_contamination == rp_contamination_periods
439+
assert counts == counts_periods
440+
441+
empty_periods = np.empty(0, dtype=unit_period_dtype)
442+
rp_contamination_empty, counts_empty = compute_refrac_period_violations(
443+
sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0, periods=empty_periods
444+
)
445+
assert np.all(np.isnan(np.array(list(rp_contamination_empty.values()))))
446+
assert np.array_equal(np.array(list(counts_empty.values())), -1 * np.ones(len(sorting_analyzer.unit_ids)))
447+
402448
# testing method accuracy with magic number is not a good pratcice, I remove this.
403449
# counts_gt = {0: 2, 1: 4, 2: 10}
404450
# rp_contamination_gt = {0: 0.10534956502609294, 1: 1.0, 2: 1.0}
@@ -425,8 +471,19 @@ def test_synchrony_metrics(sorting_analyzer_simple, periods_simple):
425471
synchrony_metrics_periods = compute_synchrony_metrics(sorting_analyzer, periods=periods)
426472
assert synchrony_metrics == synchrony_metrics_periods
427473

428-
synchrony_sizes = np.array([2, 4, 8])
474+
empty_periods = np.empty(0, dtype=unit_period_dtype)
475+
synchrony_metrics_empty = compute_synchrony_metrics(sorting_analyzer, periods=empty_periods)
476+
assert np.array_equal(
477+
np.array(list(synchrony_metrics_empty.sync_spike_2.values())), -1 * np.ones(len(sorting_analyzer.unit_ids))
478+
)
479+
assert np.array_equal(
480+
np.array(list(synchrony_metrics_empty.sync_spike_4.values())), -1 * np.ones(len(sorting_analyzer.unit_ids))
481+
)
482+
assert np.array_equal(
483+
np.array(list(synchrony_metrics_empty.sync_spike_8.values())), -1 * np.ones(len(sorting_analyzer.unit_ids))
484+
)
429485

486+
synchrony_sizes = np.array([2, 4, 8])
430487
# check returns
431488
for size in synchrony_sizes:
432489
assert f"sync_spike_{size}" in synchrony_metrics._fields
@@ -487,6 +544,15 @@ def test_calculate_drift_metrics(sorting_analyzer_simple):
487544
assert drifts_stds == drifts_stds_periods
488545
assert drift_mads == drift_mads_periods
489546

547+
# calculate num spikes with empty periods
548+
empty_periods = np.empty(0, dtype=unit_period_dtype)
549+
drifts_ptps_empty, drifts_stds_empty, drift_mads_empty = compute_drift_metrics(
550+
sorting_analyzer_simple, periods=empty_periods
551+
)
552+
assert np.all(np.isnan(np.array(list(drifts_ptps_empty.values()))))
553+
assert np.all(np.isnan(np.array(list(drifts_stds_empty.values()))))
554+
assert np.all(np.isnan(np.array(list(drift_mads_empty.values()))))
555+
490556
# print(drifts_ptps, drifts_stds, drift_mads)
491557

492558
# testing method accuracy with magic number is not a good pratcice, I remove this.
@@ -507,6 +573,11 @@ def test_calculate_sd_ratio(sorting_analyzer_simple, periods_simple):
507573
assert sd_ratio == sd_ratio_periods
508574

509575
assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids)
576+
577+
# calculate num spikes with empty periods
578+
empty_periods = np.empty(0, dtype=unit_period_dtype)
579+
sd_ratios_empty_periods = compute_sd_ratio(sorting_analyzer_simple, periods=empty_periods)
580+
assert np.all(np.isnan(np.array(list(sd_ratios_empty_periods.values()))))
510581
# @aurelien can you check this, this is not working anymore
511582
# assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0)
512583

0 commit comments

Comments
 (0)