@@ -273,6 +273,10 @@ def test_calculate_firing_range(sorting_analyzer_simple):
273273 firing_ranges_periods = compute_firing_ranges (sorting_analyzer , periods = periods , bin_size_s = 1 )
274274 assert firing_ranges == firing_ranges_periods
275275
276+ empty_periods = np .empty (0 , dtype = unit_period_dtype )
277+ firing_ranges_empty = compute_firing_ranges (sorting_analyzer , periods = empty_periods )
278+ assert np .all (np .isnan (np .array (list (firing_ranges_empty .values ()))))
279+
276280 with pytest .warns (UserWarning ) as w :
277281 firing_ranges_nan = compute_firing_ranges (
278282 sorting_analyzer , bin_size_s = sorting_analyzer .get_total_duration () + 1
@@ -287,6 +291,10 @@ def test_calculate_amplitude_cutoff(sorting_analyzer_simple):
287291 periods = compute_periods (sorting_analyzer , num_periods = 5 )
288292 amp_cuts_periods = compute_amplitude_cutoffs (sorting_analyzer , periods = periods , num_histogram_bins = 10 )
289293 assert amp_cuts == amp_cuts_periods
294+
295+ empty_periods = np .empty (0 , dtype = unit_period_dtype )
296+ amp_cuts_empty = compute_amplitude_cutoffs (sorting_analyzer , periods = empty_periods )
297+ assert np .all (np .isnan (np .array (list (amp_cuts_empty .values ()))))
290298 # print(amp_cuts)
291299
292300 # testing method accuracy with magic number is not a good pratcice, I remove this.
@@ -302,6 +310,10 @@ def test_calculate_amplitude_median(sorting_analyzer_simple):
302310 amp_medians_periods = compute_amplitude_medians (sorting_analyzer , periods = periods )
303311 assert amp_medians == amp_medians_periods
304312
313+ empty_periods = np .empty (0 , dtype = unit_period_dtype )
314+ amp_medians_empty = compute_amplitude_medians (sorting_analyzer , periods = empty_periods )
315+ assert np .all (np .isnan (np .array (list (amp_medians_empty .values ()))))
316+
305317 # testing method accuracy with magic number is not a good pratcice, I remove this.
306318 # amp_medians_gt = {0: 130.77323354628675, 1: 130.7461997791725, 2: 130.7461997791725}
307319 # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05)
@@ -319,6 +331,15 @@ def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple, periods_simple)
319331 assert amp_cv_median == amp_cv_median_periods
320332 assert amp_cv_range == amp_cv_range_periods
321333
334+ empty_periods = np .empty (0 , dtype = unit_period_dtype )
335+ amp_cv_median_empty , amp_cv_range_empty = compute_amplitude_cv_metrics (
336+ sorting_analyzer ,
337+ periods = empty_periods ,
338+ average_num_spikes_per_bin = 20 ,
339+ )
340+ assert np .all (np .isnan (np .array (list (amp_cv_median_empty .values ()))))
341+ assert np .all (np .isnan (np .array (list (amp_cv_range_empty .values ()))))
342+
322343 # amps_scalings = compute_amplitude_scalings(sorting_analyzer)
323344 sorting_analyzer .compute ("amplitude_scalings" , ** job_kwargs )
324345 amp_cv_median_scalings , amp_cv_range_scalings = compute_amplitude_cv_metrics (
@@ -354,6 +375,10 @@ def test_calculate_presence_ratio(sorting_analyzer_simple, periods_simple):
354375 periods = periods_simple
355376 ratios_periods = compute_presence_ratios (sorting_analyzer , periods = periods , bin_duration_s = 10 )
356377 assert ratios == ratios_periods
378+
379+ empty_periods = np .empty (0 , dtype = unit_period_dtype )
380+ ratios_periods_empty = compute_presence_ratios (sorting_analyzer , periods = empty_periods )
381+ assert np .all (np .isnan (np .array (list (ratios_periods_empty .values ()))))
357382 # testing method accuracy with magic number is not a good pratcice, I remove this.
358383 # ratios_gt = {0: 1.0, 1: 1.0, 2: 1.0}
359384 # np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values()))
@@ -367,6 +392,12 @@ def test_calculate_isi_violations(sorting_analyzer_violations, periods_violation
367392 sorting_analyzer , isi_threshold_ms = 1 , min_isi_ms = 0.0 , periods = periods
368393 )
369394 assert isi_viol == isi_viol_periods
395+ assert counts == counts_periods
396+
397+ empty_periods = np .empty (0 , dtype = unit_period_dtype )
398+ isi_viol_empty , isi_counts_empty = compute_isi_violations (sorting_analyzer , periods = empty_periods )
399+ assert np .all (np .isnan (np .array (list (isi_viol_empty .values ()))))
400+ assert np .array_equal (np .array (list (isi_counts_empty .values ())), - 1 * np .ones (len (sorting_analyzer .unit_ids )))
370401
371402 # testing method accuracy with magic number is not a good pratcice, I remove this.
372403 # isi_viol_gt = {0: 0.0998002996004994, 1: 0.7904857139469347, 2: 1.929898371551754}
@@ -384,6 +415,12 @@ def test_calculate_sliding_rp_violations(sorting_analyzer_violations, periods_vi
384415 )
385416 assert contaminations == contaminations_periods
386417
418+ empty_periods = np .empty (0 , dtype = unit_period_dtype )
419+ contaminations_periods_empty = compute_sliding_rp_violations (
420+ sorting_analyzer , periods = empty_periods , bin_size_ms = 0.25 , window_size_s = 1
421+ )
422+ assert np .all (np .isnan (np .array (list (contaminations_periods_empty .values ()))))
423+
387424 # testing method accuracy with magic number is not a good pratcice, I remove this.
388425 # contaminations_gt = {0: 0.03, 1: 0.185, 2: 0.325}
389426 # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05)
@@ -399,6 +436,15 @@ def test_calculate_rp_violations(sorting_analyzer_violations, periods_violations
399436 sorting_analyzer , refractory_period_ms = 1 , censored_period_ms = 0.0 , periods = periods
400437 )
401438 assert rp_contamination == rp_contamination_periods
439+ assert counts == counts_periods
440+
441+ empty_periods = np .empty (0 , dtype = unit_period_dtype )
442+ rp_contamination_empty , counts_empty = compute_refrac_period_violations (
443+ sorting_analyzer , refractory_period_ms = 1 , censored_period_ms = 0.0 , periods = empty_periods
444+ )
445+ assert np .all (np .isnan (np .array (list (rp_contamination_empty .values ()))))
446+ assert np .array_equal (np .array (list (counts_empty .values ())), - 1 * np .ones (len (sorting_analyzer .unit_ids )))
447+
402448 # testing method accuracy with magic number is not a good pratcice, I remove this.
403449 # counts_gt = {0: 2, 1: 4, 2: 10}
404450 # rp_contamination_gt = {0: 0.10534956502609294, 1: 1.0, 2: 1.0}
@@ -425,8 +471,19 @@ def test_synchrony_metrics(sorting_analyzer_simple, periods_simple):
425471 synchrony_metrics_periods = compute_synchrony_metrics (sorting_analyzer , periods = periods )
426472 assert synchrony_metrics == synchrony_metrics_periods
427473
428- synchrony_sizes = np .array ([2 , 4 , 8 ])
474+ empty_periods = np .empty (0 , dtype = unit_period_dtype )
475+ synchrony_metrics_empty = compute_synchrony_metrics (sorting_analyzer , periods = empty_periods )
476+ assert np .array_equal (
477+ np .array (list (synchrony_metrics_empty .sync_spike_2 .values ())), - 1 * np .ones (len (sorting_analyzer .unit_ids ))
478+ )
479+ assert np .array_equal (
480+ np .array (list (synchrony_metrics_empty .sync_spike_4 .values ())), - 1 * np .ones (len (sorting_analyzer .unit_ids ))
481+ )
482+ assert np .array_equal (
483+ np .array (list (synchrony_metrics_empty .sync_spike_8 .values ())), - 1 * np .ones (len (sorting_analyzer .unit_ids ))
484+ )
429485
486+ synchrony_sizes = np .array ([2 , 4 , 8 ])
430487 # check returns
431488 for size in synchrony_sizes :
432489 assert f"sync_spike_{ size } " in synchrony_metrics ._fields
@@ -487,6 +544,15 @@ def test_calculate_drift_metrics(sorting_analyzer_simple):
487544 assert drifts_stds == drifts_stds_periods
488545 assert drift_mads == drift_mads_periods
489546
547+ # calculate num spikes with empty periods
548+ empty_periods = np .empty (0 , dtype = unit_period_dtype )
549+ drifts_ptps_empty , drifts_stds_empty , drift_mads_empty = compute_drift_metrics (
550+ sorting_analyzer_simple , periods = empty_periods
551+ )
552+ assert np .all (np .isnan (np .array (list (drifts_ptps_empty .values ()))))
553+ assert np .all (np .isnan (np .array (list (drifts_stds_empty .values ()))))
554+ assert np .all (np .isnan (np .array (list (drift_mads_empty .values ()))))
555+
490556 # print(drifts_ptps, drifts_stds, drift_mads)
491557
492558 # testing method accuracy with magic number is not a good pratcice, I remove this.
@@ -507,6 +573,11 @@ def test_calculate_sd_ratio(sorting_analyzer_simple, periods_simple):
507573 assert sd_ratio == sd_ratio_periods
508574
509575 assert np .all (list (sd_ratio .keys ()) == sorting_analyzer_simple .unit_ids )
576+
577+ # calculate num spikes with empty periods
578+ empty_periods = np .empty (0 , dtype = unit_period_dtype )
579+ sd_ratios_empty_periods = compute_sd_ratio (sorting_analyzer_simple , periods = empty_periods )
580+ assert np .all (np .isnan (np .array (list (sd_ratios_empty_periods .values ()))))
510581 # @aurelien can you check this, this is not working anymore
511582 # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0)
512583
0 commit comments