Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion art/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from typing import TYPE_CHECKING, List, Optional, Union

import numpy as np
from lightning.pytorch.loggers import NeptuneLogger, WandbLogger
from lightning.pytorch.loggers import NeptuneLogger, WandbLogger, TensorBoardLogger
from loguru import logger

from art.utils.paths import EXPERIMENT_DIR_PATH

if TYPE_CHECKING:
from loguru import Logger

Expand Down Expand Up @@ -216,3 +218,50 @@ def add_tags(self, tags: Union[List[str], str]):
if isinstance(tags, str):
tags = [tags]
self.wandb.run.tags += tags


class TensorBoardLoggerAdapter(TensorBoardLogger):
"""
This is a wrapper for LightningLogger for simplifying basic functionalities between different loggers.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, save_dir=str(EXPERIMENT_DIR_PATH/"tensorboard"), **kwargs)

def log_config(self, configFile: str):
"""
Logs a config file to TensorBoard.

Args:
configFile (str): Path to config file.
"""
self.experiment.add_text("config", open(configFile, "r").read())

def log_img(self, image, path: str = "image"):
"""
Logs an image to TensorBoard.

Args:
image (np.ndarray): Image to log.
path (str, optional): Path to log image to. Defaults to "image".
"""
self.experiment.add_image(path, image)

def log_figure(self, figure, path: str = "figure"):
"""
Logs a figure to TensorBoard.

Args:
figure (Any): Figure to log.
path (str, optional): Path to log figure to. Defaults to "figure".
"""
self.experiment.add_figure(path, figure)

def add_tags(self, tags: Union[List[str], str]):
"""
Adds tags to the TensorBoard run.

Args:
tags (Union[List[str], str]): Tag or list of tags to add.
"""
pass
10 changes: 7 additions & 3 deletions art/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,11 @@ def __init__(
self,
model: ArtModule,
number_of_steps: int = 50,
model_kwargs: Dict = {},
logger: Optional[Logger] = None,
):
self.number_of_steps = number_of_steps
super().__init__(model, {"overfit_batches": 1, "max_epochs": number_of_steps})
super().__init__(model, {"overfit_batches": 1, "max_epochs": number_of_steps}, logger=logger, model_kwargs=model_kwargs)

def do(self, previous_states: Dict):
"""
Expand Down Expand Up @@ -543,12 +545,14 @@ class Overfit(ModelStep):
def __init__(
self,
model: ArtModule,
logger: Optional[Logger] = None,
max_epochs: int = 1,
model_kwargs: Dict = {},
trainer_kwargs: Dict = {},
logger: Optional[Logger] = None,
):
self.max_epochs = max_epochs

super().__init__(model, {"max_epochs": max_epochs}, logger=logger)
super().__init__(model, {"max_epochs": max_epochs}, logger=logger, model_kwargs=model_kwargs)

def do(self, previous_states: Dict):
"""
Expand Down
3 changes: 2 additions & 1 deletion art/utils/paths.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pathlib import Path

CHECKPOINTS_PATH = Path("art_checkpoints")
EXPERIMENT_LOG_DIR = CHECKPOINTS_PATH / "experiment" / "logs"
EXPERIMENT_DIR_PATH = CHECKPOINTS_PATH / "experiment"
EXPERIMENT_LOG_DIR = EXPERIMENT_DIR_PATH / "logs"


def get_checkpoint_step_dir_path(full_step_name) -> Path:
Expand Down