Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logger flag to save_hyperparameters #7960

Merged
merged 15 commits into from
Jul 13, 2021
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support `LightningModule.save_hyperparameters` when `LightningModule` is a dataclass ([#7992](https://github.com/PyTorchLightning/pytorch-lightning/pull/7992))


- Add support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980))
- Added support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980))


- Added `logger` boolean flag to `save_hyperparameters` ([#7960](https://github.com/PyTorchLightning/pytorch-lightning/pull/7960))


- Add support for calling scripts using the module syntax (`python -m package.script`) ([#8073](https://github.com/PyTorchLightning/pytorch-lightning/pull/8073))
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __init__(
self._has_teardown_test = False
self._has_teardown_predict = False

self._log_hyperparams = True
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@property
def train_transforms(self):
"""
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._automatic_optimization: bool = True
self._truncated_bptt_steps: int = 0
self._param_requires_grad_state = dict()
self._log_hyperparams = True
self._metric_attributes: Optional[Dict[int, str]] = None

# deprecated, will be removed in 1.6
Expand Down
31 changes: 19 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,20 +907,27 @@ def _pre_dispatch(self):

def _log_hyperparams(self):
# log hyper-parameters
hparams_initial = None

if self.logger is not None:
# save exp to get started (this is where the first experiment logs are written)
datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {}
lightning_hparams = self.lightning_module.hparams_initial
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
if colliding_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {colliding_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams."
)

hparams_initial = {**lightning_hparams, **datamodule_hparams}

self.logger.log_hyperparams(hparams_initial)
if self.lightning_module._log_hyperparams and self.datamodule._log_hyperparams:
datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {}
lightning_hparams = self.lightning_module.hparams_initial
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
if colliding_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {colliding_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams."
)
hparams_initial = {**lightning_hparams, **datamodule_hparams}
elif self.lightning_module._log_hyperparams:
hparams_initial = self.lightning_module.hparams_initial
elif self.datamodule._log_hyperparams:
hparams_initial = self.datamodule.hparams_initial

if hparams_initial is not None:
self.logger.log_hyperparams(hparams_initial)
self.logger.log_graph(self.lightning_module)
self.logger.save()

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/utilities/hparams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def save_hyperparameters(
self,
*args,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None
frame: Optional[types.FrameType] = None,
logger: bool = True
tchaton marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Save arguments to ``hparams`` attribute.

Expand All @@ -40,6 +41,7 @@ def save_hyperparameters(
ignore: an argument name or a list of argument names from
class ``__init__`` to be ignored
frame: a frame object. Default is None
logger: Whether to save hyperparameters by logger. Default: True
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

Example::
>>> class ManuallyArgsModel(HyperparametersMixin):
Expand Down Expand Up @@ -92,6 +94,7 @@ class ``__init__`` to be ignored
"arg1": 1
"arg3": 3.14
"""
self._log_hyperparams = logger
# the frame needs to be created in this file.
if not frame:
frame = inspect.currentframe().f_back
Expand Down
35 changes: 35 additions & 0 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from unittest.mock import MagicMock

import numpy as np
import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
Expand Down Expand Up @@ -290,3 +291,37 @@ def log_hyperparams(self, params):
}
logger.log_hyperparams(Namespace(**np_params))
assert logger.logged_params == sanitized_params


@pytest.mark.parametrize("logger", [True, False])
def test_log_hyperparams_being_called(tmpdir, logger):

class TestLogger(DummyLogger):

def __init__(self):
super().__init__()
self.log_hyperparams_called = False

def log_hyperparams(self, *args, **kwargs):
self.log_hyperparams_called = True
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

class TestModel(BoringModel):

def __init__(self, param_one, param_two):
super().__init__()
self.save_hyperparameters(logger=logger)

test_logger = TestLogger()
model = TestModel("pytorch", "lightning")

trainer = Trainer(
default_root_dir=tmpdir,
logger=test_logger,
max_epochs=1,
limit_train_batches=0.1,
limit_val_batches=0.1,
num_sanity_val_steps=0,
)
trainer.fit(model)

assert logger == test_logger.log_hyperparams_called