1313
1414from spikeinterface .metrics .utils import create_ground_truth_pc_distributions , compute_periods
1515
16- from spikeinterface .metrics .quality import (
17- get_quality_metric_list ,
18- compute_quality_metrics ,
19- )
16+ from spikeinterface .metrics .quality import get_quality_metric_list , compute_quality_metrics , ComputeQualityMetrics
2017from spikeinterface .metrics .quality .misc_metrics import (
2118 misc_metrics_list ,
2219 compute_amplitude_cutoffs ,
@@ -657,37 +654,9 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder):
657654
658655 # can't use _misc_metric_name_to_func as some functions compute several qms
659656 # e.g. isi_violation and synchrony
660- quality_metrics = [
661- "num_spikes" ,
662- "firing_rate" ,
663- "presence_ratio" ,
664- "snr" ,
665- "isi_violations_ratio" ,
666- "isi_violations_count" ,
667- "rp_contamination" ,
668- "rp_violations" ,
669- "sliding_rp_violation" ,
670- "amplitude_cutoff" ,
671- "amplitude_median" ,
672- "amplitude_cv_median" ,
673- "amplitude_cv_range" ,
674- "sync_spike_2" ,
675- "sync_spike_4" ,
676- "sync_spike_8" ,
677- "firing_range" ,
678- "drift_ptp" ,
679- "drift_std" ,
680- "drift_mad" ,
681- "sd_ratio" ,
682- "isolation_distance" ,
683- "l_ratio" ,
684- "d_prime" ,
685- "silhouette" ,
686- "nn_hit_rate" ,
687- "nn_miss_rate" ,
688- ]
689-
690- small_sorting_analyzer .compute ("quality_metrics" )
657+ quality_metric_columns = ComputeQualityMetrics .get_metric_columns ()
658+ all_metrics = ComputeQualityMetrics .get_available_metric_names ()
659+ small_sorting_analyzer .compute ("quality_metrics" , metric_names = all_metrics )
691660
692661 cache_folder = create_cache_folder
693662 output_folder = cache_folder / "sorting_analyzer"
@@ -699,7 +668,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder):
699668 saved_metrics = csv .reader (metrics_file )
700669 metric_names = next (saved_metrics )
701670
702- for metric_name in quality_metrics :
671+ for metric_name in quality_metric_columns :
703672 assert metric_name in metric_names
704673
705674 folder_analyzer .compute ("quality_metrics" , metric_names = ["snr" ], delete_existing_metrics = False )
@@ -708,7 +677,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder):
708677 saved_metrics = csv .reader (metrics_file )
709678 metric_names = next (saved_metrics )
710679
711- for metric_name in quality_metrics :
680+ for metric_name in quality_metric_columns :
712681 assert metric_name in metric_names
713682
714683 folder_analyzer .compute ("quality_metrics" , metric_names = ["snr" ], delete_existing_metrics = True )
@@ -717,7 +686,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder):
717686 saved_metrics = csv .reader (metrics_file )
718687 metric_names = next (saved_metrics )
719688
720- for metric_name in quality_metrics :
689+ for metric_name in quality_metric_columns :
721690 if metric_name == "snr" :
722691 assert metric_name in metric_names
723692 else :
0 commit comments