Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logger flag to save_hyperparameters #7960

Merged
merged 15 commits into from
Jul 13, 2021
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support `LightningModule.save_hyperparameters` when `LightningModule` is a dataclass ([#7992](https://github.com/PyTorchLightning/pytorch-lightning/pull/7992))


- Add support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980))
- Added support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980))


- Added `logger` boolean flag to `save_hyperparameters` ([#7960](https://github.com/PyTorchLightning/pytorch-lightning/pull/7960))


- Add support for calling scripts using the module syntax (`python -m package.script`) ([#8073](https://github.com/PyTorchLightning/pytorch-lightning/pull/8073))
Expand Down
30 changes: 20 additions & 10 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,20 +907,30 @@ def _pre_dispatch(self):

def _log_hyperparams(self):
# log hyper-parameters
hparams_initial = None

if self.logger is not None:
# save exp to get started (this is where the first experiment logs are written)
datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {}
lightning_hparams = self.lightning_module.hparams_initial
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
if colliding_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {colliding_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams."
)
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False

hparams_initial = {**lightning_hparams, **datamodule_hparams}
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
datamodule_hparams = self.datamodule.hparams_initial
lightning_hparams = self.lightning_module.hparams_initial

self.logger.log_hyperparams(hparams_initial)
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
if colliding_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {colliding_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams."
)
hparams_initial = {**lightning_hparams, **datamodule_hparams}
elif self.lightning_module._log_hyperparams:
hparams_initial = self.lightning_module.hparams_initial
elif datamodule_log_hyperparams:
hparams_initial = self.datamodule.hparams_initial

if hparams_initial is not None:
self.logger.log_hyperparams(hparams_initial)
self.logger.log_graph(self.lightning_module)
self.logger.save()

Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/utilities/hparams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@
from pytorch_lightning.utilities.parsing import save_hyperparameters


class HyperparametersMixin:
class HyperparametersMixin(object):

__jit_unused_properties__ = ["hparams", "hparams_initial"]

def __init__(self) -> None:
self._log_hyperparams = True

def save_hyperparameters(
self,
*args,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None
frame: Optional[types.FrameType] = None,
logger: bool = True
tchaton marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Save arguments to ``hparams`` attribute.

Expand All @@ -40,6 +44,7 @@ def save_hyperparameters(
ignore: an argument name or a list of argument names from
class ``__init__`` to be ignored
frame: a frame object. Default is None
logger: Whether to send the hyperparameters to the logger. Default: True

Example::
>>> class ManuallyArgsModel(HyperparametersMixin):
Expand Down Expand Up @@ -92,6 +97,7 @@ class ``__init__`` to be ignored
"arg1": 1
"arg3": 3.14
"""
self._log_hyperparams = logger
# the frame needs to be created in this file.
if not frame:
frame = inspect.currentframe().f_back
Expand Down
29 changes: 28 additions & 1 deletion tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import pickle
from argparse import Namespace
from typing import Optional
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
Expand Down Expand Up @@ -290,3 +291,29 @@ def log_hyperparams(self, params):
}
logger.log_hyperparams(Namespace(**np_params))
assert logger.logged_params == sanitized_params


@pytest.mark.parametrize("logger", [True, False])
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_hyperparams")
def test_log_hyperparams_being_called(log_hyperparams_mock, tmpdir, logger):

class TestModel(BoringModel):

def __init__(self, param_one, param_two):
super().__init__()
self.save_hyperparameters(logger=logger)

model = TestModel("pytorch", "lightning")
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=0.1,
limit_val_batches=0.1,
num_sanity_val_steps=0,
)
trainer.fit(model)

if logger:
log_hyperparams_mock.assert_called()
else:
log_hyperparams_mock.assert_not_called()