diff --git a/art/loggers.py b/art/loggers.py index 8eb321b..ec1d6fb 100644 --- a/art/loggers.py +++ b/art/loggers.py @@ -7,9 +7,11 @@ from typing import TYPE_CHECKING, List, Optional, Union import numpy as np -from lightning.pytorch.loggers import NeptuneLogger, WandbLogger +from lightning.pytorch.loggers import NeptuneLogger, WandbLogger, TensorBoardLogger from loguru import logger +from art.utils.paths import EXPERIMENT_DIR_PATH + if TYPE_CHECKING: from loguru import Logger @@ -216,3 +218,50 @@ def add_tags(self, tags: Union[List[str], str]): if isinstance(tags, str): tags = [tags] self.wandb.run.tags += tags + + +class TensorBoardLoggerAdapter(TensorBoardLogger): + """ + This is a wrapper for LightningLogger for simplifying basic functionalities between different loggers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, save_dir=str(EXPERIMENT_DIR_PATH/"tensorboard"), **kwargs) + + def log_config(self, configFile: str): + """ + Logs a config file to TensorBoard. + + Args: + configFile (str): Path to config file. + """ + self.experiment.add_text("config", open(configFile, "r").read()) + + def log_img(self, image, path: str = "image"): + """ + Logs an image to TensorBoard. + + Args: + image (np.ndarray): Image to log. + path (str, optional): Path to log image to. Defaults to "image". + """ + self.experiment.add_image(path, image) + + def log_figure(self, figure, path: str = "figure"): + """ + Logs a figure to TensorBoard. + + Args: + figure (Any): Figure to log. + path (str, optional): Path to log figure to. Defaults to "figure". + """ + self.experiment.add_figure(path, figure) + + def add_tags(self, tags: Union[List[str], str]): + """ + Adds tags to the TensorBoard run. + + Args: + tags (Union[List[str], str]): Tag or list of tags to add. + """ + pass diff --git a/art/steps.py b/art/steps.py index ac5f709..99f6913 100644 --- a/art/steps.py +++ b/art/steps.py @@ -503,9 +503,11 @@ def __init__( self, model: ArtModule, number_of_steps: int = 50, + model_kwargs: Dict = {}, + logger: Optional[Logger] = None, ): self.number_of_steps = number_of_steps - super().__init__(model, {"overfit_batches": 1, "max_epochs": number_of_steps}) + super().__init__(model, {"overfit_batches": 1, "max_epochs": number_of_steps}, logger=logger, model_kwargs=model_kwargs) def do(self, previous_states: Dict): """ @@ -543,12 +545,14 @@ class Overfit(ModelStep): def __init__( self, model: ArtModule, - logger: Optional[Logger] = None, max_epochs: int = 1, + model_kwargs: Dict = {}, + trainer_kwargs: Dict = {}, + logger: Optional[Logger] = None, ): self.max_epochs = max_epochs - super().__init__(model, {"max_epochs": max_epochs}, logger=logger) + super().__init__(model, {"max_epochs": max_epochs}, logger=logger, model_kwargs=model_kwargs) def do(self, previous_states: Dict): """ diff --git a/art/utils/paths.py b/art/utils/paths.py index 3a97da1..e79c9ae 100644 --- a/art/utils/paths.py +++ b/art/utils/paths.py @@ -1,7 +1,8 @@ from pathlib import Path CHECKPOINTS_PATH = Path("art_checkpoints") -EXPERIMENT_LOG_DIR = CHECKPOINTS_PATH / "experiment" / "logs" +EXPERIMENT_DIR_PATH = CHECKPOINTS_PATH / "experiment" +EXPERIMENT_LOG_DIR = EXPERIMENT_DIR_PATH / "logs" def get_checkpoint_step_dir_path(full_step_name) -> Path: