Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hparams as dict #1029

Merged
merged 33 commits into from
Mar 4, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
12851f7
hparams as dict
Borda Mar 3, 2020
6811137
hparams as dict
Borda Mar 3, 2020
77eb6c8
fixing
Borda Mar 3, 2020
804b8d8
fixing
Borda Mar 3, 2020
69cd0ec
fixing
Borda Mar 3, 2020
1c523fd
fixing
Borda Mar 3, 2020
f42ac58
typing
Borda Mar 3, 2020
5e6c5a0
typing
Borda Mar 3, 2020
b7601ba
chnagelog
Borda Mar 3, 2020
a56c9fa
update set hparams
Borda Mar 3, 2020
68c17ef
use setter
Borda Mar 3, 2020
02849b7
simplify
Borda Mar 3, 2020
ffdb2e2
chnagelog
Borda Mar 3, 2020
58e95aa
imports
Borda Mar 4, 2020
f9b0649
pylint
Borda Mar 4, 2020
4460bd0
Merge branch 'master' into hparams
williamFalcon Mar 4, 2020
716d9a7
typing
Borda Mar 4, 2020
adf1987
Update training_io.py
williamFalcon Mar 4, 2020
ffac889
Update training_io.py
williamFalcon Mar 4, 2020
fc31e48
Update lightning.py
williamFalcon Mar 4, 2020
e125035
Update test_trainer.py
williamFalcon Mar 4, 2020
73aad62
Update __init__.py
williamFalcon Mar 4, 2020
ee74462
Update base.py
williamFalcon Mar 4, 2020
3fd203e
Update utils.py
williamFalcon Mar 4, 2020
edddda0
Update test_trainer.py
williamFalcon Mar 4, 2020
25c29f3
Update training_io.py
williamFalcon Mar 4, 2020
34e49ab
Update test_trainer.py
williamFalcon Mar 4, 2020
0ba1b2a
Update test_trainer.py
williamFalcon Mar 4, 2020
ea948df
Update test_trainer.py
williamFalcon Mar 4, 2020
a21491c
Update test_trainer.py
williamFalcon Mar 4, 2020
12d4d56
Update callback_config.py
williamFalcon Mar 4, 2020
cc96dcd
Update callback_config.py
williamFalcon Mar 4, 2020
018091d
Update test_trainer.py
williamFalcon Mar 4, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849))
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))

### Changed

Expand All @@ -32,6 +33,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767))
- Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749))
- Moved functionality of `LightningModule.load_from_metrics` into `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995))
- Changed Checkpoint path parameter from `filepath` to `dirpath` ([#1016](https://github.com/PyTorchLightning/pytorch-lightning/pull/1016))
- Freezed models `hparams` as `Namespace` property ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))

### Deprecated

Expand Down
2 changes: 1 addition & 1 deletion docs/source/weights_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ The Lightning checkpoint also saves the hparams (hyperparams) passed into the Li
from argparse import Namespace

# usually these come from command line args
args = Namespace(**{'learning_rate':0.001})
args = Namespace(learning_rate=0.001)

# define you module to have hparams as the first arg
# this means your checkpoint will have everything that went into making
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ class ModelCheckpoint(Callback):

# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')

# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
# if model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt


monitor: quantity to monitor.
verbose: verbosity mode, False or True.
save_top_k: if `save_top_k == k`,
Expand Down Expand Up @@ -135,7 +137,7 @@ def _save_model(self, filepath: str) -> None:
if self.save_function is not None:
self.save_function(filepath)
else:
raise ValueError(".save_function() not set")
raise ValueError("Method `.save_function()` not set")

def check_monitor_top_k(self, current: float) -> bool:
less_than_k_models = len(self.best_k_models) < self.save_top_k
Expand Down
19 changes: 17 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Optional, Union, Dict, Callable
from typing import Any, Callable, Dict, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -68,6 +68,20 @@ def __init__(self, *args, **kwargs):
#: True if using amp
self.use_amp = False

@property
def hparams(self) -> Namespace:
if not hasattr(self, '_hparams'):
return Namespace()
assert isinstance(self._hparams, dict)
return Namespace(**self._hparams)

@hparams.setter
def hparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
"""Set the model hyper-parameters."""
if isinstance(params, Namespace):
params = vars(params)
self._hparams = params

def print(self, *args, **kwargs):
r"""
Prints only from process 0. Use this in any distributed mode to log only once
Expand Down Expand Up @@ -1201,7 +1215,8 @@ def _load_model_state(cls, checkpoint):

if cls_takes_hparams:
if ckpt_hparams is not None:
hparams = Namespace(**ckpt_hparams)
is_namespace = checkpoint.get('hparams_type') == 'namespace'
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
else:
warnings.warn(
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ contains"
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,14 @@ def on_hpc_load(self, checkpoint):
"""


def load_hparams_from_tags_csv(tags_csv):
def load_hparams_from_tags_csv(tags_csv) -> Namespace:
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')
return Namespace()

tags = {}
with open(tags_csv) as f:
csv_reader = csv.reader(f, delimiter=',')
for row in list(csv_reader)[1:]:
tags[row[0]] = convert(row[1])
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
ns = Namespace(**tags)
return ns

Expand Down
27 changes: 17 additions & 10 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Union, Optional, Dict, Iterable, Any, Callable, List

Expand Down Expand Up @@ -41,6 +42,12 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""
pass

def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)
return params

@abstractmethod
def log_hyperparams(self, params: argparse.Namespace):
"""Record hyperparameters.
Expand All @@ -50,19 +57,19 @@ def log_hyperparams(self, params: argparse.Namespace):
"""
pass

def save(self):
def save(self) -> None:
"""Save log data."""
pass

def finalize(self, status: str):
def finalize(self, status: str) -> None:
"""Do any processing that is necessary to finalize an experiment.

Args:
status: Status that the experiment finished with (e.g. success, failed, aborted)
"""
pass

def close(self):
def close(self) -> None:
"""Do any cleanup that is necessary to close an experiment."""
pass

Expand All @@ -72,7 +79,7 @@ def rank(self) -> int:
return self._rank

@rank.setter
def rank(self, value: int):
def rank(self, value: int) -> None:
"""Set the process rank."""
self._rank = value

Expand Down Expand Up @@ -107,23 +114,23 @@ def __getitem__(self, index: int) -> LightningLoggerBase:
def experiment(self) -> List[Any]:
return [logger.experiment for logger in self._logger_iterable]

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
[logger.log_metrics(metrics, step) for logger in self._logger_iterable]

def log_hyperparams(self, params: argparse.Namespace):
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
[logger.log_hyperparams(params) for logger in self._logger_iterable]

def save(self):
def save(self) -> None:
[logger.save() for logger in self._logger_iterable]

def finalize(self, status: str):
def finalize(self, status: str) -> None:
[logger.finalize(status) for logger in self._logger_iterable]

def close(self):
def close(self) -> None:
[logger.close() for logger in self._logger_iterable]

@LightningLoggerBase.rank.setter
def rank(self, value: int):
def rank(self, value: int) -> None:
self._rank = value
for logger in self._logger_iterable:
logger.rank = value
Expand Down
15 changes: 8 additions & 7 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
CometLogger
-------------
"""
import argparse
from argparse import Namespace
from logging import getLogger
from typing import Optional, Dict, Union
from typing import Optional, Dict, Union, Any

try:
from comet_ml import Experiment as CometExperiment
Expand Down Expand Up @@ -162,15 +162,16 @@ def experiment(self) -> CometBaseExperiment:
return self._experiment

@rank_zero_only
def log_hyperparams(self, params: argparse.Namespace):
self.experiment.log_parameters(vars(params))
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
self.experiment.log_parameters(params)

@rank_zero_only
def log_metrics(
self,
metrics: Dict[str, Union[torch.Tensor, float]],
step: Optional[int] = None
):
) -> None:
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
Expand All @@ -182,7 +183,7 @@ def reset_experiment(self):
self._experiment = None

@rank_zero_only
def finalize(self, status: str):
def finalize(self, status: str) -> None:
r"""
When calling self.experiment.end(), that experiment won't log any more data to Comet. That's why, if you need
to log any more data you need to create an ExistingCometExperiment. For example, to log data when testing your
Expand All @@ -199,7 +200,7 @@ def name(self) -> str:
return self.experiment.project_name

@name.setter
def name(self, value: str):
def name(self, value: str) -> None:
self.experiment.set_name(value)

@property
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def any_lightning_module_function_or_hook(...):
self.logger.experiment.whatever_ml_flow_supports(...)

"""
import argparse
from argparse import Namespace
from logging import getLogger
from time import time
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, Union

try:
import mlflow
Expand Down Expand Up @@ -88,12 +88,13 @@ def run_id(self):
return self._run_id

@rank_zero_only
def log_hyperparams(self, params: argparse.Namespace):
for k, v in vars(params).items():
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
for k, v in params.items():
self.experiment.log_param(self.run_id, k, v)

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
timestamp_ms = int(time() * 1000)
for k, v in metrics.items():
if isinstance(v, str):
Expand All @@ -105,7 +106,7 @@ def save(self):
pass

@rank_zero_only
def finalize(self, status: str = 'FINISHED'):
def finalize(self, status: str = 'FINISHED') -> None:
if status == 'success':
status = 'FINISHED'
self.experiment.set_terminated(self.run_id, status)
Expand Down
23 changes: 12 additions & 11 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
NeptuneLogger
--------------
"""
import argparse
from argparse import Namespace
from logging import getLogger
from typing import Optional, List, Dict, Any, Union, Iterable

Expand Down Expand Up @@ -164,16 +164,17 @@ def experiment(self) -> Experiment:
return self._experiment

@rank_zero_only
def log_hyperparams(self, params: argparse.Namespace):
for key, val in vars(params).items():
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
for key, val in params.items():
self.experiment.set_property(f'param__{key}', val)

@rank_zero_only
def log_metrics(
self,
metrics: Dict[str, Union[torch.Tensor, float]],
step: Optional[int] = None
):
) -> None:
"""Log metrics (numeric values) in Neptune experiments

Args:
Expand All @@ -184,7 +185,7 @@ def log_metrics(
self.log_metric(key, val, step=step)

@rank_zero_only
def finalize(self, status: str):
def finalize(self, status: str) -> None:
self.experiment.stop()

@property
Expand All @@ -207,7 +208,7 @@ def log_metric(
metric_name: str,
metric_value: Union[torch.Tensor, float, str],
step: Optional[int] = None
):
) -> None:
"""Log metrics (numeric values) in Neptune experiments

Args:
Expand All @@ -224,7 +225,7 @@ def log_metric(
self.experiment.log_metric(metric_name, x=step, y=metric_value)

@rank_zero_only
def log_text(self, log_name: str, text: str, step: Optional[int] = None):
def log_text(self, log_name: str, text: str, step: Optional[int] = None) -> None:
"""Log text data in Neptune experiment

Args:
Expand All @@ -235,7 +236,7 @@ def log_text(self, log_name: str, text: str, step: Optional[int] = None):
self.log_metric(log_name, text, step=step)

@rank_zero_only
def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None):
def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None) -> None:
"""Log image data in Neptune experiment

Args:
Expand All @@ -250,7 +251,7 @@ def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] =
self.experiment.log_image(log_name, x=step, y=image)

@rank_zero_only
def log_artifact(self, artifact: str, destination: Optional[str] = None):
def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None:
"""Save an artifact (file) in Neptune experiment storage.

Args:
Expand All @@ -261,7 +262,7 @@ def log_artifact(self, artifact: str, destination: Optional[str] = None):
self.experiment.log_artifact(artifact, destination)

@rank_zero_only
def set_property(self, key: str, value: Any):
def set_property(self, key: str, value: Any) -> None:
"""Set key-value pair as Neptune experiment property.

Args:
Expand All @@ -271,7 +272,7 @@ def set_property(self, key: str, value: Any):
self.experiment.set_property(key, value)

@rank_zero_only
def append_tags(self, tags: Union[str, Iterable[str]]):
def append_tags(self, tags: Union[str, Iterable[str]]) -> None:
"""appends tags to neptune experiment

Args:
Expand Down
Loading