@@ -961,18 +961,28 @@ def get_metric_column_descriptions(cls, metric_names=None):
961961 )
962962 return metric_column_descriptions
963963
964- def _cast_metrics (self , metrics_df ):
965- metric_dtypes = {}
966- for m in self .metric_list :
967- metric_dtypes .update (m .metric_columns )
968-
969- for col in metrics_df .columns :
970- if col in metric_dtypes :
971- try :
972- metrics_df [col ] = metrics_df [col ].astype (metric_dtypes [col ])
973- except Exception as e :
974- print (f"Error casting column { col } : { e } " )
975- return metrics_df
964+ @classmethod
965+ def get_optional_dependencies (cls , ** params ):
966+ metric_names = params .get ("metric_names" , None )
967+ if metric_names is None :
968+ metric_names = [m .metric_name for m in cls .metric_list ]
969+ else :
970+ for metric_name in metric_names :
971+ if metric_name not in [m .metric_name for m in cls .metric_list ]:
972+ raise ValueError (
973+ f"Metric { metric_name } not in available metrics { [m .metric_name for m in cls .metric_list ]} "
974+ )
975+ metric_depend_on = set ()
976+ for metric_name in metric_names :
977+ metric = [m for m in cls .metric_list if m .metric_name == metric_name ][0 ]
978+ for dep in metric .depend_on :
979+ if "|" in dep :
980+ dep_options = dep .split ("|" )
981+ metric_depend_on .update (dep_options )
982+ else :
983+ metric_depend_on .add (dep )
984+ depend_on = list (cls .depend_on ) + list (metric_depend_on )
985+ return depend_on
976986
977987 def _set_params (
978988 self ,
@@ -994,6 +1004,8 @@ def _set_params(
9941004 If None, default parameters for all metrics are used.
9951005 delete_existing_metrics : bool, default: False
9961006 If True, existing metrics in the extension will be deleted before computing new ones.
1007+ metrics_to_compute : list[str] | None
1008+ List of metric names to compute. If None, all metrics in `metric_names` are computed.
9971009 other_params : dict
9981010 Additional parameters for metric computation.
9991011
@@ -1208,15 +1220,18 @@ def _get_data(self):
12081220 # convert to correct dtype
12091221 return self .data ["metrics" ]
12101222
1211- def set_data (self , ext_data_name , data ):
1212- import pandas as pd
1223+ def _cast_metrics (self , metrics_df ):
1224+ metric_dtypes = {}
1225+ for m in self .metric_list :
1226+ metric_dtypes .update (m .metric_columns )
12131227
1214- if ext_data_name != "metrics" :
1215- return
1216- if not isinstance (data , pd .DataFrame ):
1217- return
1218- metrics = self ._cast_metrics (data )
1219- self .data [ext_data_name ] = metrics
1228+ for col in metrics_df .columns :
1229+ if col in metric_dtypes :
1230+ try :
1231+ metrics_df [col ] = metrics_df [col ].astype (metric_dtypes [col ])
1232+ except Exception as e :
1233+ print (f"Error casting column { col } : { e } " )
1234+ return metrics_df
12201235
12211236 def _select_extension_data (self , unit_ids : list [int | str ]):
12221237 """
@@ -1331,6 +1346,16 @@ def _split_extension_data(
13311346 new_data = dict (metrics = metrics )
13321347 return new_data
13331348
1349+ def set_data (self , ext_data_name , data ):
1350+ import pandas as pd
1351+
1352+ if ext_data_name != "metrics" :
1353+ return
1354+ if not isinstance (data , pd .DataFrame ):
1355+ return
1356+ metrics = self ._cast_metrics (data )
1357+ self .data [ext_data_name ] = metrics
1358+
13341359
13351360class BaseSpikeVectorExtension (AnalyzerExtension ):
13361361 """
0 commit comments