@@ -155,7 +155,6 @@ class ExperimentLog:
155155 static_configs : dict
156156 grid_fields : list
157157 logs_file : str
158- info_fields : list
159158
160159 metric_fields : Optional [list ] = None
161160 df : Optional [pd .DataFrame ]= None
@@ -167,17 +166,15 @@ def __post_init__(self):
167166 if self .df is None :
168167 assert self .metric_fields is not None , 'Specify the metric fields of the experiment.'
169168 assert not (f := set (self .grid_fields ) & set (self .metric_fields )), f'Overlapping field names { f } in grid_fields and metric_fields. Remove one of them.'
170- columns = self .grid_fields + self .info_fields + self .metric_fields
171- self .df = pd .DataFrame (columns = columns ).set_index (self .grid_fields )
169+ self .df = pd .DataFrame (columns = self .grid_fields + self .metric_fields ).set_index (self .grid_fields )
172170 else :
173- self .metric_fields = [i for i in list (self .df ) if i not in self .info_fields ]
174- self .field_order = self .info_fields + self .metric_fields
171+ self .metric_fields = list (self .df )
175172
176173 # Constructors.
177174 # -----------------------------------------------------------------------------
178175 @classmethod
179- def from_exp_config (cls , exp_config , logs_file : str , info_fields : list , metric_fields : Optional [list ]= None , auto_update_tsv : bool = False ):
180- return cls (* (exp_config [k ] for k in ['static_configs' , 'grid_fields' ]), logs_file = logs_file , info_fields = info_fields ,
176+ def from_exp_config (cls , exp_config , logs_file : str , metric_fields : Optional [list ]= None , auto_update_tsv : bool = False ):
177+ return cls (* (exp_config [k ] for k in ['static_configs' , 'grid_fields' ]), logs_file = logs_file ,
181178 metric_fields = metric_fields , auto_update_tsv = auto_update_tsv )
182179
183180 @classmethod
@@ -218,19 +215,15 @@ def header():
218215 df = pd .read_csv (io .StringIO (csv_str ), sep = '\t ' )
219216 df = df .drop (['id' ], axis = 1 )
220217
221- # make str(list) to list
222- if not df .empty :
223- list_filt = lambda f : isinstance (v := df [f ].iloc [0 ], str ) and ('[' in v or '(' in v )
224- list_fields = [* filter (list_filt , list (df ))]
225- if parse_str :
226- df [list_fields ] = df [list_fields ].applymap (str2value )
218+ if parse_str :
219+ df = df .applymap (str2value )
227220
228221 # set grid_fields to multiindex
229222 df = df .set_index (idx [1 :])
230223
231224 return {'static_configs' : static_configs ,
232225 'grid_fields' : idx [1 :],
233- 'info_fields ' : list (df ),
226+ 'metric_fields ' : list (df ),
234227 'df' : df }
235228
236229
@@ -286,18 +279,33 @@ def wrapped(self, *args, **kwargs):
286279 # -----------------------------------------------------------------------------
287280
288281 @partial (update_tsv , mode = 'r' )
289- def add_result (self , configs , metrics = dict (), ** infos ):
282+ def add_result (self , configs , ** metrics ):
290283 '''Add experiment run result to dataframe'''
284+ if configs in self :
285+ cur_gridval = list2tuple ([configs [k ] for k in self .grid_fields ])
286+ self .df = self .df .drop (cur_gridval )
287+
288+ configs = {k :list2tuple (configs [k ]) for k in self .grid_fields }
289+ metrics = {k :metrics .get (k ) for k in self .metric_fields }
290+ result_dict = {k :[v ] for k , v in {** configs , ** metrics }.items ()}
291+ result_df = pd .DataFrame (result_dict ).set_index (self .grid_fields )
292+ self .df = pd .concat ([self .df , result_df ])[self .metric_fields ]
293+
294+ # Dataframe.loc based code
295+ # problem: pandas unable to distinguish between
296+ # ``.loc[(col_idx, row_idx)]`` and ``.loc[(multi_idx_1, multi_idx_2)]``
297+ # for 2-level multiindex dataframe
298+ """
291299 cur_gridval = list2tuple([configs[k] for k in self.grid_fields])
292300
293- row_dict = {** infos , ** metrics }
294- df_row = [row_dict .get (k ) for k in self .field_order ]
301+ df_row = [metrics.get(k) for k in self.field_order]
295302
296303 # Write over metric results if there is a config saved
297304 if configs in self:
298305 self.df = self.df.drop(cur_gridval)
299306
300307 self.df.loc[cur_gridval] = df_row
308+ """
301309
302310 @staticmethod
303311 def __add_column (df , new_column_name , fn , * fn_arg_fields ):
@@ -368,11 +376,10 @@ def same_diff(dictl, dictr):
368376
369377 self .static_configs = new_sttc
370378 self .grid_fields += new_to_self_sf
371- self .field_order = self .info_fields + self .metric_fields
372379
373380 self .df , other .df = (obj .df .reset_index () for obj in (self , other ))
374- self .df = pd .concat ([self .df , other .df ])[ self . grid_fields + self . field_order ] \
375- .set_index (self .grid_fields )
381+ self .df = pd .concat ([self .df , other .df ]) \
382+ .set_index (self .grid_fields )[ self . metric_fields ]
376383 return self
377384
378385 def merge (self , * others , same = True ):
@@ -411,20 +418,19 @@ def isin(self, config):
411418 '''Check if specific experiment config was already executed in log.'''
412419 if self .df .empty : return False
413420
414- cfg_same_with = lambda dct : [config [d ]== dct [ d ] for d in dct . keys ()]
421+ cfg_same_in_static = all ( [config [k ]== v for k , v in self . static_configs . items () if k in config ])
415422 cfg_matched_df = self .__cfg_match_row (config )
416423
417- return all ( cfg_same_with ( self . static_configs )) and not cfg_matched_df .empty
424+ return cfg_same_in_static and not cfg_matched_df .empty
418425
419426
420- def get_metric_and_info (self , config ):
427+ def get_metric (self , config ):
421428 '''Search matching log with given config dict and return metric_dict, info_dict'''
422429 assert config in self , 'config should be in self when using get_metric_dict.'
423430
424431 cfg_matched_df = self .__cfg_match_row (config )
425432 metric_dict = {k :(v .iloc [0 ] if not (v := cfg_matched_df [k ]).empty else None ) for k in self .metric_fields }
426- info_dict = {k :(v .iloc [0 ] if not (v := cfg_matched_df [k ]).empty else None ) for k in self .info_fields }
427- return metric_dict , info_dict
433+ return metric_dict
428434
429435 def is_same_exp (self , other ):
430436 '''Check if both logs have same config fields.'''
@@ -458,8 +464,8 @@ def melt_and_explode_metric(self, df=None, step=None):
458464
459465 # delete string and NaN valued rows
460466 df = df [pd .to_numeric (df ['metric_value' ], errors = 'coerce' ).notnull ()]\
461- .dropna ()\
462- .astype ('float' )
467+ .dropna ()\
468+ .astype ('float' )
463469
464470 return df
465471
@@ -492,8 +498,6 @@ class Experiment:
492498 - (current run checking) Save configs of currently running experiments to tsv so other running code can know.
493499 - Saves experiment logs, automatically resumes experiment using saved log.
494500 '''
495- info_field : ClassVar [list ] = ['datetime' , 'status' ]
496-
497501 __RUNNING : ClassVar [str ] = 'R'
498502 __FAILED : ClassVar [str ] = 'F'
499503 __COMPLETED : ClassVar [str ] = 'C'
@@ -557,7 +561,7 @@ def __get_log(self, logs_file, metric_fields=None, auto_update_tsv=False):
557561 logs_path , _ = os .path .split (logs_file )
558562 if not os .path .exists (logs_path ):
559563 os .makedirs (logs_path )
560- log = ExperimentLog .from_exp_config (self .configs .__dict__ , logs_file , self . info_field ,
564+ log = ExperimentLog .from_exp_config (self .configs .__dict__ , logs_file ,
561565 metric_fields = metric_fields , auto_update_tsv = auto_update_tsv )
562566 log .to_tsv ()
563567 return log
@@ -571,8 +575,8 @@ def get_paths(exp_folder):
571575 return cfg_file , tsv_file , fig_dir
572576
573577 def get_log_checkpoint (self , config , empty_metric ):
574- metric_dict , info_dict = self .log .get_metric_and_info (config )
575- if info_dict ['status' ] == self .__FAILED :
578+ metric_dict = self .log .get_metric (config )
579+ if metric_dict ['status' ] == self .__FAILED :
576580 return metric_dict
577581 return empty_metric
578582
@@ -593,8 +597,8 @@ def run(self):
593597 for i , config in enumerate (self .configs ):
594598
595599 if config in self .log :
596- metric_dict , info_dict = self .log .get_metric_and_info (config )
597- if info_dict .get ('status' ) != self .__FAILED :
600+ metric_dict = self .log .get_metric (config )
601+ if metric_dict .get ('status' ) != self .__FAILED :
598602 continue # skip already executed runs
599603
600604 # if config not in self.log or status==self.__FAILED
@@ -618,8 +622,9 @@ def run(self):
618622 raise
619623
620624 # Open log file and add result
621- self .log .add_result (config , metrics = metric_dict ,
622- datetime = str (datetime .now ()), status = self .__COMPLETED )
625+ self .log .add_result (config , ** metric_dict ,
626+ datetime = str (datetime .now ()),
627+ status = self .__COMPLETED )
623628 self .log .to_tsv ()
624629
625630 logging .info ("Saved experiment data to log" )
@@ -659,16 +664,15 @@ def resplit_logs(exp_folder_path: str, target_split: int=1, save_backup: bool=Tr
659664 # empty log
660665 lgs = ExperimentLog .from_exp_config (configs .__dict__ ,
661666 os .path .join (logs_folder , f'split_{ n } .tsv' ,),
662- base .info_fields ,
663667 base .metric_fields )
664668
665669 # resplitting nth split
666670 cfgs_temp = copy .deepcopy (configs )
667671 cfgs_temp .filter_iter (lambda i , _ : i % target_split == n )
668672 for cfg in track (cfgs_temp , description = f'split: { n } /{ target_split } ' ):
669673 if cfg in base :
670- metric_dict , info_dict = base .get_metric_and_info (cfg )
671- lgs .add_result (cfg , metric_dict , ** info_dict )
674+ metric_dict = base .get_metric (cfg )
675+ lgs .add_result (cfg , ** metric_dict )
672676
673677 lgs .to_tsv ()
674678
0 commit comments