Skip to content

Commit a436513

Browse files
authored
Merge pull request #10 from edong6768/changes/0.1.8
CNG remove info_field from ExperimentLog and FIX str2value, ExperimentLog.add_result, etc
2 parents c04709c + a35ff2d commit a436513

File tree

3 files changed

+50
-41
lines changed

3 files changed

+50
-41
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ package-dir = {""="src"}
88

99
[project]
1010
name = "malet"
11-
version = "0.1.7"
11+
version = "0.1.8"
1212
description = "Malet: a tool for machine learning experiment"
1313
readme = "README.md"
1414
requires-python = ">=3.8"

src/malet/experiment.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ class ExperimentLog:
155155
static_configs: dict
156156
grid_fields: list
157157
logs_file: str
158-
info_fields: list
159158

160159
metric_fields: Optional[list] = None
161160
df: Optional[pd.DataFrame]=None
@@ -167,17 +166,15 @@ def __post_init__(self):
167166
if self.df is None:
168167
assert self.metric_fields is not None, 'Specify the metric fields of the experiment.'
169168
assert not (f:=set(self.grid_fields) & set(self.metric_fields)), f'Overlapping field names {f} in grid_fields and metric_fields. Remove one of them.'
170-
columns = self.grid_fields + self.info_fields + self.metric_fields
171-
self.df = pd.DataFrame(columns=columns).set_index(self.grid_fields)
169+
self.df = pd.DataFrame(columns=self.grid_fields+self.metric_fields).set_index(self.grid_fields)
172170
else:
173-
self.metric_fields = [i for i in list(self.df) if i not in self.info_fields]
174-
self.field_order = self.info_fields + self.metric_fields
171+
self.metric_fields = list(self.df)
175172

176173
# Constructors.
177174
# -----------------------------------------------------------------------------
178175
@classmethod
179-
def from_exp_config(cls, exp_config, logs_file: str, info_fields: list, metric_fields: Optional[list]=None, auto_update_tsv: bool=False):
180-
return cls(*(exp_config[k] for k in ['static_configs', 'grid_fields']), logs_file=logs_file, info_fields=info_fields,
176+
def from_exp_config(cls, exp_config, logs_file: str, metric_fields: Optional[list]=None, auto_update_tsv: bool=False):
177+
return cls(*(exp_config[k] for k in ['static_configs', 'grid_fields']), logs_file=logs_file,
181178
metric_fields=metric_fields, auto_update_tsv = auto_update_tsv)
182179

183180
@classmethod
@@ -218,19 +215,15 @@ def header():
218215
df = pd.read_csv(io.StringIO(csv_str), sep='\t')
219216
df = df.drop(['id'], axis=1)
220217

221-
# make str(list) to list
222-
if not df.empty:
223-
list_filt = lambda f: isinstance(v:=df[f].iloc[0], str) and ('[' in v or '(' in v)
224-
list_fields = [*filter(list_filt, list(df))]
225-
if parse_str:
226-
df[list_fields] = df[list_fields].applymap(str2value)
218+
if parse_str:
219+
df = df.applymap(str2value)
227220

228221
# set grid_fields to multiindex
229222
df = df.set_index(idx[1:])
230223

231224
return {'static_configs': static_configs,
232225
'grid_fields': idx[1:],
233-
'info_fields': list(df),
226+
'metric_fields': list(df),
234227
'df': df}
235228

236229

@@ -286,18 +279,33 @@ def wrapped(self, *args, **kwargs):
286279
# -----------------------------------------------------------------------------
287280

288281
@partial(update_tsv, mode='r')
289-
def add_result(self, configs, metrics=dict(), **infos):
282+
def add_result(self, configs, **metrics):
290283
'''Add experiment run result to dataframe'''
284+
if configs in self:
285+
cur_gridval = list2tuple([configs[k] for k in self.grid_fields])
286+
self.df = self.df.drop(cur_gridval)
287+
288+
configs = {k:list2tuple(configs[k]) for k in self.grid_fields}
289+
metrics = {k:metrics.get(k) for k in self.metric_fields}
290+
result_dict = {k:[v] for k, v in {**configs, **metrics}.items()}
291+
result_df = pd.DataFrame(result_dict).set_index(self.grid_fields)
292+
self.df = pd.concat([self.df, result_df])[self.metric_fields]
293+
294+
# Dataframe.loc based code
295+
# problem: pandas unable to distinguish between
296+
# ``.loc[(col_idx, row_idx)]`` and ``.loc[(multi_idx_1, multi_idx_2)]``
297+
# for 2-level multiindex dataframe
298+
"""
291299
cur_gridval = list2tuple([configs[k] for k in self.grid_fields])
292300
293-
row_dict = {**infos, **metrics}
294-
df_row = [row_dict.get(k) for k in self.field_order]
301+
df_row = [metrics.get(k) for k in self.field_order]
295302
296303
# Write over metric results if there is a config saved
297304
if configs in self:
298305
self.df = self.df.drop(cur_gridval)
299306
300307
self.df.loc[cur_gridval] = df_row
308+
"""
301309

302310
@staticmethod
303311
def __add_column(df, new_column_name, fn, *fn_arg_fields):
@@ -368,11 +376,10 @@ def same_diff(dictl, dictr):
368376

369377
self.static_configs = new_sttc
370378
self.grid_fields += new_to_self_sf
371-
self.field_order = self.info_fields + self.metric_fields
372379

373380
self.df, other.df = (obj.df.reset_index() for obj in (self, other))
374-
self.df = pd.concat([self.df, other.df])[self.grid_fields+self.field_order] \
375-
.set_index(self.grid_fields)
381+
self.df = pd.concat([self.df, other.df]) \
382+
.set_index(self.grid_fields)[self.metric_fields]
376383
return self
377384

378385
def merge(self, *others, same=True):
@@ -411,20 +418,19 @@ def isin(self, config):
411418
'''Check if specific experiment config was already executed in log.'''
412419
if self.df.empty: return False
413420

414-
cfg_same_with = lambda dct: [config[d]==dct[d] for d in dct.keys()]
421+
cfg_same_in_static = all([config[k]==v for k, v in self.static_configs.items() if k in config])
415422
cfg_matched_df = self.__cfg_match_row(config)
416423

417-
return all(cfg_same_with(self.static_configs)) and not cfg_matched_df.empty
424+
return cfg_same_in_static and not cfg_matched_df.empty
418425

419426

420-
def get_metric_and_info(self, config):
427+
def get_metric(self, config):
421428
'''Search matching log with given config dict and return metric_dict, info_dict'''
422429
assert config in self, 'config should be in self when using get_metric_dict.'
423430

424431
cfg_matched_df = self.__cfg_match_row(config)
425432
metric_dict = {k:(v.iloc[0] if not (v:=cfg_matched_df[k]).empty else None) for k in self.metric_fields}
426-
info_dict = {k:(v.iloc[0] if not (v:=cfg_matched_df[k]).empty else None) for k in self.info_fields}
427-
return metric_dict, info_dict
433+
return metric_dict
428434

429435
def is_same_exp(self, other):
430436
'''Check if both logs have same config fields.'''
@@ -458,8 +464,8 @@ def melt_and_explode_metric(self, df=None, step=None):
458464

459465
# delete string and NaN valued rows
460466
df = df[pd.to_numeric(df['metric_value'], errors='coerce').notnull()]\
461-
.dropna()\
462-
.astype('float')
467+
.dropna()\
468+
.astype('float')
463469

464470
return df
465471

@@ -492,8 +498,6 @@ class Experiment:
492498
- (current run checking) Save configs of currently running experiments to tsv so other running code can know.
493499
- Saves experiment logs, automatically resumes experiment using saved log.
494500
'''
495-
info_field: ClassVar[list] = ['datetime', 'status']
496-
497501
__RUNNING: ClassVar[str] = 'R'
498502
__FAILED: ClassVar[str] = 'F'
499503
__COMPLETED: ClassVar[str] = 'C'
@@ -557,7 +561,7 @@ def __get_log(self, logs_file, metric_fields=None, auto_update_tsv=False):
557561
logs_path, _ = os.path.split(logs_file)
558562
if not os.path.exists(logs_path):
559563
os.makedirs(logs_path)
560-
log = ExperimentLog.from_exp_config(self.configs.__dict__, logs_file, self.info_field,
564+
log = ExperimentLog.from_exp_config(self.configs.__dict__, logs_file,
561565
metric_fields=metric_fields, auto_update_tsv=auto_update_tsv)
562566
log.to_tsv()
563567
return log
@@ -571,8 +575,8 @@ def get_paths(exp_folder):
571575
return cfg_file, tsv_file, fig_dir
572576

573577
def get_log_checkpoint(self, config, empty_metric):
574-
metric_dict, info_dict = self.log.get_metric_and_info(config)
575-
if info_dict['status'] == self.__FAILED:
578+
metric_dict = self.log.get_metric(config)
579+
if metric_dict['status'] == self.__FAILED:
576580
return metric_dict
577581
return empty_metric
578582

@@ -593,8 +597,8 @@ def run(self):
593597
for i, config in enumerate(self.configs):
594598

595599
if config in self.log:
596-
metric_dict, info_dict = self.log.get_metric_and_info(config)
597-
if info_dict.get('status') != self.__FAILED:
600+
metric_dict = self.log.get_metric(config)
601+
if metric_dict.get('status') != self.__FAILED:
598602
continue # skip already executed runs
599603

600604
# if config not in self.log or status==self.__FAILED
@@ -618,8 +622,9 @@ def run(self):
618622
raise
619623

620624
# Open log file and add result
621-
self.log.add_result(config, metrics=metric_dict,
622-
datetime=str(datetime.now()), status=self.__COMPLETED)
625+
self.log.add_result(config, **metric_dict,
626+
datetime=str(datetime.now()),
627+
status=self.__COMPLETED)
623628
self.log.to_tsv()
624629

625630
logging.info("Saved experiment data to log")
@@ -659,16 +664,15 @@ def resplit_logs(exp_folder_path: str, target_split: int=1, save_backup: bool=Tr
659664
# empty log
660665
lgs = ExperimentLog.from_exp_config(configs.__dict__,
661666
os.path.join(logs_folder, f'split_{n}.tsv',),
662-
base.info_fields,
663667
base.metric_fields)
664668

665669
# resplitting nth split
666670
cfgs_temp = copy.deepcopy(configs)
667671
cfgs_temp.filter_iter(lambda i, _: i%target_split==n)
668672
for cfg in track(cfgs_temp, description=f'split: {n}/{target_split}'):
669673
if cfg in base:
670-
metric_dict, info_dict = base.get_metric_and_info(cfg)
671-
lgs.add_result(cfg, metric_dict, **info_dict)
674+
metric_dict = base.get_metric(cfg)
675+
lgs.add_result(cfg, **metric_dict)
672676

673677
lgs.to_tsv()
674678

src/malet/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,9 @@ def list2tuple(l):
3838

3939
def str2value(value_str):
4040
"""Casts string back to standard python types"""
41-
return literal_eval(value_str) if isinstance(value_str, str) else value_str
41+
if not isinstance(value_str, str): return value_str
42+
value_str = value_str.replace('inf', '2e+308')
43+
try:
44+
return literal_eval(value_str)
45+
except:
46+
return value_str

0 commit comments

Comments
 (0)