1313import numpy as np
1414from collections import namedtuple
1515
16+ from .numpyextractors import NumpySorting
1617from .sortinganalyzer import SortingAnalyzer , AnalyzerExtension , register_result_extension
1718from .waveform_tools import extract_waveforms_to_single_buffer , estimate_templates_with_accumulator
1819from .recording_tools import get_noise_levels
@@ -823,10 +824,9 @@ class BaseMetric:
823824 metric_columns = {} # column names and their dtypes of the dataframe
824825 metric_descriptions = {} # descriptions of each metric column
825826 needs_recording = False # whether the metric needs recording
826- needs_tmp_data = (
827- False # whether the metric needs temporary data comoputed with _prepare_data at the MetricExtension level
828- )
829- needs_job_kwargs = False
827+ needs_tmp_data = False # whether the metric needs temporary data computed with MetricExtension._prepare_data
828+ needs_job_kwargs = False # whether the metric needs job_kwargs
829+ supports_periods = False # whether the metric function supports periods
830830 depend_on = [] # extensions the metric depends on
831831
832832 # the metric function must have the signature:
@@ -839,7 +839,7 @@ class BaseMetric:
839839 metric_function = None # to be defined in subclass
840840
841841 @classmethod
842- def compute (cls , sorting_analyzer , unit_ids , metric_params , tmp_data , job_kwargs ):
842+ def compute (cls , sorting_analyzer , unit_ids , metric_params , tmp_data , job_kwargs , periods = None ):
843843 """Compute the metric.
844844
845845 Parameters
@@ -854,6 +854,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs
854854 Temporary data to pass to the metric function
855855 job_kwargs : dict
856856 Job keyword arguments to control parallelization
857+ periods : np.ndarray | None
858+ Numpy array of unit periods of unit_period_dtype if supports_periods is True
857859
858860 Returns
859861 -------
@@ -865,6 +867,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs
865867 args += (tmp_data ,)
866868 if cls .needs_job_kwargs :
867869 args += (job_kwargs ,)
870+ if cls .supports_periods :
871+ args += (periods ,)
868872
869873 results = cls .metric_function (* args , ** metric_params )
870874
@@ -897,6 +901,17 @@ class BaseMetricExtension(AnalyzerExtension):
897901 need_backward_compatibility_on_load = False
898902 metric_list : list [BaseMetric ] = None # list of BaseMetric
899903
904+ @classmethod
905+ def get_available_metric_names (cls ):
906+ """Get the available metric names.
907+
908+ Returns
909+ -------
910+ available_metric_names : list[str]
911+ List of available metric names.
912+ """
913+ return [m .metric_name for m in cls .metric_list ]
914+
900915 @classmethod
901916 def get_default_metric_params (cls ):
902917 """Get the default metric parameters.
@@ -988,6 +1003,7 @@ def _set_params(
9881003 metric_params : dict | None = None ,
9891004 delete_existing_metrics : bool = False ,
9901005 metrics_to_compute : list [str ] | None = None ,
1006+ periods : np .ndarray | None = None ,
9911007 ** other_params ,
9921008 ):
9931009 """
@@ -1004,6 +1020,8 @@ def _set_params(
10041020 If True, existing metrics in the extension will be deleted before computing new ones.
10051021 metrics_to_compute : list[str] | None
10061022 List of metric names to compute. If None, all metrics in `metric_names` are computed.
1023+ periods : np.ndarray | None
1024+ Numpy array of unit_period_dtype defining periods to compute metrics over.
10071025 other_params : dict
10081026 Additional parameters for metric computation.
10091027
@@ -1079,6 +1097,7 @@ def _set_params(
10791097 metrics_to_compute = metrics_to_compute ,
10801098 delete_existing_metrics = delete_existing_metrics ,
10811099 metric_params = metric_params ,
1100+ periods = periods ,
10821101 ** other_params ,
10831102 )
10841103 return params
@@ -1129,6 +1148,8 @@ def _compute_metrics(
11291148 if metric_names is None :
11301149 metric_names = self .params ["metric_names" ]
11311150
1151+ periods = self .params .get ("periods" , None )
1152+
11321153 column_names_dtypes = {}
11331154 for metric_name in metric_names :
11341155 metric = [m for m in self .metric_list if m .metric_name == metric_name ][0 ]
@@ -1153,6 +1174,7 @@ def _compute_metrics(
11531174 metric_params = metric_params ,
11541175 tmp_data = tmp_data ,
11551176 job_kwargs = job_kwargs ,
1177+ periods = periods ,
11561178 )
11571179 except Exception as e :
11581180 warnings .warn (f"Error computing metric { metric_name } : { e } " )
@@ -1179,6 +1201,7 @@ def _run(self, **job_kwargs):
11791201
11801202 metrics_to_compute = self .params ["metrics_to_compute" ]
11811203 delete_existing_metrics = self .params ["delete_existing_metrics" ]
1204+ periods = self .params .get ("periods" , None )
11821205
11831206 _ , job_kwargs = split_job_kwargs (job_kwargs )
11841207 job_kwargs = fix_job_kwargs (job_kwargs )
@@ -1452,6 +1475,16 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None,
14521475 periods ,
14531476 )
14541477 all_data = all_data [keep_mask ]
1478+ # since we have the mask already, we can use it directly to avoid double computation
1479+ spike_vector = self .sorting_analyzer .sorting .to_spike_vector (concatenated = True )
1480+ sliced_spike_vector = spike_vector [keep_mask ]
1481+ sorting = NumpySorting (
1482+ sliced_spike_vector ,
1483+ sampling_frequency = self .sorting_analyzer .sampling_frequency ,
1484+ unit_ids = self .sorting_analyzer .unit_ids ,
1485+ )
1486+ else :
1487+ sorting = self .sorting_analyzer .sorting
14551488
14561489 if outputs == "numpy" :
14571490 if copy :
@@ -1460,10 +1493,10 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None,
14601493 return all_data
14611494 elif outputs == "by_unit" :
14621495 unit_ids = self .sorting_analyzer .unit_ids
1463- spike_vector = self . sorting_analyzer . sorting . to_spike_vector ( concatenated = False )
1496+
14641497 if keep_mask is not None :
14651498 # since we are filtering spikes, we need to recompute the spike indices
1466- spike_vector = spike_vector [ keep_mask ]
1499+ spike_vector = sorting . to_spike_vector ( concatenated = False )
14671500 spike_indices = spike_vector_to_indices (spike_vector , unit_ids , absolute_index = True )
14681501 else :
14691502 # use the cache of indices
0 commit comments