diff --git a/README.md b/README.md index bdce397..9971db3 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ python -m art.cli bert-transfer-learning-tutorial ``` 3. A tutorial showing how to use ART for regularization ```sh -python -m art.cli regularization_tutorial +python -m art.cli regularization-tutorial ``` ## API Cheatsheet diff --git a/art/steps.py b/art/steps.py index ac5f709..b54c276 100644 --- a/art/steps.py +++ b/art/steps.py @@ -26,6 +26,7 @@ from art.utils.enums import TrainingStage from art.utils.paths import get_checkpoint_logs_folder_path from art.utils.savers import JSONStepSaver +from art.utils.ensemble import ArtEnsemble class NoModelUsed: @@ -719,10 +720,6 @@ def do(self, previous_states: Dict): # TODO how to solve this? -class Squeeze(ModelStep): - pass - - class TransferLearning(ModelStep): """This step tries performing proper transfer learning""" @@ -833,3 +830,73 @@ def change_lr(model): model.lr = self.fine_tune_lr self.model_modifiers.append(change_lr) + + +class Ensemble(ModelStep): + """This step tries to ensemble models""" + + name = "Ensemble" + description = "Ensembles models" + + def __init__( + self, + model: ArtModule, + num_models: int = 5, + logger: Optional[Logger] = None, + trainer_kwargs: Dict = {}, + model_kwargs: Dict = {}, + model_modifiers: List[Callable] = [], + datamodule_modifiers: List[Callable] = [], + ): + """ + This method initializes the step + + Args: + models (List[ArtModule]): models + logger (Logger, optional): logger. Defaults to None. + trainer_kwargs (Dict, optional): Kwargs passed to lightning Trainer. Defaults to {}. + model_kwargs (Dict, optional): Kwargs passed to model. Defaults to {}. + model_modifiers (List[Callable], optional): model modifiers. Defaults to []. + datamodule_modifiers (List[Callable], optional): datamodule modifiers. Defaults to []. + """ + super().__init__( + model, + trainer_kwargs, + model_kwargs, + model_modifiers, + datamodule_modifiers, + logger=logger, + ) + self.num_models = num_models + + def do(self, previous_states: Dict): + """ + This method trains the model + + Args: + previous_states (Dict): previous states + """ + models_paths = [] + for _ in range(self.num_models): + self.reset_trainer( + logger=self.trainer.logger, trainer_kwargs=self.trainer_kwargs + ) + self.train(trainer_kwargs={"datamodule": self.datamodule}) + models_paths.append(self.trainer.checkpoint_callback.best_model_path) + + initialized_models = [] + for path in models_paths: + model = self.model_class.load_from_checkpoint(path) + model.eval() + initialized_models.append(model) + + self.model = ArtEnsemble(initialized_models) + self.validate(trainer_kwargs={"datamodule": self.datamodule}) + + def get_check_stage(self): + """Returns check stage""" + return TrainingStage.VALIDATION.value + + def log_model_params(self, model): + self.results["parameters"]["num_models"] = self.num_models + super().log_model_params(model) diff --git a/art/utils/ensemble.py b/art/utils/ensemble.py new file mode 100644 index 0000000..e2505e3 --- /dev/null +++ b/art/utils/ensemble.py @@ -0,0 +1,36 @@ +from art.core import ArtModule +from art.utils.enums import BATCH, PREDICTION + +import torch +from torch import nn + +from typing import List +from copy import deepcopy + + +class ArtEnsemble(ArtModule): + """ + Base class for ensembles. + """ + + def __init__(self, models: List[ArtModule]): + super().__init__() + self.models = nn.ModuleList(models) + + def predict(self, data): + predictions = torch.stack([self.predict_on_model_from_dataloader(model, deepcopy(data)) for model in self.models]) + return torch.mean(predictions, dim=0) + + def predict_on_model_from_dataloader(self, model, dataloader): + predictions = [] + for batch in dataloader: + model.to(self.device) + batch_processed = model.parse_data({BATCH: batch}) + predictions.append(model.predict(batch_processed)[PREDICTION]) + return torch.cat(predictions) + + def log_params(self): + return { + "num_models": len(self.models), + "models": [model.log_params() for model in self.models], + }