Skip to content

Commit

Permalink
Add logger flag to save_hyperparameters (#7960)
Browse files Browse the repository at this point in the history
* Add log flag to save_hyperparameters

* FIx setter

* Add test & Update changelog

* Address comments

* Fix conflicts

* Update trainer

* Update CHANGELOG.md

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Fix datamodule hparams fix

* Fix datamodule hparams fix

* Update test with patch

* Update pytorch_lightning/utilities/hparams_mixin.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Move log_hyperparams to mixin

* Update hparams mixin

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
kaushikb11 and awaelchli authored Jul 13, 2021
1 parent 1d3b7f2 commit 9d5ad76
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 14 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support `LightningModule.save_hyperparameters` when `LightningModule` is a dataclass ([#7992](https://github.com/PyTorchLightning/pytorch-lightning/pull/7992))


- Add support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980))
- Added support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980))


- Added `logger` boolean flag to `save_hyperparameters` ([#7960](https://github.com/PyTorchLightning/pytorch-lightning/pull/7960))


- Add support for calling scripts using the module syntax (`python -m package.script`) ([#8073](https://github.com/PyTorchLightning/pytorch-lightning/pull/8073))
Expand Down
30 changes: 20 additions & 10 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,20 +907,30 @@ def _pre_dispatch(self):

def _log_hyperparams(self):
# log hyper-parameters
hparams_initial = None

if self.logger is not None:
# save exp to get started (this is where the first experiment logs are written)
datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {}
lightning_hparams = self.lightning_module.hparams_initial
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
if colliding_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {colliding_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams."
)
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False

hparams_initial = {**lightning_hparams, **datamodule_hparams}
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
datamodule_hparams = self.datamodule.hparams_initial
lightning_hparams = self.lightning_module.hparams_initial

self.logger.log_hyperparams(hparams_initial)
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
if colliding_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {colliding_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams."
)
hparams_initial = {**lightning_hparams, **datamodule_hparams}
elif self.lightning_module._log_hyperparams:
hparams_initial = self.lightning_module.hparams_initial
elif datamodule_log_hyperparams:
hparams_initial = self.datamodule.hparams_initial

if hparams_initial is not None:
self.logger.log_hyperparams(hparams_initial)
self.logger.log_graph(self.lightning_module)
self.logger.save()

Expand Down
11 changes: 9 additions & 2 deletions pytorch_lightning/utilities/hparams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,20 @@
from pytorch_lightning.utilities.parsing import save_hyperparameters


class HyperparametersMixin:
class HyperparametersMixin(object):

__jit_unused_properties__ = ["hparams", "hparams_initial"]

def __init__(self) -> None:
super().__init__()
self._log_hyperparams = True

def save_hyperparameters(
self,
*args,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None
frame: Optional[types.FrameType] = None,
logger: bool = True
) -> None:
"""Save arguments to ``hparams`` attribute.
Expand All @@ -40,6 +45,7 @@ def save_hyperparameters(
ignore: an argument name or a list of argument names from
class ``__init__`` to be ignored
frame: a frame object. Default is None
logger: Whether to send the hyperparameters to the logger. Default: True
Example::
>>> class ManuallyArgsModel(HyperparametersMixin):
Expand Down Expand Up @@ -92,6 +98,7 @@ class ``__init__`` to be ignored
"arg1": 1
"arg3": 3.14
"""
self._log_hyperparams = logger
# the frame needs to be created in this file.
if not frame:
frame = inspect.currentframe().f_back
Expand Down
29 changes: 28 additions & 1 deletion tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import pickle
from argparse import Namespace
from typing import Optional
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
Expand Down Expand Up @@ -290,3 +291,29 @@ def log_hyperparams(self, params):
}
logger.log_hyperparams(Namespace(**np_params))
assert logger.logged_params == sanitized_params


@pytest.mark.parametrize("logger", [True, False])
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_hyperparams")
def test_log_hyperparams_being_called(log_hyperparams_mock, tmpdir, logger):

class TestModel(BoringModel):

def __init__(self, param_one, param_two):
super().__init__()
self.save_hyperparameters(logger=logger)

model = TestModel("pytorch", "lightning")
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=0.1,
limit_val_batches=0.1,
num_sanity_val_steps=0,
)
trainer.fit(model)

if logger:
log_hyperparams_mock.assert_called()
else:
log_hyperparams_mock.assert_not_called()

0 comments on commit 9d5ad76

Please sign in to comment.