From e95b2f442d4f49073af89a800cd0d386d1cfdb92 Mon Sep 17 00:00:00 2001 From: Kacper Trebacz Date: Thu, 7 Dec 2023 12:47:27 +0100 Subject: [PATCH 1/2] adjust to embedding tutorial --- art/loggers.py | 53 ++++++++++++++++++++++++++++++++++++++++++++-- art/steps.py | 10 ++++++--- art/utils/paths.py | 3 ++- 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/art/loggers.py b/art/loggers.py index 8eb321b..5aa307d 100644 --- a/art/loggers.py +++ b/art/loggers.py @@ -7,9 +7,11 @@ from typing import TYPE_CHECKING, List, Optional, Union import numpy as np -from lightning.pytorch.loggers import NeptuneLogger, WandbLogger +from lightning.pytorch.loggers import NeptuneLogger, WandbLogger, TensorBoardLogger from loguru import logger +from art.utils.paths import EXPERIMENT_DIR_PATH + if TYPE_CHECKING: from loguru import Logger @@ -215,4 +217,51 @@ def add_tags(self, tags: Union[List[str], str]): """ if isinstance(tags, str): tags = [tags] - self.wandb.run.tags += tags + # self.wandb.run.tags += tags + + +class TensorBoardLoggerAdapter(TensorBoardLogger): + """ + This is a wrapper for LightningLogger for simplifying basic functionalities between different loggers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, save_dir=str(EXPERIMENT_DIR_PATH/"tensorboard"), **kwargs) + + def log_config(self, configFile: str): + """ + Logs a config file to TensorBoard. + + Args: + configFile (str): Path to config file. + """ + self.experiment.add_text("config", open(configFile, "r").read()) + + def log_img(self, image, path: str = "image"): + """ + Logs an image to TensorBoard. + + Args: + image (np.ndarray): Image to log. + path (str, optional): Path to log image to. Defaults to "image". + """ + self.experiment.add_image(path, image) + + def log_figure(self, figure, path: str = "figure"): + """ + Logs a figure to TensorBoard. + + Args: + figure (Any): Figure to log. + path (str, optional): Path to log figure to. Defaults to "figure". + """ + self.experiment.add_figure(path, figure) + + def add_tags(self, tags: Union[List[str], str]): + """ + Adds tags to the TensorBoard run. + + Args: + tags (Union[List[str], str]): Tag or list of tags to add. + """ + pass diff --git a/art/steps.py b/art/steps.py index ac5f709..99f6913 100644 --- a/art/steps.py +++ b/art/steps.py @@ -503,9 +503,11 @@ def __init__( self, model: ArtModule, number_of_steps: int = 50, + model_kwargs: Dict = {}, + logger: Optional[Logger] = None, ): self.number_of_steps = number_of_steps - super().__init__(model, {"overfit_batches": 1, "max_epochs": number_of_steps}) + super().__init__(model, {"overfit_batches": 1, "max_epochs": number_of_steps}, logger=logger, model_kwargs=model_kwargs) def do(self, previous_states: Dict): """ @@ -543,12 +545,14 @@ class Overfit(ModelStep): def __init__( self, model: ArtModule, - logger: Optional[Logger] = None, max_epochs: int = 1, + model_kwargs: Dict = {}, + trainer_kwargs: Dict = {}, + logger: Optional[Logger] = None, ): self.max_epochs = max_epochs - super().__init__(model, {"max_epochs": max_epochs}, logger=logger) + super().__init__(model, {"max_epochs": max_epochs}, logger=logger, model_kwargs=model_kwargs) def do(self, previous_states: Dict): """ diff --git a/art/utils/paths.py b/art/utils/paths.py index 3a97da1..e79c9ae 100644 --- a/art/utils/paths.py +++ b/art/utils/paths.py @@ -1,7 +1,8 @@ from pathlib import Path CHECKPOINTS_PATH = Path("art_checkpoints") -EXPERIMENT_LOG_DIR = CHECKPOINTS_PATH / "experiment" / "logs" +EXPERIMENT_DIR_PATH = CHECKPOINTS_PATH / "experiment" +EXPERIMENT_LOG_DIR = EXPERIMENT_DIR_PATH / "logs" def get_checkpoint_step_dir_path(full_step_name) -> Path: From f6302a0285d0a1b23a500449f1ac052d03a5d390 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20Tr=C4=99bacz?= Date: Thu, 11 Jan 2024 23:51:59 +0100 Subject: [PATCH 2/2] Update loggers.py wandb fix --- art/loggers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/art/loggers.py b/art/loggers.py index 5aa307d..ec1d6fb 100644 --- a/art/loggers.py +++ b/art/loggers.py @@ -217,7 +217,7 @@ def add_tags(self, tags: Union[List[str], str]): """ if isinstance(tags, str): tags = [tags] - # self.wandb.run.tags += tags + self.wandb.run.tags += tags class TensorBoardLoggerAdapter(TensorBoardLogger):