From fdd3651ebc1c0041ebe02308ba664dd882c5b53e Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Fri, 2 Oct 2020 11:23:18 +0200 Subject: [PATCH 01/36] Extract hparam functions to mixin. We want to use the hyperparameter saving code in the datamodule, too. --- pytorch_lightning/core/lightning.py | 127 +---------------- pytorch_lightning/utilities/hparams_mixin.py | 139 +++++++++++++++++++ 2 files changed, 144 insertions(+), 122 deletions(-) create mode 100644 pytorch_lightning/utilities/hparams_mixin.py diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dc77ebc8cc395..a87217d6977a6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -15,10 +15,8 @@ import collections import inspect import os -import re import tempfile from abc import ABC -from argparse import Namespace from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch @@ -27,17 +25,14 @@ from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO +from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities.parsing import ( - AttributeDict, - collect_init_args, - get_init_args, -) +from pytorch_lightning.utilities.parsing import collect_init_args from torch import ScriptModule, Tensor from torch.nn import Module from torch.nn.parallel import DistributedDataParallel @@ -55,6 +50,7 @@ class LightningModule( ABC, DeviceDtypeModuleMixin, + HyperparametersMixin, GradInformation, ModelIO, ModelHooks, @@ -67,9 +63,8 @@ class LightningModule( __jit_unused_properties__ = [ "datamodule", "example_input_array", - "hparams", "on_gpu", - ] + DeviceDtypeModuleMixin.__jit_unused_properties__ + ] + DeviceDtypeModuleMixin.__jit_unused_properties__ + HyperparametersMixin.__jit_unused_properties__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1498,88 +1493,6 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: parents_arguments.update(args) return self_arguments, parents_arguments - def save_hyperparameters(self, *args, frame=None) -> None: - """Save all model arguments. - - Args: - args: single object of `dict`, `NameSpace` or `OmegaConf` - or string names or argumenst from class `__init__` - - >>> from collections import OrderedDict - >>> class ManuallyArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # manually assign arguments - ... self.save_hyperparameters('arg1', 'arg3') - ... def forward(self, *args, **kwargs): - ... ... - >>> model = ManuallyArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg3": 3.14 - - >>> class AutomaticArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # equivalent automatic - ... self.save_hyperparameters() - ... def forward(self, *args, **kwargs): - ... ... - >>> model = AutomaticArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg2": abc - "arg3": 3.14 - - >>> class SingleArgModel(LightningModule): - ... def __init__(self, params): - ... super().__init__() - ... # manually assign single argument - ... self.save_hyperparameters(params) - ... def forward(self, *args, **kwargs): - ... ... - >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) - >>> model.hparams - "p1": 1 - "p2": abc - "p3": 3.14 - """ - if not frame: - frame = inspect.currentframe().f_back - init_args = get_init_args(frame) - assert init_args, "failed to inspect the self init" - if not args: - hp = init_args - self._hparams_name = "kwargs" if hp else None - else: - isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] - if len(isx_non_str) == 1: - hp = args[isx_non_str[0]] - cand_names = [k for k, v in init_args.items() if v == hp] - self._hparams_name = cand_names[0] if cand_names else None - else: - hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} - self._hparams_name = "kwargs" - - # `hparams` are expected here - if hp: - self._set_hparams(hp) - - def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: - if isinstance(hp, Namespace): - hp = vars(hp) - if isinstance(hp, dict): - hp = AttributeDict(hp) - elif isinstance(hp, PRIMITIVE_TYPES): - raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.") - elif not isinstance(hp, ALLOWED_CONFIG_TYPES): - raise ValueError(f"Unsupported config type of {type(hp)}.") - - if isinstance(hp, dict) and isinstance(self.hparams, dict): - self.hparams.update(hp) - else: - self._hparams = hp - def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs): """Saves the model in ONNX format @@ -1676,33 +1589,3 @@ def to_torchscript( torch.jit.save(scripted_module, file_path) return scripted_module - - @property - def hparams(self) -> Union[AttributeDict, str]: - if not hasattr(self, "_hparams"): - self._hparams = AttributeDict() - return self._hparams - - @hparams.setter - def hparams(self, hp: Union[dict, Namespace, Any]): - hparams_assignment_name = self.__get_hparams_assignment_variable() - self._hparams_name = hparams_assignment_name - self._set_hparams(hp) - - def __get_hparams_assignment_variable(self): - """""" - """ - looks at the code of the class to figure out what the user named self.hparams - this only happens when the user explicitly sets self.hparams - """ - try: - class_code = inspect.getsource(self.__class__) - lines = class_code.split("\n") - for line in lines: - line = re.sub(r"\s+", "", line, flags=re.UNICODE) - if ".hparams=" in line: - return line.split("=")[1] - except Exception as e: - return "hparams" - - return None diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py new file mode 100644 index 0000000000000..13095af1e1c62 --- /dev/null +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -0,0 +1,139 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from argparse import Namespace +from typing import Union, Any + +from pytorch_lightning.core.saving import PRIMITIVE_TYPES, ALLOWED_CONFIG_TYPES +from pytorch_lightning.utilities import AttributeDict +from pytorch_lightning.utilities.parsing import get_init_args + + +class HyperparametersMixin: + + __jit_unused_properties__ = ["hparams"] + + def save_hyperparameters(self, *args, frame=None) -> None: + """Save all model arguments. + + Args: + args: single object of `dict`, `NameSpace` or `OmegaConf` + or string names or argumenst from class `__init__` + + >>> from collections import OrderedDict + >>> class ManuallyArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # manually assign arguments + ... self.save_hyperparameters('arg1', 'arg3') + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 + + >>> class AutomaticArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # equivalent automatic + ... self.save_hyperparameters() + ... def forward(self, *args, **kwargs): + ... ... + >>> model = AutomaticArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg2": abc + "arg3": 3.14 + + >>> class SingleArgModel(LightningModule): + ... def __init__(self, params): + ... super().__init__() + ... # manually assign single argument + ... self.save_hyperparameters(params) + ... def forward(self, *args, **kwargs): + ... ... + >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) + >>> model.hparams + "p1": 1 + "p2": abc + "p3": 3.14 + """ + if not frame: + frame = inspect.currentframe().f_back + init_args = get_init_args(frame) + assert init_args, "failed to inspect the self init" + if not args: + hp = init_args + self._hparams_name = "kwargs" if hp else None + else: + isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] + if len(isx_non_str) == 1: + hp = args[isx_non_str[0]] + cand_names = [k for k, v in init_args.items() if v == hp] + self._hparams_name = cand_names[0] if cand_names else None + else: + hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} + self._hparams_name = "kwargs" + + # `hparams` are expected here + if hp: + self._set_hparams(hp) + + def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: + if isinstance(hp, Namespace): + hp = vars(hp) + if isinstance(hp, dict): + hp = AttributeDict(hp) + elif isinstance(hp, PRIMITIVE_TYPES): + raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.") + elif not isinstance(hp, ALLOWED_CONFIG_TYPES): + raise ValueError(f"Unsupported config type of {type(hp)}.") + + if isinstance(hp, dict) and isinstance(self.hparams, dict): + self.hparams.update(hp) + else: + self._hparams = hp + + @property + def hparams(self) -> Union[AttributeDict, str]: + if not hasattr(self, "_hparams"): + self._hparams = AttributeDict() + return self._hparams + + @hparams.setter + def hparams(self, hp: Union[dict, Namespace, Any]): + hparams_assignment_name = self.__get_hparams_assignment_variable() + self._hparams_name = hparams_assignment_name + self._set_hparams(hp) + + def __get_hparams_assignment_variable(self): + """""" + """ + looks at the code of the class to figure out what the user named self.hparams + this only happens when the user explicitly sets self.hparams + """ + try: + class_code = inspect.getsource(self.__class__) + lines = class_code.split("\n") + for line in lines: + line = re.sub(r"\s+", "", line, flags=re.UNICODE) + if ".hparams=" in line: + return line.split("=")[1] + except Exception as e: + return "hparams" + + return None From 9629b322fe58e916378689bc3d99c5df8f81248f Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Fri, 2 Oct 2020 11:35:54 +0200 Subject: [PATCH 02/36] Make LightningDataModule inherit from HyperparametersMixin. A DataModule can now save its hyperparameters just like a LightningModule. --- pytorch_lightning/core/datamodule.py | 4 +++- tests/base/datamodules.py | 2 ++ tests/core/test_datamodules.py | 6 ++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 51928c757b8d4..142189304bad2 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -23,6 +23,8 @@ from pytorch_lightning.utilities import parsing, rank_zero_only from torch.utils.data import DataLoader +from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin + class _DataModuleWrapper(type): def __init__(self, *args, **kwargs): @@ -92,7 +94,7 @@ def wrapped_fn(*args, **kwargs): return wrapped_fn -class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapper): +class LightningDataModule(DataHooks, CheckpointHooks, HyperparametersMixin, metaclass=_DataModuleWrapper): """ A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models. diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index 3ea0615db6e94..2d139032f30f9 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -14,6 +14,8 @@ def __init__(self, data_dir: str = "./"): self.non_picklable = None self.checkpoint_state: Optional[str] = None + self.save_hyperparameters() + def prepare_data(self): TrialMNIST(self.data_dir, train=True, download=True) TrialMNIST(self.data_dir, train=False, download=True) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 5325ca828e47b..e02e8e21ca791 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -6,6 +6,7 @@ import torch from pytorch_lightning import LightningDataModule, Trainer, seed_everything +from pytorch_lightning.utilities import AttributeDict from tests.base import EvalModelTemplate from tests.base.datamodules import TrialMNISTDataModule from tests.base.develop_utils import reset_seed @@ -419,3 +420,8 @@ def transfer_batch_to_device(self, data, device): expected = torch.device('cuda', 0) assert dm.hook_called assert batch_gpu.samples.device == batch_gpu.targets.device == expected + + +def test_simple_hyperparameters_saving(): + data = TrialMNISTDataModule() + assert data.hparams == AttributeDict({'data_dir': data.data_dir}) From 2771894e077f5ad52247a7d8523335eb0219acaf Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Fri, 2 Oct 2020 12:58:19 +0200 Subject: [PATCH 03/36] Add function to extend hparams. The function takes a dict or namespace and adds the contained hparams to the existing ones. If a hparam already exists, an error is thrown to avoid overwriting it. --- pytorch_lightning/utilities/hparams_mixin.py | 25 ++++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index 13095af1e1c62..c0eba6fc39a49 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -93,7 +93,26 @@ def save_hyperparameters(self, *args, frame=None) -> None: if hp: self._set_hparams(hp) + def extend_hparams(self, hparams): + hparams = self._to_hparams_dict(hparams) + if not hasattr(self, '_hparams'): + self._hparams_name = 'extended' + self._hparams = hparams + else: + colliding_keys = [key for key in hparams.keys() if key in self.hparams] + if colliding_keys: + raise ValueError(f'The keys {colliding_keys} are already present in the hparams.') + self.hparams.update(hparams) + def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: + hp = self._to_hparams_dict(hp) + + if isinstance(hp, dict) and isinstance(self.hparams, dict): + self.hparams.update(hp) + else: + self._hparams = hp + + def _to_hparams_dict(self, hp): if isinstance(hp, Namespace): hp = vars(hp) if isinstance(hp, dict): @@ -102,11 +121,7 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.") elif not isinstance(hp, ALLOWED_CONFIG_TYPES): raise ValueError(f"Unsupported config type of {type(hp)}.") - - if isinstance(hp, dict) and isinstance(self.hparams, dict): - self.hparams.update(hp) - else: - self._hparams = hp + return hp @property def hparams(self) -> Union[AttributeDict, str]: From ea744ae2be888cdfda1f1560629b6330898fcbdf Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Fri, 2 Oct 2020 13:00:20 +0200 Subject: [PATCH 04/36] Add hparams of DataModule to model before training. To log and checkpoint the hparams of the DataModule together with the model, we add them to the model's hparams before training. --- pytorch_lightning/trainer/training_loop.py | 6 ++ tests/models/test_hparams.py | 87 ++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 99318b9f34324..cb5463121c46c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -88,6 +88,12 @@ def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # clean hparams if hasattr(model, 'hparams'): parsing.clean_namespace(model.hparams) + if hasattr(datamodule, 'hparams'): + parsing.clean_namespace(datamodule.hparams) + try: + model.extend_hparams(datamodule.hparams) + except ValueError as e: + raise ValueError(f'Error while adding data module hparams to model: {e}') # links data to the trainer self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 807d5dcc869fe..422767725caa4 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -1,3 +1,4 @@ +import copy import os import pickle from argparse import Namespace @@ -13,6 +14,7 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml from pytorch_lightning.utilities import AttributeDict, is_picklable from tests.base import EvalModelTemplate, TrialMNIST +from tests.base.datamodules import TrialMNISTDataModule class SaveHparamsModel(EvalModelTemplate): @@ -540,3 +542,88 @@ def test_args(tmpdir): raw_checkpoint_path = _raw_checkpoint_path(trainer) model = SubClassVarArgs.load_from_checkpoint(raw_checkpoint_path) assert model.hparams == hparams + + +def test_extending_existing_hparams(tmpdir): + """Test that the new hparams are added to the existing ones.""" + hparams = {'arg1': 'abc'} + model = EvalModelTemplate() + old_hparams = copy.deepcopy(model.hparams) + model.extend_hparams(hparams) + + old_hparams.update(hparams) + assert old_hparams == model.hparams + + +def test_extending_non_existing_hparams(tmpdir): + """Test that hparams are created if they do not exist yet when we try to extend them.""" + class DummyModel(LightningModule): + pass + + hparams = {'arg1': 'abc'} + model = DummyModel() + model.extend_hparams(hparams) + + assert hparams == model.hparams + + +def test_extending_with_namespace(tmpdir): + """Test that we can extend hparams with a namespace.""" + hparams = Namespace(arg1='abc') + model = EvalModelTemplate() + old_hparams = copy.deepcopy(model.hparams) + model.extend_hparams(hparams) + + old_hparams.update(vars(hparams)) + assert old_hparams == model.hparams + + +def test_extend_with_unsupported_hparams(tmpdir): + """Test that usupported hparams structures raise an error when extending.""" + hparams = ('arg1', 'abc') + model = EvalModelTemplate() + + with pytest.raises(ValueError): + model.extend_hparams(hparams) + + +def test_extend_with_primitive_hparams(tmpdir): + """Test that primitives raise an error when extending.""" + hparams = 5 + model = EvalModelTemplate() + + with pytest.raises(ValueError): + model.extend_hparams(hparams) + + +def test_extend_with_collision(tmp_path): + """Test that new hparams cannot collide with existing hparams.""" + model = EvalModelTemplate() + with pytest.raises(ValueError): + model.extend_hparams({'batch_size': 5}) + + +def test_adding_datamodule_hparams(tmpdir): + """Test that hparams from datamodule are added to the checkpoint.""" + model = SaveHparamsModel({'arg1': 5, 'arg2': 'abc'}) + data = TrialMNISTDataModule() + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + trainer.fit(model, datamodule=data) + + hparams = model.hparams + hparams.update(data.hparams) + raw_checkpoint_path = _raw_checkpoint_path(trainer) + model = SaveHparamsModel.load_from_checkpoint(raw_checkpoint_path) + assert model.hparams == hparams + + +def test_colliding_datamodule_hparams(tmpdir): + """Test that colliding hparams from the datamodule are caught.""" + model = SaveHparamsModel({'data_dir': 'abc', 'arg2': 'abc'}) + data = TrialMNISTDataModule() + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + with pytest.raises(ValueError, match='Error while adding data module hparams to model:'): + trainer.fit(model, datamodule=data) From 3ea5f32d41c2d49b7ed55aa4c2fe994abdcc2f17 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Fri, 2 Oct 2020 13:47:31 +0200 Subject: [PATCH 05/36] Change examples due to cyclic import. --- pytorch_lightning/utilities/hparams_mixin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index c0eba6fc39a49..b8cc3cd57a150 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -34,7 +34,7 @@ def save_hyperparameters(self, *args, frame=None) -> None: or string names or argumenst from class `__init__` >>> from collections import OrderedDict - >>> class ManuallyArgsModel(LightningModule): + >>> class ManuallyArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # manually assign arguments @@ -46,7 +46,7 @@ def save_hyperparameters(self, *args, frame=None) -> None: "arg1": 1 "arg3": 3.14 - >>> class AutomaticArgsModel(LightningModule): + >>> class AutomaticArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # equivalent automatic @@ -59,7 +59,7 @@ def save_hyperparameters(self, *args, frame=None) -> None: "arg2": abc "arg3": 3.14 - >>> class SingleArgModel(LightningModule): + >>> class SingleArgModel(HyperparametersMixin): ... def __init__(self, params): ... super().__init__() ... # manually assign single argument From 9ea79899a9e40aee07d1716e489a6f9036c2c877 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 22 Oct 2020 10:57:00 +0200 Subject: [PATCH 06/36] Add initital_hparams to mixin and move/rename extend_hparams. --- pytorch_lightning/core/lightning.py | 11 +++++++ pytorch_lightning/trainer/training_loop.py | 2 +- pytorch_lightning/utilities/hparams_mixin.py | 33 ++++++++++++-------- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e27f42d4775a8..6e1498c509702 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1514,3 +1514,14 @@ def to_torchscript( torch.jit.save(torchscript_module, file_path) return torchscript_module + + def add_datamodule_hparams(self, hparams): + hparams = self._to_hparams_dict(hparams) + if not hasattr(self, '_hparams'): + self._hparams_name = 'extended' + self._hparams = hparams + else: + colliding_keys = [key for key in hparams.keys() if key in self.hparams] + if colliding_keys: + raise ValueError(f'The keys {colliding_keys} are already present in the hparams.') + self.hparams.update(hparams) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e879ef902fa12..ec17a88c12333 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -98,7 +98,7 @@ def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): if hasattr(datamodule, 'hparams'): parsing.clean_namespace(datamodule.hparams) try: - model.extend_hparams(datamodule.hparams) + model.add_datamodule_hparams(datamodule.hparams) except ValueError as e: raise ValueError(f'Error while adding data module hparams to model: {e}') diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index b8cc3cd57a150..cd615d8ee38e2 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy import inspect import re from argparse import Namespace @@ -24,7 +24,10 @@ class HyperparametersMixin: - __jit_unused_properties__ = ["hparams"] + __jit_unused_properties__ = [ + "hparams", + "hparams_initial" + ] def save_hyperparameters(self, *args, frame=None) -> None: """Save all model arguments. @@ -77,9 +80,11 @@ def save_hyperparameters(self, *args, frame=None) -> None: init_args = get_init_args(frame) assert init_args, "failed to inspect the self init" if not args: + # take all arguments hp = init_args self._hparams_name = "kwargs" if hp else None else: + # take only listed arguments in `save_hparams` isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] if len(isx_non_str) == 1: hp = args[isx_non_str[0]] @@ -92,17 +97,8 @@ def save_hyperparameters(self, *args, frame=None) -> None: # `hparams` are expected here if hp: self._set_hparams(hp) - - def extend_hparams(self, hparams): - hparams = self._to_hparams_dict(hparams) - if not hasattr(self, '_hparams'): - self._hparams_name = 'extended' - self._hparams = hparams - else: - colliding_keys = [key for key in hparams.keys() if key in self.hparams] - if colliding_keys: - raise ValueError(f'The keys {colliding_keys} are already present in the hparams.') - self.hparams.update(hparams) + # make deep copy so there is not other runtime changes reflected + self._hparams_initial = copy.deepcopy(self._hparams) def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: hp = self._to_hparams_dict(hp) @@ -121,6 +117,7 @@ def _to_hparams_dict(self, hp): raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.") elif not isinstance(hp, ALLOWED_CONFIG_TYPES): raise ValueError(f"Unsupported config type of {type(hp)}.") + return hp @property @@ -129,11 +126,21 @@ def hparams(self) -> Union[AttributeDict, str]: self._hparams = AttributeDict() return self._hparams + @property + def hparams_initial(self) -> AttributeDict: + if not hasattr(self, "_hparams_initial"): + return AttributeDict() + # prevent any change + return copy.deepcopy(self._hparams_initial) + @hparams.setter def hparams(self, hp: Union[dict, Namespace, Any]): hparams_assignment_name = self.__get_hparams_assignment_variable() self._hparams_name = hparams_assignment_name self._set_hparams(hp) + # this resolves case when user does not uses `save_hyperparameters` and do hard assignement in init + if not hasattr(self, "_hparams_initial"): + self._hparams_initial = copy.deepcopy(self._hparams) def __get_hparams_assignment_variable(self): """""" From f3ea974bc711cbd0a660b353e2f2ecdb9dfa02d5 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 22 Oct 2020 10:57:30 +0200 Subject: [PATCH 07/36] Update unit tests. --- tests/models/test_hparams.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 47448d049ae33..f68e6610584dc 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -618,7 +618,7 @@ def test_extending_existing_hparams(tmpdir): hparams = {'arg1': 'abc'} model = EvalModelTemplate() old_hparams = copy.deepcopy(model.hparams) - model.extend_hparams(hparams) + model.add_datamodule_hparams(hparams) old_hparams.update(hparams) assert old_hparams == model.hparams @@ -632,7 +632,7 @@ class DummyModel(LightningModule): hparams = {'arg1': 'abc'} model = DummyModel() - model.extend_hparams(hparams) + model.add_datamodule_hparams(hparams) assert hparams == model.hparams @@ -642,7 +642,7 @@ def test_extending_with_namespace(tmpdir): hparams = Namespace(arg1='abc') model = EvalModelTemplate() old_hparams = copy.deepcopy(model.hparams) - model.extend_hparams(hparams) + model.add_datamodule_hparams(hparams) old_hparams.update(vars(hparams)) assert old_hparams == model.hparams @@ -654,7 +654,7 @@ def test_extend_with_unsupported_hparams(tmpdir): model = EvalModelTemplate() with pytest.raises(ValueError): - model.extend_hparams(hparams) + model.add_datamodule_hparams(hparams) def test_extend_with_primitive_hparams(tmpdir): @@ -663,14 +663,14 @@ def test_extend_with_primitive_hparams(tmpdir): model = EvalModelTemplate() with pytest.raises(ValueError): - model.extend_hparams(hparams) + model.add_datamodule_hparams(hparams) def test_extend_with_collision(tmp_path): """Test that new hparams cannot collide with existing hparams.""" model = EvalModelTemplate() with pytest.raises(ValueError): - model.extend_hparams({'batch_size': 5}) + model.add_datamodule_hparams({'batch_size': 5}) def test_adding_datamodule_hparams(tmpdir): From fc8af4189d51faa748676c3d6382a2f1bc0293be Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 22 Oct 2020 11:07:30 +0200 Subject: [PATCH 08/36] Add hparams from datamodule to hparams_inital, too. --- pytorch_lightning/core/lightning.py | 5 ++++- tests/models/test_hparams.py | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6e1498c509702..84b21342831b0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1518,10 +1518,13 @@ def to_torchscript( def add_datamodule_hparams(self, hparams): hparams = self._to_hparams_dict(hparams) if not hasattr(self, '_hparams'): - self._hparams_name = 'extended' self._hparams = hparams + self._hparams_initial = hparams else: colliding_keys = [key for key in hparams.keys() if key in self.hparams] + colliding_keys += [key for key in hparams.keys() if key in self.hparams_initial] + colliding_keys = set(colliding_keys) if colliding_keys: raise ValueError(f'The keys {colliding_keys} are already present in the hparams.') self.hparams.update(hparams) + self._hparams_initial.update(hparams) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index f68e6610584dc..82f2662048765 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -622,6 +622,7 @@ def test_extending_existing_hparams(tmpdir): old_hparams.update(hparams) assert old_hparams == model.hparams + assert old_hparams == model.hparams_initial def test_extending_non_existing_hparams(tmpdir): @@ -635,6 +636,7 @@ class DummyModel(LightningModule): model.add_datamodule_hparams(hparams) assert hparams == model.hparams + assert hparams == model.hparams_initial def test_extending_with_namespace(tmpdir): @@ -646,6 +648,7 @@ def test_extending_with_namespace(tmpdir): old_hparams.update(vars(hparams)) assert old_hparams == model.hparams + assert old_hparams == model.hparams_initial def test_extend_with_unsupported_hparams(tmpdir): @@ -686,6 +689,7 @@ def test_adding_datamodule_hparams(tmpdir): raw_checkpoint_path = _raw_checkpoint_path(trainer) model = SaveHparamsModel.load_from_checkpoint(raw_checkpoint_path) assert model.hparams == hparams + assert model.hparams_initial == hparams def test_colliding_datamodule_hparams(tmpdir): From 3f8d44fc86b835b109af2ca413e6fb1e56f110de Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 22 Oct 2020 11:10:37 +0200 Subject: [PATCH 09/36] Test if datamodule hparams are logged to trainer loggers, too. --- tests/models/test_hparams.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 82f2662048765..23ef007c99416 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -688,8 +688,12 @@ def test_adding_datamodule_hparams(tmpdir): hparams.update(data.hparams) raw_checkpoint_path = _raw_checkpoint_path(trainer) model = SaveHparamsModel.load_from_checkpoint(raw_checkpoint_path) - assert model.hparams == hparams - assert model.hparams_initial == hparams + assert hparams == model.hparams + assert hparams == model.hparams_initial + + path_yaml = os.path.join(trainer.logger.log_dir, trainer.logger.NAME_HPARAMS_FILE) + logged_hparams = load_hparams_from_yaml(path_yaml) + assert hparams == logged_hparams def test_colliding_datamodule_hparams(tmpdir): From 822cce4138586e7c95c497fcf21eca01fd94deb4 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 22 Oct 2020 11:14:24 +0200 Subject: [PATCH 10/36] Simplify error handling. The function for adding hparams is now only available in lightning module. The error message is now specific enough without reraising it in the training loop. --- pytorch_lightning/core/lightning.py | 3 ++- pytorch_lightning/trainer/training_loop.py | 5 +---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 84b21342831b0..c3e7d2383a9a7 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1525,6 +1525,7 @@ def add_datamodule_hparams(self, hparams): colliding_keys += [key for key in hparams.keys() if key in self.hparams_initial] colliding_keys = set(colliding_keys) if colliding_keys: - raise ValueError(f'The keys {colliding_keys} are already present in the hparams.') + raise ValueError(f'Error while adding datamodule hparams: ' + f'the keys {colliding_keys} are already present in the model hparams.') self.hparams.update(hparams) self._hparams_initial.update(hparams) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ec17a88c12333..cd446aadff847 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -97,10 +97,7 @@ def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): parsing.clean_namespace(model.hparams) if hasattr(datamodule, 'hparams'): parsing.clean_namespace(datamodule.hparams) - try: - model.add_datamodule_hparams(datamodule.hparams) - except ValueError as e: - raise ValueError(f'Error while adding data module hparams to model: {e}') + model.add_datamodule_hparams(datamodule.hparams) # links data to the trainer self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) From 0f4dc643d0b8d4ba1139d3dcfde2a188d31a7589 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 22 Oct 2020 12:00:40 +0200 Subject: [PATCH 11/36] Change args of add_datamodule_hparams from hparams to datamodule itself. --- pytorch_lightning/core/lightning.py | 27 ++++++++--------- pytorch_lightning/trainer/training_loop.py | 7 +++-- tests/models/test_hparams.py | 35 +++++++++++++++++----- 3 files changed, 43 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c3e7d2383a9a7..a48fd1ac37e2e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -13,33 +13,30 @@ # limitations under the License. import collections -import copy import inspect import os -import re import tempfile from abc import ABC -from argparse import Namespace -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +from torch import ScriptModule, Tensor +from torch.nn import Module +from torch.optim.optimizer import Optimizer + from pytorch_lightning import _logger as log +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel -from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin -from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO +from pytorch_lightning.core.saving import ModelIO +from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin from pytorch_lightning.utilities.parsing import collect_init_args -from torch import ScriptModule, Tensor -from torch.nn import Module -from torch.optim.optimizer import Optimizer - +from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() @@ -1515,8 +1512,8 @@ def to_torchscript( return torchscript_module - def add_datamodule_hparams(self, hparams): - hparams = self._to_hparams_dict(hparams) + def add_datamodule_hparams(self, datamodule: LightningDataModule): + hparams = self._to_hparams_dict(datamodule.hparams) if not hasattr(self, '_hparams'): self._hparams = hparams self._hparams_initial = hparams diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index cd446aadff847..2b3e0cb4c97c6 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -92,12 +92,13 @@ def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # bind logger and other properties self.trainer.model_connector.copy_trainer_model_properties(model) + # Add hparams from datamodule + if datamodule is not None: + model.add_datamodule_hparams(datamodule) + # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) - if hasattr(datamodule, 'hparams'): - parsing.clean_namespace(datamodule.hparams) - model.add_datamodule_hparams(datamodule.hparams) # links data to the trainer self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 23ef007c99416..8327614a9ba21 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -15,6 +15,7 @@ import os import pickle from argparse import Namespace +from typing import Optional import cloudpickle import pytest @@ -22,9 +23,10 @@ from fsspec.implementations.local import LocalFileSystem from omegaconf import OmegaConf, Container from torch.nn import functional as F -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, TensorDataset from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml from pytorch_lightning.utilities import AttributeDict, is_picklable from tests.base import EvalModelTemplate, TrialMNIST, BoringModel @@ -613,12 +615,29 @@ def test_model_with_fsspec_as_parameter(tmpdir): trainer.test() +class DataModuleWithHparams(LightningDataModule): + def __init__(self, hparams): + super().__init__() + + self.hparams = hparams + self._data = None + + def prepare_data(self, *args, **kwargs): + pass + + def setup(self, stage: Optional[str] = None): + self._data = TensorDataset(torch.randn(100, 20)) + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + return DataLoader(self._data, batch_size=10) + + def test_extending_existing_hparams(tmpdir): """Test that the new hparams are added to the existing ones.""" hparams = {'arg1': 'abc'} model = EvalModelTemplate() old_hparams = copy.deepcopy(model.hparams) - model.add_datamodule_hparams(hparams) + model.add_datamodule_hparams(DataModuleWithHparams(hparams)) old_hparams.update(hparams) assert old_hparams == model.hparams @@ -633,7 +652,7 @@ class DummyModel(LightningModule): hparams = {'arg1': 'abc'} model = DummyModel() - model.add_datamodule_hparams(hparams) + model.add_datamodule_hparams(DataModuleWithHparams(hparams)) assert hparams == model.hparams assert hparams == model.hparams_initial @@ -644,7 +663,7 @@ def test_extending_with_namespace(tmpdir): hparams = Namespace(arg1='abc') model = EvalModelTemplate() old_hparams = copy.deepcopy(model.hparams) - model.add_datamodule_hparams(hparams) + model.add_datamodule_hparams(DataModuleWithHparams(hparams)) old_hparams.update(vars(hparams)) assert old_hparams == model.hparams @@ -657,7 +676,7 @@ def test_extend_with_unsupported_hparams(tmpdir): model = EvalModelTemplate() with pytest.raises(ValueError): - model.add_datamodule_hparams(hparams) + model.add_datamodule_hparams(DataModuleWithHparams(hparams)) def test_extend_with_primitive_hparams(tmpdir): @@ -666,14 +685,14 @@ def test_extend_with_primitive_hparams(tmpdir): model = EvalModelTemplate() with pytest.raises(ValueError): - model.add_datamodule_hparams(hparams) + model.add_datamodule_hparams(DataModuleWithHparams(hparams)) def test_extend_with_collision(tmp_path): """Test that new hparams cannot collide with existing hparams.""" model = EvalModelTemplate() with pytest.raises(ValueError): - model.add_datamodule_hparams({'batch_size': 5}) + model.add_datamodule_hparams(DataModuleWithHparams({'batch_size': 5})) def test_adding_datamodule_hparams(tmpdir): @@ -703,5 +722,5 @@ def test_colliding_datamodule_hparams(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(ValueError, match='Error while adding data module hparams to model:'): + with pytest.raises(ValueError, match='Error while adding datamodule hparams: '): trainer.fit(model, datamodule=data) From 0af99685619d91d20d41b0d5e3f629a30a5a2668 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 22 Oct 2020 12:02:55 +0200 Subject: [PATCH 12/36] Add hparams of datamodule only if it has some. --- pytorch_lightning/core/lightning.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a48fd1ac37e2e..ae0f2b2753e76 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1513,6 +1513,10 @@ def to_torchscript( return torchscript_module def add_datamodule_hparams(self, datamodule: LightningDataModule): + """Add the hparams of a LightningDataModule to the hparams and hparams_initial of this module.""" + if not hasattr(datamodule, 'hparams'): + return + hparams = self._to_hparams_dict(datamodule.hparams) if not hasattr(self, '_hparams'): self._hparams = hparams From 0369c9f899b24a3c6bf068a6e06b833ca04115f8 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 22 Oct 2020 12:06:44 +0200 Subject: [PATCH 13/36] Add one more unit test. --- tests/models/test_hparams.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 8327614a9ba21..f6b982d9710f6 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -632,6 +632,21 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader: return DataLoader(self._data, batch_size=10) +class DataModuleWithoutHparams(LightningDataModule): + def __init__(self): + super().__init__() + self._data = None + + def prepare_data(self, *args, **kwargs): + pass + + def setup(self, stage: Optional[str] = None): + self._data = TensorDataset(torch.randn(100, 20)) + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + return DataLoader(self._data, batch_size=10) + + def test_extending_existing_hparams(tmpdir): """Test that the new hparams are added to the existing ones.""" hparams = {'arg1': 'abc'} @@ -724,3 +739,12 @@ def test_colliding_datamodule_hparams(tmpdir): with pytest.raises(ValueError, match='Error while adding datamodule hparams: '): trainer.fit(model, datamodule=data) + + +def test_adding_hparams_of_datamodule_without_hparams(tmpdir): + model = EvalModelTemplate() + hparams = copy.deepcopy(model.hparams) + model.add_datamodule_hparams(DataModuleWithoutHparams()) + + assert hparams == model.hparams + assert hparams == model.hparams_initial From 6dde1d21c65505bbc5fe9a65a6b38d3e0d67bed2 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 22 Oct 2020 12:10:38 +0200 Subject: [PATCH 14/36] Fix pep8 complaint. --- pytorch_lightning/utilities/hparams_mixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index cd615d8ee38e2..ac6076447cf44 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -25,8 +25,8 @@ class HyperparametersMixin: __jit_unused_properties__ = [ - "hparams", - "hparams_initial" + "hparams", + "hparams_initial" ] def save_hyperparameters(self, *args, frame=None) -> None: From 873b02ae6f0deeb54dff4b2081a7a41573f66005 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 19 Nov 2020 16:57:35 +0100 Subject: [PATCH 15/36] Make training work for SaveHparamsModel. Had to train in order to look if data module hparams are logged. --- tests/models/test_hparams.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index d4e2a743f7100..2a29fbb9b0bb5 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -40,6 +40,11 @@ def __init__(self, hparams): super().__init__() self.save_hyperparameters(hparams) + def training_step(self, batch, batch_idx): + output = self.layer(batch[0]) + loss = self.loss(batch, output) + return {"loss": loss} + class AssignHparamsModel(BoringModel): """ Tests that a model can take an object with explicit setter """ @@ -48,6 +53,20 @@ def __init__(self, hparams): self.hparams = hparams +class BoringDataModule(LightningDataModule): + def __init__(self, hparams): + super().__init__() + self.data = None + + self.hparams = hparams + + def setup(self, stage: Optional[str] = None): + self.data = torch.randn(10, 32) + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + return DataLoader(TensorDataset(self.data), batch_size=10) + + def decorate(func): @functools.wraps(func) def wrapper(*args, **kwargs): @@ -746,7 +765,7 @@ def test_extend_with_collision(tmp_path): def test_adding_datamodule_hparams(tmpdir): """Test that hparams from datamodule are added to the checkpoint.""" model = SaveHparamsModel({'arg1': 5, 'arg2': 'abc'}) - data = TrialMNISTDataModule() + data = BoringDataModule({'data_dir': 'foo'}) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model, datamodule=data) @@ -766,7 +785,7 @@ def test_adding_datamodule_hparams(tmpdir): def test_colliding_datamodule_hparams(tmpdir): """Test that colliding hparams from the datamodule are caught.""" model = SaveHparamsModel({'data_dir': 'abc', 'arg2': 'abc'}) - data = TrialMNISTDataModule() + data = BoringDataModule({'data_dir': 'foo'}) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) From ebb53fb408644f1352e9388fce785666d08b6d45 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Tue, 1 Dec 2020 09:14:40 +0100 Subject: [PATCH 16/36] Update pytorch_lightning/core/lightning.py Co-authored-by: Jirka Borovec --- pytorch_lightning/core/lightning.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ec09abe117704..2b19e35b492ee 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1727,7 +1727,8 @@ def add_datamodule_hparams(self, datamodule: LightningDataModule): colliding_keys += [key for key in hparams.keys() if key in self.hparams_initial] colliding_keys = set(colliding_keys) if colliding_keys: - raise ValueError(f'Error while adding datamodule hparams: ' - f'the keys {colliding_keys} are already present in the model hparams.') + raise ValueError( + f'Error while adding datamodule hparams: the keys {colliding_keys} are already present in the model hparams.' + ) self.hparams.update(hparams) self._hparams_initial.update(hparams) From c36ef72dcaefce04401e2c2a89e0802eb3e6dfbb Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Tue, 1 Dec 2020 09:15:35 +0100 Subject: [PATCH 17/36] Update pytorch_lightning/utilities/hparams_mixin.py Co-authored-by: Jirka Borovec --- pytorch_lightning/utilities/hparams_mixin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index ac6076447cf44..2bd1a468dfb0f 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -143,7 +143,6 @@ def hparams(self, hp: Union[dict, Namespace, Any]): self._hparams_initial = copy.deepcopy(self._hparams) def __get_hparams_assignment_variable(self): - """""" """ looks at the code of the class to figure out what the user named self.hparams this only happens when the user explicitly sets self.hparams From eaaf94e2723d2473749e542f54b15c9149bf81a3 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Mon, 21 Jun 2021 14:02:37 +0200 Subject: [PATCH 18/36] Fix merge conflicts. --- .../trainer/connectors/data_connector.py | 3 + pytorch_lightning/utilities/hparams_mixin.py | 134 +++++++++--------- tests/core/test_datamodules.py | 29 +++- tests/models/test_hparams.py | 30 ++-- 4 files changed, 115 insertions(+), 81 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 4ff7e5aa21a42..618958b7dcd87 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -125,6 +125,9 @@ def attach_datamodule( if is_overridden(method, datamodule): setattr(model, method, getattr(datamodule, method)) + # Add hparams from datamodule + model.add_datamodule_hparams(datamodule) + # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') for hook in batch_transfer_hooks: diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index 2bd1a468dfb0f..3715e084e2270 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -14,12 +14,13 @@ import copy import inspect import re +import types from argparse import Namespace -from typing import Union, Any +from typing import Any, Optional, Sequence, Union -from pytorch_lightning.core.saving import PRIMITIVE_TYPES, ALLOWED_CONFIG_TYPES +from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES from pytorch_lightning.utilities import AttributeDict -from pytorch_lightning.utilities.parsing import get_init_args +from pytorch_lightning.utilities.parsing import save_hyperparameters class HyperparametersMixin: @@ -29,76 +30,71 @@ class HyperparametersMixin: "hparams_initial" ] - def save_hyperparameters(self, *args, frame=None) -> None: - """Save all model arguments. - + def save_hyperparameters( + self, + *args, + ignore: Optional[Union[Sequence[str], str]] = None, + frame: Optional[types.FrameType] = None + ) -> None: + """Save model arguments to ``hparams`` attribute. Args: args: single object of `dict`, `NameSpace` or `OmegaConf` - or string names or argumenst from class `__init__` - - >>> from collections import OrderedDict - >>> class ManuallyArgsModel(HyperparametersMixin): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # manually assign arguments - ... self.save_hyperparameters('arg1', 'arg3') - ... def forward(self, *args, **kwargs): - ... ... - >>> model = ManuallyArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg3": 3.14 - - >>> class AutomaticArgsModel(HyperparametersMixin): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # equivalent automatic - ... self.save_hyperparameters() - ... def forward(self, *args, **kwargs): - ... ... - >>> model = AutomaticArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg2": abc - "arg3": 3.14 - - >>> class SingleArgModel(HyperparametersMixin): - ... def __init__(self, params): - ... super().__init__() - ... # manually assign single argument - ... self.save_hyperparameters(params) - ... def forward(self, *args, **kwargs): - ... ... - >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) - >>> model.hparams - "p1": 1 - "p2": abc - "p3": 3.14 + or string names or arguments from class ``__init__`` + ignore: an argument name or a list of argument names from + class ``__init__`` to be ignored + frame: a frame object. Default is None + Example:: + >>> class ManuallyArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # manually assign arguments + ... self.save_hyperparameters('arg1', 'arg3') + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 + >>> class AutomaticArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # equivalent automatic + ... self.save_hyperparameters() + ... def forward(self, *args, **kwargs): + ... ... + >>> model = AutomaticArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg2": abc + "arg3": 3.14 + >>> class SingleArgModel(LightningModule): + ... def __init__(self, params): + ... super().__init__() + ... # manually assign single argument + ... self.save_hyperparameters(params) + ... def forward(self, *args, **kwargs): + ... ... + >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) + >>> model.hparams + "p1": 1 + "p2": abc + "p3": 3.14 + >>> class ManuallyArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # pass argument(s) to ignore as a string or in a list + ... self.save_hyperparameters(ignore='arg2') + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 """ + # the frame needs to be created in this file. if not frame: frame = inspect.currentframe().f_back - init_args = get_init_args(frame) - assert init_args, "failed to inspect the self init" - if not args: - # take all arguments - hp = init_args - self._hparams_name = "kwargs" if hp else None - else: - # take only listed arguments in `save_hparams` - isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] - if len(isx_non_str) == 1: - hp = args[isx_non_str[0]] - cand_names = [k for k, v in init_args.items() if v == hp] - self._hparams_name = cand_names[0] if cand_names else None - else: - hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} - self._hparams_name = "kwargs" - - # `hparams` are expected here - if hp: - self._set_hparams(hp) - # make deep copy so there is not other runtime changes reflected - self._hparams_initial = copy.deepcopy(self._hparams) + save_hyperparameters(self, *args, ignore=ignore, frame=frame) def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: hp = self._to_hparams_dict(hp) @@ -121,7 +117,7 @@ def _to_hparams_dict(self, hp): return hp @property - def hparams(self) -> Union[AttributeDict, str]: + def hparams(self) -> Union[AttributeDict, dict, Namespace]: if not hasattr(self, "_hparams"): self._hparams = AttributeDict() return self._hparams diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index ff7b3bf816a29..b5568c8b0a0f7 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -13,15 +13,16 @@ # limitations under the License. import pickle from argparse import ArgumentParser -from typing import Any, Dict +from typing import Any, Dict, Optional from unittest import mock -from unittest.mock import call, PropertyMock +from unittest.mock import PropertyMock, call import pytest import torch from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers import BoringDataModule, BoringModel from tests.helpers.datamodules import ClassifDataModule @@ -551,6 +552,26 @@ def test_dm_init_from_datasets_dataloaders(iterable): ]) +class DataModuleWithHparams(LightningDataModule): + def __init__(self, arg0, arg1, kwarg0=None): + super().__init__() + + self.arg0 = arg0 + self.arg1 = arg1 + self.kwarg0 = kwarg0 + + self.save_hyperparameters() + + def prepare_data(self, *args, **kwargs): + pass + + def setup(self, stage: Optional[str] = None): + pass + + def train_dataloader(self, *args, **kwargs): + pass + + def test_simple_hyperparameters_saving(): - data = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0) - assert data.hparams == AttributeDict({'data_dir': data.data_dir}) + data = DataModuleWithHparams(10, "foo", kwarg0="bar") + assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"}) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 5ef0b333b02d8..6ec0d0433e07a 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -30,7 +30,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml -from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable +from pytorch_lightning.utilities import AttributeDict, _HYDRA_EXPERIMENTAL_AVAILABLE, is_picklable from tests.helpers import BoringModel, RandomDataset if _HYDRA_EXPERIMENTAL_AVAILABLE: @@ -783,7 +783,7 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader: def test_extending_existing_hparams(tmpdir): """Test that the new hparams are added to the existing ones.""" hparams = {'arg1': 'abc'} - model = EvalModelTemplate() + model = CustomBoringModel() old_hparams = copy.deepcopy(model.hparams) model.add_datamodule_hparams(DataModuleWithHparams(hparams)) @@ -809,7 +809,7 @@ class DummyModel(LightningModule): def test_extending_with_namespace(tmpdir): """Test that we can extend hparams with a namespace.""" hparams = Namespace(arg1='abc') - model = EvalModelTemplate() + model = CustomBoringModel() old_hparams = copy.deepcopy(model.hparams) model.add_datamodule_hparams(DataModuleWithHparams(hparams)) @@ -821,7 +821,7 @@ def test_extending_with_namespace(tmpdir): def test_extend_with_unsupported_hparams(tmpdir): """Test that usupported hparams structures raise an error when extending.""" hparams = ('arg1', 'abc') - model = EvalModelTemplate() + model = CustomBoringModel() with pytest.raises(ValueError): model.add_datamodule_hparams(DataModuleWithHparams(hparams)) @@ -830,7 +830,7 @@ def test_extend_with_unsupported_hparams(tmpdir): def test_extend_with_primitive_hparams(tmpdir): """Test that primitives raise an error when extending.""" hparams = 5 - model = EvalModelTemplate() + model = CustomBoringModel() with pytest.raises(ValueError): model.add_datamodule_hparams(DataModuleWithHparams(hparams)) @@ -838,11 +838,25 @@ def test_extend_with_primitive_hparams(tmpdir): def test_extend_with_collision(tmp_path): """Test that new hparams cannot collide with existing hparams.""" - model = EvalModelTemplate() + model = CustomBoringModel() with pytest.raises(ValueError): model.add_datamodule_hparams(DataModuleWithHparams({'batch_size': 5})) +class BoringDataModule(LightningDataModule): + def __init__(self, hparams): + super().__init__() + self.data = None + + self.hparams = hparams + + def setup(self, stage: Optional[str] = None): + self.data = torch.randn(10, 32) + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + return DataLoader(TensorDataset(self.data), batch_size=10) + + def test_adding_datamodule_hparams(tmpdir): """Test that hparams from datamodule are added to the checkpoint.""" model = SaveHparamsModel({'arg1': 5, 'arg2': 'abc'}) @@ -859,7 +873,7 @@ def test_adding_datamodule_hparams(tmpdir): assert hparams == model.hparams_initial path_yaml = os.path.join(trainer.logger.log_dir, trainer.logger.NAME_HPARAMS_FILE) - logged_hparams = load_hparams_from_yaml(path_yaml) + logged_hparams = AttributeDict(load_hparams_from_yaml(path_yaml)) assert hparams == logged_hparams @@ -875,7 +889,7 @@ def test_colliding_datamodule_hparams(tmpdir): def test_adding_hparams_of_datamodule_without_hparams(tmpdir): - model = EvalModelTemplate() + model = CustomBoringModel() hparams = copy.deepcopy(model.hparams) model.add_datamodule_hparams(DataModuleWithoutHparams()) From 806a1e5abd2e23078fb1463bd7e22f1e3d16038d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Jun 2021 12:04:00 +0000 Subject: [PATCH 19/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/lightning.py | 12 ++++++------ pytorch_lightning/utilities/hparams_mixin.py | 5 +---- tests/core/test_datamodules.py | 3 ++- tests/models/test_hparams.py | 5 ++++- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2ed5fb110888a..67d14c879a7f8 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -32,12 +32,12 @@ from torch.optim.optimizer import Optimizer from torchmetrics import Metric +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES -from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -1737,10 +1737,10 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: @torch.no_grad() def to_onnx( - self, - file_path: Union[str, Path], - input_sample: Optional[Any] = None, - **kwargs, + self, + file_path: Union[str, Path], + input_sample: Optional[Any] = None, + **kwargs, ): """Saves the model in ONNX format @@ -1880,7 +1880,7 @@ def add_datamodule_hparams(self, datamodule: LightningDataModule): colliding_keys = set(colliding_keys) if colliding_keys: raise ValueError( - f'Error while adding datamodule hparams: the keys {colliding_keys} are already present in the model hparams.' + f'Error while adding datamodule hparams: the keys {colliding_keys} are already present in the model hparams.' ) self.hparams.update(hparams) self._hparams_initial.update(hparams) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index 3715e084e2270..c661d48bd8b2c 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -25,10 +25,7 @@ class HyperparametersMixin: - __jit_unused_properties__ = [ - "hparams", - "hparams_initial" - ] + __jit_unused_properties__ = ["hparams", "hparams_initial"] def save_hyperparameters( self, diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index b5568c8b0a0f7..d5698b7188ada 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -15,7 +15,7 @@ from argparse import ArgumentParser from typing import Any, Dict, Optional from unittest import mock -from unittest.mock import PropertyMock, call +from unittest.mock import call, PropertyMock import pytest import torch @@ -553,6 +553,7 @@ def test_dm_init_from_datasets_dataloaders(iterable): class DataModuleWithHparams(LightningDataModule): + def __init__(self, arg0, arg1, kwarg0=None): super().__init__() diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 6ec0d0433e07a..91ea3a335a677 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -30,7 +30,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml -from pytorch_lightning.utilities import AttributeDict, _HYDRA_EXPERIMENTAL_AVAILABLE, is_picklable +from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable from tests.helpers import BoringModel, RandomDataset if _HYDRA_EXPERIMENTAL_AVAILABLE: @@ -749,6 +749,7 @@ def test_dataclass_lightning_module(tmpdir): class DataModuleWithHparams(LightningDataModule): + def __init__(self, hparams): super().__init__() @@ -766,6 +767,7 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader: class DataModuleWithoutHparams(LightningDataModule): + def __init__(self): super().__init__() self._data = None @@ -844,6 +846,7 @@ def test_extend_with_collision(tmp_path): class BoringDataModule(LightningDataModule): + def __init__(self, hparams): super().__init__() self.data = None From 6a8067901bf6c89c270d6fde1c67373721476705 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Mon, 21 Jun 2021 14:18:00 +0200 Subject: [PATCH 20/36] Fix code style issues. --- pytorch_lightning/core/lightning.py | 13 ++++++------- pytorch_lightning/utilities/hparams_mixin.py | 15 ++++++++++----- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2ed5fb110888a..7e1538c4ad7d0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -14,13 +14,11 @@ """nn.Module with additional great features.""" import collections -import copy import inspect import logging import numbers import os import tempfile -import types import uuid from abc import ABC from pathlib import Path @@ -32,12 +30,12 @@ from torch.optim.optimizer import Optimizer from torchmetrics import Metric +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES -from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -45,9 +43,9 @@ from pytorch_lightning.utilities.distributed import sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin -from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters +from pytorch_lightning.utilities.parsing import collect_init_args from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT, _METRIC_COLLECTION from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -1880,7 +1878,8 @@ def add_datamodule_hparams(self, datamodule: LightningDataModule): colliding_keys = set(colliding_keys) if colliding_keys: raise ValueError( - f'Error while adding datamodule hparams: the keys {colliding_keys} are already present in the model hparams.' + f'Error while adding datamodule hparams: the keys {colliding_keys} ' + f'are already present in the model hparams.' ) self.hparams.update(hparams) self._hparams_initial.update(hparams) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index 3715e084e2270..b37f0dc7eba70 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -37,14 +37,16 @@ def save_hyperparameters( frame: Optional[types.FrameType] = None ) -> None: """Save model arguments to ``hparams`` attribute. + Args: args: single object of `dict`, `NameSpace` or `OmegaConf` or string names or arguments from class ``__init__`` ignore: an argument name or a list of argument names from class ``__init__`` to be ignored frame: a frame object. Default is None + Example:: - >>> class ManuallyArgsModel(LightningModule): + >>> class ManuallyArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # manually assign arguments @@ -55,7 +57,8 @@ class ``__init__`` to be ignored >>> model.hparams "arg1": 1 "arg3": 3.14 - >>> class AutomaticArgsModel(LightningModule): + + >>> class AutomaticArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # equivalent automatic @@ -67,7 +70,8 @@ class ``__init__`` to be ignored "arg1": 1 "arg2": abc "arg3": 3.14 - >>> class SingleArgModel(LightningModule): + + >>> class SingleArgModel(HyperparametersMixin): ... def __init__(self, params): ... super().__init__() ... # manually assign single argument @@ -79,7 +83,8 @@ class ``__init__`` to be ignored "p1": 1 "p2": abc "p3": 3.14 - >>> class ManuallyArgsModel(LightningModule): + + >>> class ManuallyArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # pass argument(s) to ignore as a string or in a list @@ -150,7 +155,7 @@ def __get_hparams_assignment_variable(self): line = re.sub(r"\s+", "", line, flags=re.UNICODE) if ".hparams=" in line: return line.split("=")[1] - except Exception as e: + except Exception: return "hparams" return None From 2228e69d243ab5710b7cef4f51321fd3eb8d5da8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Jun 2021 12:21:22 +0000 Subject: [PATCH 21/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/lightning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 934edc0d6599e..ab96bd1b268ca 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -45,7 +45,7 @@ from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin from pytorch_lightning.utilities.parsing import collect_init_args from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT, _METRIC_COLLECTION +from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -1879,7 +1879,7 @@ def add_datamodule_hparams(self, datamodule: LightningDataModule): if colliding_keys: raise ValueError( f'Error while adding datamodule hparams: the keys {colliding_keys} ' - f'are already present in the model hparams.' + f'are already present in the model hparams.' ) self.hparams.update(hparams) self._hparams_initial.update(hparams) From c14361e847f5da1470a93ff19cad3ebc23139fbf Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Tue, 22 Jun 2021 17:16:16 +0200 Subject: [PATCH 22/36] Fix indentation error from merge. --- pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index d0d4dae0086f1..215ace26a18f1 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -124,7 +124,7 @@ def attach_datamodule( setattr(model, method, getattr(datamodule, method)) # Add hparams from datamodule - model.add_datamodule_hparams(datamodule) + model.add_datamodule_hparams(datamodule) # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') From 5371b2cd07138e771747f5bd131684779f640a12 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 24 Jun 2021 15:10:10 +0200 Subject: [PATCH 23/36] Extract hparam merging function. --- pytorch_lightning/utilities/hparams_mixin.py | 22 ++++++++++++++++ tests/models/test_hparams.py | 27 +++++++++++++++----- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index 750f47868ebed..4fe32116c8de2 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -156,3 +156,25 @@ def __get_hparams_assignment_variable(self): return "hparams" return None + + +def merge_hparams(lightning_hparams: dict, data_hparams: dict) -> dict: + """Merge the hparams of a LightningModule and a LightningDataModule and return them. + + If there is an overlap between the hparams in the LightningModule and in the + LightningDataModule an exception is raised. + + Args: + lightning_hparams: the hyperparameters of a LightningModule + data_hparams: the hyperparameters of a LightningDataModule + """ + colliding_keys = [key for key in data_hparams.keys() if key in lightning_hparams] + if colliding_keys: + raise ValueError( + f'Error while merging hparams: the keys {colliding_keys} are present ' + f'in both the LightningModules and LightningDataModules hparams.' + ) + merged_hparams = copy.deepcopy(lightning_hparams) + merged_hparams.update(data_hparams) + + return merged_hparams diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 91ea3a335a677..5a7ec7da2dbc8 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -30,7 +30,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml -from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable +from pytorch_lightning.utilities import AttributeDict, _HYDRA_EXPERIMENTAL_AVAILABLE, is_picklable +from pytorch_lightning.utilities.hparams_mixin import merge_hparams from tests.helpers import BoringModel, RandomDataset if _HYDRA_EXPERIMENTAL_AVAILABLE: @@ -880,15 +881,27 @@ def test_adding_datamodule_hparams(tmpdir): assert hparams == logged_hparams +def test_merging_hparams(tmpdir): + model_hparams = {'arg1': 'abc', 'arg2': 'abc'} + data_hparams = {'data_dir': 'foo'} + merged_hparams = merge_hparams(model_hparams, data_hparams) + + # Merged hparams contain all keys + assert all(key in merged_hparams for key in model_hparams.keys()) + assert all(key in merged_hparams for key in data_hparams.keys()) + + # Original dicts are not modified + assert not any(key in model_hparams for key in data_hparams.keys()) + assert not any(key in data_hparams for key in model_hparams.keys()) + + def test_colliding_datamodule_hparams(tmpdir): """Test that colliding hparams from the datamodule are caught.""" - model = SaveHparamsModel({'data_dir': 'abc', 'arg2': 'abc'}) - data = BoringDataModule({'data_dir': 'foo'}) - - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + model_hparams = {'data_dir': 'abc', 'arg2': 'abc'} + data_hparams = {'data_dir': 'foo'} - with pytest.raises(ValueError, match='Error while adding datamodule hparams: '): - trainer.fit(model, datamodule=data) + with pytest.raises(ValueError): + merge_hparams(model_hparams, data_hparams) def test_adding_hparams_of_datamodule_without_hparams(tmpdir): From eebc6abd8440463dcf1ace122d8fcd52b2d9a615 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 24 Jun 2021 16:08:11 +0200 Subject: [PATCH 24/36] Hold model and data hparams separately and merge on logging. --- .../trainer/connectors/data_connector.py | 3 - pytorch_lightning/trainer/trainer.py | 10 +- tests/models/test_hparams.py | 153 ++++++------------ 3 files changed, 57 insertions(+), 109 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 215ace26a18f1..c21238f06fe8f 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -123,9 +123,6 @@ def attach_datamodule( if is_overridden(method, datamodule): setattr(model, method, getattr(datamodule, method)) - # Add hparams from datamodule - model.add_datamodule_hparams(datamodule) - # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') for hook in batch_transfer_hooks: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index aa4de1b70ce96..e21cabd4e7d8f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -65,9 +65,11 @@ from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.hparams_mixin import merge_hparams from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS, _EVALUATE_OUTPUT, \ + _PREDICT_OUTPUT log = logging.getLogger(__name__) # warnings to ignore in trainer @@ -896,7 +898,11 @@ def _pre_dispatch(self): # log hyper-parameters if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) - self.logger.log_hyperparams(self.lightning_module.hparams_initial) + hparams_initial = merge_hparams( + self.lightning_module.hparams_initial, + {} if self.datamodule is None else self.datamodule.hparams_initial, + ) + self.logger.log_hyperparams(hparams_initial) self.logger.log_graph(self.lightning_module) self.logger.save() diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 5a7ec7da2dbc8..b294e09c079ac 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -18,6 +18,7 @@ from argparse import Namespace from dataclasses import dataclass from typing import Optional +from unittest import mock import cloudpickle import pytest @@ -749,22 +750,13 @@ def test_dataclass_lightning_module(tmpdir): assert model.hparams == dict(mandatory=33, optional="cocofruit") -class DataModuleWithHparams(LightningDataModule): - - def __init__(self, hparams): - super().__init__() - - self.hparams = hparams - self._data = None - - def prepare_data(self, *args, **kwargs): - pass - - def setup(self, stage: Optional[str] = None): - self._data = TensorDataset(torch.randn(100, 20)) +class NoHparamsModel(BoringModel): + """ Tests a model without hparams. """ - def train_dataloader(self, *args, **kwargs) -> DataLoader: - return DataLoader(self._data, batch_size=10) + def training_step(self, batch, batch_idx): + output = self.layer(batch[0]) + loss = self.loss(batch, output) + return {"loss": loss} class DataModuleWithoutHparams(LightningDataModule): @@ -777,76 +769,13 @@ def prepare_data(self, *args, **kwargs): pass def setup(self, stage: Optional[str] = None): - self._data = TensorDataset(torch.randn(100, 20)) + self.data = torch.randn(10, 32) def train_dataloader(self, *args, **kwargs) -> DataLoader: - return DataLoader(self._data, batch_size=10) - - -def test_extending_existing_hparams(tmpdir): - """Test that the new hparams are added to the existing ones.""" - hparams = {'arg1': 'abc'} - model = CustomBoringModel() - old_hparams = copy.deepcopy(model.hparams) - model.add_datamodule_hparams(DataModuleWithHparams(hparams)) - - old_hparams.update(hparams) - assert old_hparams == model.hparams - assert old_hparams == model.hparams_initial - - -def test_extending_non_existing_hparams(tmpdir): - """Test that hparams are created if they do not exist yet when we try to extend them.""" - - class DummyModel(LightningModule): - pass - - hparams = {'arg1': 'abc'} - model = DummyModel() - model.add_datamodule_hparams(DataModuleWithHparams(hparams)) - - assert hparams == model.hparams - assert hparams == model.hparams_initial - - -def test_extending_with_namespace(tmpdir): - """Test that we can extend hparams with a namespace.""" - hparams = Namespace(arg1='abc') - model = CustomBoringModel() - old_hparams = copy.deepcopy(model.hparams) - model.add_datamodule_hparams(DataModuleWithHparams(hparams)) - - old_hparams.update(vars(hparams)) - assert old_hparams == model.hparams - assert old_hparams == model.hparams_initial - - -def test_extend_with_unsupported_hparams(tmpdir): - """Test that usupported hparams structures raise an error when extending.""" - hparams = ('arg1', 'abc') - model = CustomBoringModel() - - with pytest.raises(ValueError): - model.add_datamodule_hparams(DataModuleWithHparams(hparams)) - - -def test_extend_with_primitive_hparams(tmpdir): - """Test that primitives raise an error when extending.""" - hparams = 5 - model = CustomBoringModel() - - with pytest.raises(ValueError): - model.add_datamodule_hparams(DataModuleWithHparams(hparams)) - - -def test_extend_with_collision(tmp_path): - """Test that new hparams cannot collide with existing hparams.""" - model = CustomBoringModel() - with pytest.raises(ValueError): - model.add_datamodule_hparams(DataModuleWithHparams({'batch_size': 5})) + return DataLoader(TensorDataset(self.data), batch_size=10) -class BoringDataModule(LightningDataModule): +class DataModuleWithHparams(LightningDataModule): def __init__(self, hparams): super().__init__() @@ -861,24 +790,49 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader: return DataLoader(TensorDataset(self.data), batch_size=10) -def test_adding_datamodule_hparams(tmpdir): - """Test that hparams from datamodule are added to the checkpoint.""" - model = SaveHparamsModel({'arg1': 5, 'arg2': 'abc'}) - data = BoringDataModule({'data_dir': 'foo'}) +def _get_mock_logger(tmpdir): + mock_logger = mock.MagicMock(name="logger") + mock_logger.name = "mock_logger" + mock_logger.save_dir = tmpdir + mock_logger.version = "0" + del mock_logger.__iter__ + return mock_logger - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + +@pytest.mark.parametrize("model", (SaveHparamsModel({'arg1': 5, 'arg2': 'abc'}), NoHparamsModel())) +@pytest.mark.parametrize("data", (DataModuleWithHparams({'data_dir': 'foo'}), DataModuleWithoutHparams())) +def test_adding_datamodule_hparams(tmpdir, model, data): + """Test that hparams from datamodule and model are logged.""" + org_model_hparams = copy.deepcopy(model.hparams_initial) + org_data_hparams = copy.deepcopy(data.hparams_initial) + + mock_logger = _get_mock_logger(tmpdir) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger) trainer.fit(model, datamodule=data) - hparams = model.hparams - hparams.update(data.hparams) - raw_checkpoint_path = _raw_checkpoint_path(trainer) - model = SaveHparamsModel.load_from_checkpoint(raw_checkpoint_path) - assert hparams == model.hparams - assert hparams == model.hparams_initial + # Hparams of model and data were not modified + assert org_model_hparams == model.hparams + assert org_data_hparams == data.hparams - path_yaml = os.path.join(trainer.logger.log_dir, trainer.logger.NAME_HPARAMS_FILE) - logged_hparams = AttributeDict(load_hparams_from_yaml(path_yaml)) - assert hparams == logged_hparams + # Merged hparams were logged + merged_hparams = copy.deepcopy(org_model_hparams) + merged_hparams.update(org_data_hparams) + mock_logger.log_hyperparams.assert_called_with(merged_hparams) + + +def test_no_datamodule_for_hparams(tmpdir): + """Test that hparams model are logged if no datamodule is used.""" + model = SaveHparamsModel({'arg1': 5, 'arg2': 'abc'}) + org_model_hparams = copy.deepcopy(model.hparams_initial) + data = DataModuleWithoutHparams() + data.setup() + + mock_logger = _get_mock_logger(tmpdir) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger) + trainer.fit(model, data.train_dataloader()) + + # Merged hparams were logged + mock_logger.log_hyperparams.assert_called_with(org_model_hparams) def test_merging_hparams(tmpdir): @@ -902,12 +856,3 @@ def test_colliding_datamodule_hparams(tmpdir): with pytest.raises(ValueError): merge_hparams(model_hparams, data_hparams) - - -def test_adding_hparams_of_datamodule_without_hparams(tmpdir): - model = CustomBoringModel() - hparams = copy.deepcopy(model.hparams) - model.add_datamodule_hparams(DataModuleWithoutHparams()) - - assert hparams == model.hparams - assert hparams == model.hparams_initial From 7faed2946e50b703477409d17aaf393f3128c65c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Jun 2021 14:09:33 +0000 Subject: [PATCH 25/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/trainer.py | 3 +-- pytorch_lightning/utilities/hparams_mixin.py | 4 ++-- tests/models/test_hparams.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e21cabd4e7d8f..f31df8c9f68b8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -68,8 +68,7 @@ from pytorch_lightning.utilities.hparams_mixin import merge_hparams from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS, _EVALUATE_OUTPUT, \ - _PREDICT_OUTPUT +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS log = logging.getLogger(__name__) # warnings to ignore in trainer diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index 4fe32116c8de2..c9576f89d8542 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -171,8 +171,8 @@ def merge_hparams(lightning_hparams: dict, data_hparams: dict) -> dict: colliding_keys = [key for key in data_hparams.keys() if key in lightning_hparams] if colliding_keys: raise ValueError( - f'Error while merging hparams: the keys {colliding_keys} are present ' - f'in both the LightningModules and LightningDataModules hparams.' + f'Error while merging hparams: the keys {colliding_keys} are present ' + f'in both the LightningModules and LightningDataModules hparams.' ) merged_hparams = copy.deepcopy(lightning_hparams) merged_hparams.update(data_hparams) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index b294e09c079ac..da8990572cb86 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -31,7 +31,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml -from pytorch_lightning.utilities import AttributeDict, _HYDRA_EXPERIMENTAL_AVAILABLE, is_picklable +from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable from pytorch_lightning.utilities.hparams_mixin import merge_hparams from tests.helpers import BoringModel, RandomDataset From 8f27cb053d17937cb9ee73a2f4df379e0a0fb4e1 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 28 Jun 2021 16:47:24 +0100 Subject: [PATCH 26/36] Fixes --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 64a05d02fd4a4..51c7516bd7c62 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -36,7 +36,7 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES +from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors From d12d3c17762c7f26c79d04f8ebde8ed5a23d4f64 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 8 Jul 2021 23:49:38 +0530 Subject: [PATCH 27/36] Update hparams mixin --- pytorch_lightning/core/lightning.py | 86 -------------------- pytorch_lightning/utilities/hparams_mixin.py | 4 +- 2 files changed, 2 insertions(+), 88 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b55fac1806bf8..6238f4c03782d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1829,92 +1829,6 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: parents_arguments.update(args) return self_arguments, parents_arguments - def save_hyperparameters( - self, - *args, - ignore: Optional[Union[Sequence[str], str]] = None, - frame: Optional[types.FrameType] = None - ) -> None: - """Save model arguments to the ``hparams`` attribute. - - Args: - args: single object of type :class:`dict`, :class:`~argparse.Namespace`, `OmegaConf` - or strings representing the argument names in ``__init__``. - ignore: an argument name or a list of argument names in ``__init__`` to be ignored - frame: a frame object. Default is ``None``. - - Example:: - - >>> class ManuallyArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # manually assign arguments - ... self.save_hyperparameters('arg1', 'arg3') - ... def forward(self, *args, **kwargs): - ... ... - >>> model = ManuallyArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg3": 3.14 - - >>> class AutomaticArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # equivalent automatic - ... self.save_hyperparameters() - ... def forward(self, *args, **kwargs): - ... ... - >>> model = AutomaticArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg2": abc - "arg3": 3.14 - - >>> class SingleArgModel(LightningModule): - ... def __init__(self, params): - ... super().__init__() - ... # manually assign single argument - ... self.save_hyperparameters(params) - ... def forward(self, *args, **kwargs): - ... ... - >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) - >>> model.hparams - "p1": 1 - "p2": abc - "p3": 3.14 - - >>> class ManuallyArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # pass argument(s) to ignore as a string or in a list - ... self.save_hyperparameters(ignore='arg2') - ... def forward(self, *args, **kwargs): - ... ... - >>> model = ManuallyArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg3": 3.14 - """ - # the frame needs to be created in this file. - if not frame: - frame = inspect.currentframe().f_back - save_hyperparameters(self, *args, ignore=ignore, frame=frame) - - def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: - if isinstance(hp, Namespace): - hp = vars(hp) - if isinstance(hp, dict): - hp = AttributeDict(hp) - elif isinstance(hp, PRIMITIVE_TYPES): - raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.") - elif not isinstance(hp, ALLOWED_CONFIG_TYPES): - raise ValueError(f"Unsupported config type of {type(hp)}.") - - if isinstance(hp, dict) and isinstance(self.hparams, dict): - self.hparams.update(hp) - else: - self._hparams = hp - @torch.no_grad() def to_onnx( self, diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index c9576f89d8542..e2dec9a2f2a54 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -106,7 +106,8 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: else: self._hparams = hp - def _to_hparams_dict(self, hp): + @staticmethod + def _to_hparams_dict(hp: Union[dict, Namespace, str]): if isinstance(hp, Namespace): hp = vars(hp) if isinstance(hp, dict): @@ -115,7 +116,6 @@ def _to_hparams_dict(self, hp): raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.") elif not isinstance(hp, ALLOWED_CONFIG_TYPES): raise ValueError(f"Unsupported config type of {type(hp)}.") - return hp @property From d9c74d9424928c5bcec7008979252487c72cdc53 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 9 Jul 2021 00:36:28 +0530 Subject: [PATCH 28/36] Update trainer & Lightning module --- pytorch_lightning/core/lightning.py | 22 ---------------------- pytorch_lightning/trainer/trainer.py | 9 +++++---- 2 files changed, 5 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6238f4c03782d..971ed3d975494 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -31,7 +31,6 @@ from torch.optim.optimizer import Optimizer from torchmetrics import Metric -from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary @@ -1960,27 +1959,6 @@ def to_torchscript( return torchscript_module - def add_datamodule_hparams(self, datamodule: LightningDataModule): - """Add the hparams of a LightningDataModule to the hparams and hparams_initial of this module.""" - if not hasattr(datamodule, 'hparams'): - return - - hparams = self._to_hparams_dict(datamodule.hparams) - if not hasattr(self, '_hparams'): - self._hparams = hparams - self._hparams_initial = hparams - else: - colliding_keys = [key for key in hparams.keys() if key in self.hparams] - colliding_keys += [key for key in hparams.keys() if key in self.hparams_initial] - colliding_keys = set(colliding_keys) - if colliding_keys: - raise ValueError( - f'Error while adding datamodule hparams: the keys {colliding_keys} ' - f'are already present in the model hparams.' - ) - self.hparams.update(hparams) - self._hparams_initial.update(hparams) - @property def model_size(self) -> float: """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ecd4d4a7ee402..29caea2e5e2ed 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -904,14 +904,15 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, def _pre_dispatch(self): self.accelerator.pre_dispatch(self) + self._log_hyperparams() + def _log_hyperparams(self): # log hyper-parameters if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) - hparams_initial = merge_hparams( - self.lightning_module.hparams_initial, - {} if self.datamodule is None else self.datamodule.hparams_initial, - ) + datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {} + hparams_initial = merge_hparams(self.lightning_module.hparams_initial, datamodule_hparams) + self.logger.log_hyperparams(hparams_initial) self.logger.log_graph(self.lightning_module) self.logger.save() From d2f477a79df479044b64074009d875d7ce386f1e Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 9 Jul 2021 02:25:33 +0530 Subject: [PATCH 29/36] Fix torchscript issue --- pytorch_lightning/core/lightning.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 971ed3d975494..735f8ab160c1f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -78,6 +78,7 @@ class LightningModule( "model_size", "automatic_optimization", "truncated_bptt_steps", + "loaded_optimizer_states_dict", ] + DeviceDtypeModuleMixin.__jit_unused_properties__ + HyperparametersMixin.__jit_unused_properties__ def __init__(self, *args: Any, **kwargs: Any) -> None: From 2ba18c9f7eeb3447dac053592f8941bb2203fc6d Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 9 Jul 2021 02:34:35 +0530 Subject: [PATCH 30/36] Update test --- tests/models/test_hparams.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index da8990572cb86..7537de5fce5c8 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -780,8 +780,7 @@ class DataModuleWithHparams(LightningDataModule): def __init__(self, hparams): super().__init__() self.data = None - - self.hparams = hparams + self.save_hyperparameters(hparams) def setup(self, stage: Optional[str] = None): self.data = torch.randn(10, 32) From 55a15c093e86c5fef2b0f564570771f670fb760a Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 9 Jul 2021 03:17:14 +0530 Subject: [PATCH 31/36] Update tests --- tests/models/test_hparams.py | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 7537de5fce5c8..0203597410637 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -17,7 +17,6 @@ import pickle from argparse import Namespace from dataclasses import dataclass -from typing import Optional from unittest import mock import cloudpickle @@ -25,7 +24,7 @@ import torch from fsspec.implementations.local import LocalFileSystem from omegaconf import Container, OmegaConf -from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint @@ -46,11 +45,6 @@ def __init__(self, hparams): super().__init__() self.save_hyperparameters(hparams) - def training_step(self, batch, batch_idx): - output = self.layer(batch[0]) - loss = self.loss(batch, output) - return {"loss": loss} - def decorate(func): @@ -753,40 +747,21 @@ def test_dataclass_lightning_module(tmpdir): class NoHparamsModel(BoringModel): """ Tests a model without hparams. """ - def training_step(self, batch, batch_idx): - output = self.layer(batch[0]) - loss = self.loss(batch, output) - return {"loss": loss} - class DataModuleWithoutHparams(LightningDataModule): - def __init__(self): - super().__init__() - self._data = None - - def prepare_data(self, *args, **kwargs): - pass - - def setup(self, stage: Optional[str] = None): - self.data = torch.randn(10, 32) - def train_dataloader(self, *args, **kwargs) -> DataLoader: - return DataLoader(TensorDataset(self.data), batch_size=10) + return DataLoader(RandomDataset(32, 64), batch_size=32) class DataModuleWithHparams(LightningDataModule): def __init__(self, hparams): super().__init__() - self.data = None self.save_hyperparameters(hparams) - def setup(self, stage: Optional[str] = None): - self.data = torch.randn(10, 32) - def train_dataloader(self, *args, **kwargs) -> DataLoader: - return DataLoader(TensorDataset(self.data), batch_size=10) + return DataLoader(RandomDataset(32, 64), batch_size=32) def _get_mock_logger(tmpdir): From 0d0508667ad748185aac60a5d89d58b3ad6b030f Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 9 Jul 2021 03:25:07 +0530 Subject: [PATCH 32/36] Update datamodule test --- tests/core/test_datamodules.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 21e69bbd3810f..6203e93e63e2f 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -13,7 +13,7 @@ # limitations under the License. import pickle from argparse import ArgumentParser -from typing import Any, Dict, Optional +from typing import Any, Dict from unittest import mock from unittest.mock import call, PropertyMock @@ -558,22 +558,8 @@ class DataModuleWithHparams(LightningDataModule): def __init__(self, arg0, arg1, kwarg0=None): super().__init__() - - self.arg0 = arg0 - self.arg1 = arg1 - self.kwarg0 = kwarg0 - self.save_hyperparameters() - def prepare_data(self, *args, **kwargs): - pass - - def setup(self, stage: Optional[str] = None): - pass - - def train_dataloader(self, *args, **kwargs): - pass - def test_simple_hyperparameters_saving(): data = DataModuleWithHparams(10, "foo", kwarg0="bar") From a08304e9a48d6c6ac43bd3d8a58ded1b7a5178de Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 9 Jul 2021 04:10:53 +0530 Subject: [PATCH 33/36] Remove merge_hparams & update tests --- pytorch_lightning/trainer/trainer.py | 11 ++++++-- pytorch_lightning/utilities/hparams_mixin.py | 22 --------------- tests/models/test_hparams.py | 29 ++++++-------------- 3 files changed, 17 insertions(+), 45 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 29caea2e5e2ed..4c034ac843361 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -77,7 +77,6 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.hparams_mixin import merge_hparams from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -911,7 +910,15 @@ def _log_hyperparams(self): if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {} - hparams_initial = merge_hparams(self.lightning_module.hparams_initial, datamodule_hparams) + lightning_hparams = self.lightning_module.hparams_initial + colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys() + if colliding_keys: + raise MisconfigurationException( + f"Error while merging hparams: the keys {colliding_keys} are present " + "in both the LightningModule's and LightningDataModule's hparams." + ) + + hparams_initial = {**lightning_hparams, **datamodule_hparams} self.logger.log_hyperparams(hparams_initial) self.logger.log_graph(self.lightning_module) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index e2dec9a2f2a54..c6c202d23cfbf 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -156,25 +156,3 @@ def __get_hparams_assignment_variable(self): return "hparams" return None - - -def merge_hparams(lightning_hparams: dict, data_hparams: dict) -> dict: - """Merge the hparams of a LightningModule and a LightningDataModule and return them. - - If there is an overlap between the hparams in the LightningModule and in the - LightningDataModule an exception is raised. - - Args: - lightning_hparams: the hyperparameters of a LightningModule - data_hparams: the hyperparameters of a LightningDataModule - """ - colliding_keys = [key for key in data_hparams.keys() if key in lightning_hparams] - if colliding_keys: - raise ValueError( - f'Error while merging hparams: the keys {colliding_keys} are present ' - f'in both the LightningModules and LightningDataModules hparams.' - ) - merged_hparams = copy.deepcopy(lightning_hparams) - merged_hparams.update(data_hparams) - - return merged_hparams diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 0203597410637..1ba92bd754d60 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -31,7 +31,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable -from pytorch_lightning.utilities.hparams_mixin import merge_hparams +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset if _HYDRA_EXPERIMENTAL_AVAILABLE: @@ -803,30 +803,17 @@ def test_no_datamodule_for_hparams(tmpdir): mock_logger = _get_mock_logger(tmpdir) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger) - trainer.fit(model, data.train_dataloader()) + trainer.fit(model, datamodule=data) # Merged hparams were logged mock_logger.log_hyperparams.assert_called_with(org_model_hparams) -def test_merging_hparams(tmpdir): - model_hparams = {'arg1': 'abc', 'arg2': 'abc'} - data_hparams = {'data_dir': 'foo'} - merged_hparams = merge_hparams(model_hparams, data_hparams) - - # Merged hparams contain all keys - assert all(key in merged_hparams for key in model_hparams.keys()) - assert all(key in merged_hparams for key in data_hparams.keys()) - - # Original dicts are not modified - assert not any(key in model_hparams for key in data_hparams.keys()) - assert not any(key in data_hparams for key in model_hparams.keys()) +def test_colliding_hparams(tmpdir): + model = SaveHparamsModel({'data_dir': 'abc', 'arg2': 'abc'}) + data = DataModuleWithHparams({'data_dir': 'foo'}) -def test_colliding_datamodule_hparams(tmpdir): - """Test that colliding hparams from the datamodule are caught.""" - model_hparams = {'data_dir': 'abc', 'arg2': 'abc'} - data_hparams = {'data_dir': 'foo'} - - with pytest.raises(ValueError): - merge_hparams(model_hparams, data_hparams) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + with pytest.raises(MisconfigurationException, match=r'Error while merging hparams:'): + trainer.fit(model, datamodule=data) From d517fd381c5154ba10d8f373f8184363114ee181 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 9 Jul 2021 04:13:19 +0530 Subject: [PATCH 34/36] Update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e8e1749f4350..cb15ded665e0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -146,6 +146,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307)) +- Added support for `save_hyperparameters` in `LightningDataModule` ([#3792](https://github.com/PyTorchLightning/pytorch-lightning/pull/3792)) + + ### Changed From 6297bd62b0f9c7d3e99c46be0526cf7868d64970 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 9 Jul 2021 14:15:01 +0530 Subject: [PATCH 35/36] Remove hparams setter --- pytorch_lightning/utilities/hparams_mixin.py | 29 +------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index c6c202d23cfbf..94a232972abeb 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -13,10 +13,9 @@ # limitations under the License. import copy import inspect -import re import types from argparse import Namespace -from typing import Any, Optional, Sequence, Union +from typing import Optional, Sequence, Union from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES from pytorch_lightning.utilities import AttributeDict @@ -130,29 +129,3 @@ def hparams_initial(self) -> AttributeDict: return AttributeDict() # prevent any change return copy.deepcopy(self._hparams_initial) - - @hparams.setter - def hparams(self, hp: Union[dict, Namespace, Any]): - hparams_assignment_name = self.__get_hparams_assignment_variable() - self._hparams_name = hparams_assignment_name - self._set_hparams(hp) - # this resolves case when user does not uses `save_hyperparameters` and do hard assignement in init - if not hasattr(self, "_hparams_initial"): - self._hparams_initial = copy.deepcopy(self._hparams) - - def __get_hparams_assignment_variable(self): - """ - looks at the code of the class to figure out what the user named self.hparams - this only happens when the user explicitly sets self.hparams - """ - try: - class_code = inspect.getsource(self.__class__) - lines = class_code.split("\n") - for line in lines: - line = re.sub(r"\s+", "", line, flags=re.UNICODE) - if ".hparams=" in line: - return line.split("=")[1] - except Exception: - return "hparams" - - return None From 43c75fe6458ca47cfee721bd79872f77e577b5a2 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 9 Jul 2021 18:19:47 +0530 Subject: [PATCH 36/36] Update pytorch_lightning/utilities/hparams_mixin.py Co-authored-by: Ethan Harris --- pytorch_lightning/utilities/hparams_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index 94a232972abeb..8dd4b23c89398 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -32,7 +32,7 @@ def save_hyperparameters( ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None ) -> None: - """Save model arguments to ``hparams`` attribute. + """Save arguments to ``hparams`` attribute. Args: args: single object of `dict`, `NameSpace` or `OmegaConf`