Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hyperparameters for datamodule #3792

Merged
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
fdd3651
Extract hparam functions to mixin.
Oct 2, 2020
9629b32
Make LightningDataModule inherit from HyperparametersMixin.
Oct 2, 2020
2771894
Add function to extend hparams.
Oct 2, 2020
ea744ae
Add hparams of DataModule to model before training.
Oct 2, 2020
3ea5f32
Change examples due to cyclic import.
Oct 2, 2020
5eb202e
Merge branch 'master' into hyperparameters_for_datamodule
tilman151 Oct 5, 2020
0eaf76a
Merge branch 'master' into hyperparameters_for_datamodule
Oct 22, 2020
9ea7989
Add initital_hparams to mixin and move/rename extend_hparams.
Oct 22, 2020
f3ea974
Update unit tests.
Oct 22, 2020
fc8af41
Add hparams from datamodule to hparams_inital, too.
Oct 22, 2020
3f8d44f
Test if datamodule hparams are logged to trainer loggers, too.
Oct 22, 2020
822cce4
Simplify error handling.
Oct 22, 2020
0f4dc64
Change args of add_datamodule_hparams from hparams to datamodule itself.
Oct 22, 2020
0af9968
Add hparams of datamodule only if it has some.
Oct 22, 2020
0369c9f
Add one more unit test.
Oct 22, 2020
6dde1d2
Fix pep8 complaint.
Oct 22, 2020
374932e
Merge branch 'master' into hyperparameters_for_datamodule
tilman151 Oct 22, 2020
17ca42c
Merge branch 'master' into hyperparameters_for_datamodule
tilman151 Oct 23, 2020
4d8dc2a
Merge branch 'master' into hyperparameters_for_datamodule
Nov 13, 2020
d050be0
Merge remote-tracking branch 'origin/hyperparameters_for_datamodule' …
Nov 13, 2020
001f754
Merge branch 'master' into hyperparameters_for_datamodule
tilman151 Nov 16, 2020
64acd07
Merge branch 'master' into hyperparameters_for_datamodule
tilman151 Nov 18, 2020
958b6e5
Merge branch 'master' into hyperparameters_for_datamodule
Nov 19, 2020
873b02a
Make training work for SaveHparamsModel.
Nov 19, 2020
2268d6e
Merge branch 'master' into hyperparameters_for_datamodule
tilman151 Nov 20, 2020
548e976
Merge branch 'master' into hyperparameters_for_datamodule
rohitgr7 Nov 22, 2020
b0e1f04
Merge branch 'master' into hyperparameters_for_datamodule
tilman151 Nov 24, 2020
ebb53fb
Update pytorch_lightning/core/lightning.py
tilman151 Dec 1, 2020
c36ef72
Update pytorch_lightning/utilities/hparams_mixin.py
tilman151 Dec 1, 2020
5e2d8d5
Merge branch 'master' into hyperparameters_for_datamodule
Jun 21, 2021
eaaf94e
Fix merge conflicts.
Jun 21, 2021
806a1e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2021
6a80679
Fix code style issues.
Jun 21, 2021
3481039
Merge remote-tracking branch 'origin/hyperparameters_for_datamodule' …
Jun 21, 2021
2228e69
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2021
dbca3e4
Merge branch 'master' into hyperparameters_for_datamodule
Jun 22, 2021
c14361e
Fix indentation error from merge.
Jun 22, 2021
a26bd9e
Merge branch 'master' into hyperparameters_for_datamodule
Jun 22, 2021
5371b2c
Extract hparam merging function.
Jun 24, 2021
eebc6ab
Hold model and data hparams separately and merge on logging.
Jun 24, 2021
7faed29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
1ade79f
Merge branch 'master' into hyperparameters_for_datamodule
ethanwharris Jun 28, 2021
8f27cb0
Fixes
ethanwharris Jun 28, 2021
05fad8f
Merge branch 'master' into hyperparameters_for_datamodule
kaushikb11 Jul 8, 2021
d12d3c1
Update hparams mixin
kaushikb11 Jul 8, 2021
d9c74d9
Update trainer & Lightning module
kaushikb11 Jul 8, 2021
d2f477a
Fix torchscript issue
kaushikb11 Jul 8, 2021
2ba18c9
Update test
kaushikb11 Jul 8, 2021
55a15c0
Update tests
kaushikb11 Jul 8, 2021
0d05086
Update datamodule test
kaushikb11 Jul 8, 2021
a08304e
Remove merge_hparams & update tests
kaushikb11 Jul 8, 2021
d517fd3
Update changelog
kaushikb11 Jul 8, 2021
6297bd6
Remove hparams setter
kaushikb11 Jul 9, 2021
43c75fe
Update pytorch_lightning/utilities/hparams_mixin.py
kaushikb11 Jul 9, 2021
d3d5cf7
Merge branch 'master' into hyperparameters_for_datamodule
kaushikb11 Jul 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307))


- Added support for `save_hyperparameters` in `LightningDataModule` ([#3792](https://github.com/PyTorchLightning/pytorch-lightning/pull/3792))


### Changed


Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin


class LightningDataModule(CheckpointHooks, DataHooks):
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.
Expand Down
122 changes: 6 additions & 116 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@
"""The LightningModule - an nn.Module with many additional features."""

import collections
import copy
import inspect
import logging
import numbers
import os
import tempfile
import types
import uuid
from abc import ABC
from argparse import Namespace
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -38,15 +35,16 @@
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin
from pytorch_lightning.utilities.parsing import collect_init_args
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
Expand All @@ -58,6 +56,7 @@
class LightningModule(
ABC,
DeviceDtypeModuleMixin,
HyperparametersMixin,
GradInformation,
ModelIO,
ModelHooks,
Expand All @@ -70,8 +69,6 @@ class LightningModule(
__jit_unused_properties__ = [
"datamodule",
"example_input_array",
"hparams",
"hparams_initial",
"on_gpu",
"current_epoch",
"global_step",
Expand All @@ -82,7 +79,7 @@ class LightningModule(
"automatic_optimization",
"truncated_bptt_steps",
"loaded_optimizer_states_dict",
] + DeviceDtypeModuleMixin.__jit_unused_properties__
] + DeviceDtypeModuleMixin.__jit_unused_properties__ + HyperparametersMixin.__jit_unused_properties__

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -1832,92 +1829,6 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]:
parents_arguments.update(args)
return self_arguments, parents_arguments

def save_hyperparameters(
self,
*args,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None
) -> None:
"""Save model arguments to the ``hparams`` attribute.

Args:
args: single object of type :class:`dict`, :class:`~argparse.Namespace`, `OmegaConf`
or strings representing the argument names in ``__init__``.
ignore: an argument name or a list of argument names in ``__init__`` to be ignored
frame: a frame object. Default is ``None``.

Example::

>>> class ManuallyArgsModel(LightningModule):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # manually assign arguments
... self.save_hyperparameters('arg1', 'arg3')
... def forward(self, *args, **kwargs):
... ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14

>>> class AutomaticArgsModel(LightningModule):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # equivalent automatic
... self.save_hyperparameters()
... def forward(self, *args, **kwargs):
... ...
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg2": abc
"arg3": 3.14

>>> class SingleArgModel(LightningModule):
... def __init__(self, params):
... super().__init__()
... # manually assign single argument
... self.save_hyperparameters(params)
... def forward(self, *args, **kwargs):
... ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.hparams
"p1": 1
"p2": abc
"p3": 3.14

>>> class ManuallyArgsModel(LightningModule):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # pass argument(s) to ignore as a string or in a list
... self.save_hyperparameters(ignore='arg2')
... def forward(self, *args, **kwargs):
... ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
"""
# the frame needs to be created in this file.
if not frame:
frame = inspect.currentframe().f_back
save_hyperparameters(self, *args, ignore=ignore, frame=frame)

def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
if isinstance(hp, Namespace):
hp = vars(hp)
if isinstance(hp, dict):
hp = AttributeDict(hp)
elif isinstance(hp, PRIMITIVE_TYPES):
raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.")
elif not isinstance(hp, ALLOWED_CONFIG_TYPES):
raise ValueError(f"Unsupported config type of {type(hp)}.")

if isinstance(hp, dict) and isinstance(self.hparams, dict):
self.hparams.update(hp)
else:
self._hparams = hp

@torch.no_grad()
def to_onnx(
self,
Expand Down Expand Up @@ -2049,27 +1960,6 @@ def to_torchscript(

return torchscript_module

@property
def hparams(self) -> Union[AttributeDict, dict, Namespace]:
"""
The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user.
For the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
"""
if not hasattr(self, "_hparams"):
self._hparams = AttributeDict()
return self._hparams

@property
def hparams_initial(self) -> AttributeDict:
"""
The collection of hyperparameters saved with :meth:`save_hyperparameters`. These contents are read-only.
Manual updates to the saved hyperparameters can instead be performed through :attr:`hparams`.
"""
if not hasattr(self, "_hparams_initial"):
return AttributeDict()
# prevent any change
return copy.deepcopy(self._hparams_initial)

@property
def model_size(self) -> float:
"""
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,11 +903,24 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT,

def _pre_dispatch(self):
self.accelerator.pre_dispatch(self)
self._log_hyperparams()

def _log_hyperparams(self):
# log hyper-parameters
if self.logger is not None:
# save exp to get started (this is where the first experiment logs are written)
self.logger.log_hyperparams(self.lightning_module.hparams_initial)
datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {}
lightning_hparams = self.lightning_module.hparams_initial
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
if colliding_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {colliding_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams."
)

hparams_initial = {**lightning_hparams, **datamodule_hparams}

self.logger.log_hyperparams(hparams_initial)
self.logger.log_graph(self.lightning_module)
self.logger.save()

Expand Down
131 changes: 131 additions & 0 deletions pytorch_lightning/utilities/hparams_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright The PyTorch Lightning team.
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import types
from argparse import Namespace
from typing import Optional, Sequence, Union

from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.parsing import save_hyperparameters


class HyperparametersMixin:
tilman151 marked this conversation as resolved.
Show resolved Hide resolved

__jit_unused_properties__ = ["hparams", "hparams_initial"]

def save_hyperparameters(
self,
*args,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None
) -> None:
"""Save arguments to ``hparams`` attribute.

Args:
args: single object of `dict`, `NameSpace` or `OmegaConf`
or string names or arguments from class ``__init__``
ignore: an argument name or a list of argument names from
class ``__init__`` to be ignored
frame: a frame object. Default is None

Example::
>>> class ManuallyArgsModel(HyperparametersMixin):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # manually assign arguments
... self.save_hyperparameters('arg1', 'arg3')
... def forward(self, *args, **kwargs):
... ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14

>>> class AutomaticArgsModel(HyperparametersMixin):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # equivalent automatic
... self.save_hyperparameters()
... def forward(self, *args, **kwargs):
... ...
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg2": abc
"arg3": 3.14

>>> class SingleArgModel(HyperparametersMixin):
... def __init__(self, params):
... super().__init__()
... # manually assign single argument
... self.save_hyperparameters(params)
... def forward(self, *args, **kwargs):
... ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.hparams
"p1": 1
"p2": abc
"p3": 3.14

>>> class ManuallyArgsModel(HyperparametersMixin):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # pass argument(s) to ignore as a string or in a list
... self.save_hyperparameters(ignore='arg2')
... def forward(self, *args, **kwargs):
... ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
"""
# the frame needs to be created in this file.
if not frame:
frame = inspect.currentframe().f_back
save_hyperparameters(self, *args, ignore=ignore, frame=frame)

def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
hp = self._to_hparams_dict(hp)

if isinstance(hp, dict) and isinstance(self.hparams, dict):
self.hparams.update(hp)
else:
self._hparams = hp

@staticmethod
def _to_hparams_dict(hp: Union[dict, Namespace, str]):
if isinstance(hp, Namespace):
hp = vars(hp)
if isinstance(hp, dict):
hp = AttributeDict(hp)
elif isinstance(hp, PRIMITIVE_TYPES):
raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.")
elif not isinstance(hp, ALLOWED_CONFIG_TYPES):
raise ValueError(f"Unsupported config type of {type(hp)}.")
return hp

@property
def hparams(self) -> Union[AttributeDict, dict, Namespace]:
if not hasattr(self, "_hparams"):
self._hparams = AttributeDict()
return self._hparams

@property
def hparams_initial(self) -> AttributeDict:
if not hasattr(self, "_hparams_initial"):
return AttributeDict()
# prevent any change
return copy.deepcopy(self._hparams_initial)
13 changes: 13 additions & 0 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.model_helpers import is_overridden
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.datamodules import ClassifDataModule
Expand Down Expand Up @@ -551,3 +552,15 @@ def test_dm_init_from_datasets_dataloaders(iterable):
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
])


class DataModuleWithHparams(LightningDataModule):

def __init__(self, arg0, arg1, kwarg0=None):
super().__init__()
self.save_hyperparameters()


def test_simple_hyperparameters_saving():
data = DataModuleWithHparams(10, "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"})
Loading