From 9d5ad7639cb0e694bb8634ece69235cfd330908d Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Tue, 13 Jul 2021 15:06:36 +0530 Subject: [PATCH] Add logger flag to save_hyperparameters (#7960) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add log flag to save_hyperparameters * FIx setter * Add test & Update changelog * Address comments * Fix conflicts * Update trainer * Update CHANGELOG.md Co-authored-by: Adrian Wälchli * Fix datamodule hparams fix * Fix datamodule hparams fix * Update test with patch * Update pytorch_lightning/utilities/hparams_mixin.py Co-authored-by: Adrian Wälchli * Move log_hyperparams to mixin * Update hparams mixin Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 5 +++- pytorch_lightning/trainer/trainer.py | 30 +++++++++++++------- pytorch_lightning/utilities/hparams_mixin.py | 11 +++++-- tests/loggers/test_base.py | 29 ++++++++++++++++++- 4 files changed, 61 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b35c08bdb198..2a0aeed6401dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -116,7 +116,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support `LightningModule.save_hyperparameters` when `LightningModule` is a dataclass ([#7992](https://github.com/PyTorchLightning/pytorch-lightning/pull/7992)) -- Add support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980)) +- Added support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980)) + + +- Added `logger` boolean flag to `save_hyperparameters` ([#7960](https://github.com/PyTorchLightning/pytorch-lightning/pull/7960)) - Add support for calling scripts using the module syntax (`python -m package.script`) ([#8073](https://github.com/PyTorchLightning/pytorch-lightning/pull/8073)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 306d1c2f06978..ac7a41e3808f2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -907,20 +907,30 @@ def _pre_dispatch(self): def _log_hyperparams(self): # log hyper-parameters + hparams_initial = None + if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) - datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {} - lightning_hparams = self.lightning_module.hparams_initial - colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys() - if colliding_keys: - raise MisconfigurationException( - f"Error while merging hparams: the keys {colliding_keys} are present " - "in both the LightningModule's and LightningDataModule's hparams." - ) + datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False - hparams_initial = {**lightning_hparams, **datamodule_hparams} + if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: + datamodule_hparams = self.datamodule.hparams_initial + lightning_hparams = self.lightning_module.hparams_initial - self.logger.log_hyperparams(hparams_initial) + colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys() + if colliding_keys: + raise MisconfigurationException( + f"Error while merging hparams: the keys {colliding_keys} are present " + "in both the LightningModule's and LightningDataModule's hparams." + ) + hparams_initial = {**lightning_hparams, **datamodule_hparams} + elif self.lightning_module._log_hyperparams: + hparams_initial = self.lightning_module.hparams_initial + elif datamodule_log_hyperparams: + hparams_initial = self.datamodule.hparams_initial + + if hparams_initial is not None: + self.logger.log_hyperparams(hparams_initial) self.logger.log_graph(self.lightning_module) self.logger.save() diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index 8dd4b23c89398..b1cb9492e91d5 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -22,15 +22,20 @@ from pytorch_lightning.utilities.parsing import save_hyperparameters -class HyperparametersMixin: +class HyperparametersMixin(object): __jit_unused_properties__ = ["hparams", "hparams_initial"] + def __init__(self) -> None: + super().__init__() + self._log_hyperparams = True + def save_hyperparameters( self, *args, ignore: Optional[Union[Sequence[str], str]] = None, - frame: Optional[types.FrameType] = None + frame: Optional[types.FrameType] = None, + logger: bool = True ) -> None: """Save arguments to ``hparams`` attribute. @@ -40,6 +45,7 @@ def save_hyperparameters( ignore: an argument name or a list of argument names from class ``__init__`` to be ignored frame: a frame object. Default is None + logger: Whether to send the hyperparameters to the logger. Default: True Example:: >>> class ManuallyArgsModel(HyperparametersMixin): @@ -92,6 +98,7 @@ class ``__init__`` to be ignored "arg1": 1 "arg3": 3.14 """ + self._log_hyperparams = logger # the frame needs to be created in this file. if not frame: frame = inspect.currentframe().f_back diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 9209083148265..5ecc372ec0acf 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -14,9 +14,10 @@ import pickle from argparse import Namespace from typing import Optional -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np +import pytest from pytorch_lightning import Trainer from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger @@ -290,3 +291,29 @@ def log_hyperparams(self, params): } logger.log_hyperparams(Namespace(**np_params)) assert logger.logged_params == sanitized_params + + +@pytest.mark.parametrize("logger", [True, False]) +@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_hyperparams") +def test_log_hyperparams_being_called(log_hyperparams_mock, tmpdir, logger): + + class TestModel(BoringModel): + + def __init__(self, param_one, param_two): + super().__init__() + self.save_hyperparameters(logger=logger) + + model = TestModel("pytorch", "lightning") + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=0.1, + limit_val_batches=0.1, + num_sanity_val_steps=0, + ) + trainer.fit(model) + + if logger: + log_hyperparams_mock.assert_called() + else: + log_hyperparams_mock.assert_not_called()