Skip to content

Commit 410e62e

Browse files
committed
Merge branch 'main' of github.com:QUMIA/train-scripts
2 parents 8a03fe2 + 74c34fe commit 410e62e

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

qumia_core.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,21 @@ def validate(trainer: QUMIA_Trainer, n_batches=None, set_type='validation', fold
145145
print("Possible mismatch between labels and inputs!")
146146
#raise Exception("Mismatch between labels and inputs")
147147

148-
# Save the dataframe to a csv file
148+
# Prepare the output directory
149149
val_output_dir = os.path.join(trainer.output_dir, folder)
150+
os.makedirs(val_output_dir, exist_ok=True)
151+
152+
# Save the dataframe to a csv file
150153
df_combined.to_csv(os.path.join(val_output_dir, f'df_{set_type}_predictions.csv'), index=False)
151154

152155
# Create a confusion matrix
153156
create_confusion_matrix(rounded_predictions.tolist(), labels.tolist(), set_type, val_output_dir)
154157

158+
# WandB confusion matrix
159+
label_list = [value - 1 for value in labels.astype(int)]
160+
pred_list = [value - 1 for value in rounded_predictions.astype(int)]
155161
wandb.log({"cm_" + folder: wandb.plot.confusion_matrix(probs=None,
156-
y_true=labels, preds=predictions,
162+
y_true=label_list, preds=pred_list,
157163
class_names=['1.0', '2.0', '3.0', '4.0'])})
158164

159165
return df_combined

0 commit comments

Comments
 (0)