Skip to content

Commit

Permalink
hparams as dict [blocked by 1041] (Lightning-AI#1029)
Browse files Browse the repository at this point in the history
* hparams as dict

* hparams as dict

* fixing

* fixing

* fixing

* fixing

* typing

* typing

* chnagelog

* update set hparams

* use setter

* simplify

* chnagelog

* imports

* pylint

* typing

* Update training_io.py

* Update training_io.py

* Update lightning.py

* Update test_trainer.py

* Update __init__.py

* Update base.py

* Update utils.py

* Update test_trainer.py

* Update training_io.py

* Update test_trainer.py

* Update test_trainer.py

* Update test_trainer.py

* Update test_trainer.py

* Update callback_config.py

* Update callback_config.py

* Update test_trainer.py

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
2 people authored and tullie committed Apr 3, 2020
1 parent d00009c commit a2ebb11
Show file tree
Hide file tree
Showing 18 changed files with 168 additions and 87 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849))
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))

### Changed

Expand All @@ -32,6 +33,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767))
- Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749))
- Moved functionality of `LightningModule.load_from_metrics` into `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995))
- Changed Checkpoint path parameter from `filepath` to `dirpath` ([#1016](https://github.com/PyTorchLightning/pytorch-lightning/pull/1016))
- Freezed models `hparams` as `Namespace` property ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))

### Deprecated

Expand Down
2 changes: 1 addition & 1 deletion docs/source/weights_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ The Lightning checkpoint also saves the hparams (hyperparams) passed into the Li
from argparse import Namespace
# usually these come from command line args
args = Namespace(**{'learning_rate':0.001})
args = Namespace(learning_rate=0.001)
# define you module to have hparams as the first arg
# this means your checkpoint will have everything that went into making
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ class ModelCheckpoint(Callback):
# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
# if model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt
monitor: quantity to monitor.
verbose: verbosity mode, False or True.
save_top_k: if `save_top_k == k`,
Expand Down Expand Up @@ -135,7 +137,7 @@ def _save_model(self, filepath: str) -> None:
if self.save_function is not None:
self.save_function(filepath)
else:
raise ValueError(".save_function() not set")
raise ValueError("Method `.save_function()` not set")

def check_monitor_top_k(self, current: float) -> bool:
less_than_k_models = len(self.best_k_models) < self.save_top_k
Expand Down
19 changes: 17 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Optional, Union, Dict, Callable
from typing import Any, Callable, Dict, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -68,6 +68,20 @@ def __init__(self, *args, **kwargs):
#: True if using amp
self.use_amp = False

@property
def hparams(self) -> Namespace:
if not hasattr(self, '_hparams'):
return Namespace()
assert isinstance(self._hparams, dict)
return Namespace(**self._hparams)

@hparams.setter
def hparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
"""Set the model hyper-parameters."""
if isinstance(params, Namespace):
params = vars(params)
self._hparams = params

def print(self, *args, **kwargs):
r"""
Prints only from process 0. Use this in any distributed mode to log only once
Expand Down Expand Up @@ -1201,7 +1215,8 @@ def _load_model_state(cls, checkpoint):

if cls_takes_hparams:
if ckpt_hparams is not None:
hparams = Namespace(**ckpt_hparams)
is_namespace = checkpoint.get('hparams_type') == 'namespace'
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
else:
warnings.warn(
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ contains"
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,14 @@ def on_hpc_load(self, checkpoint):
"""


def load_hparams_from_tags_csv(tags_csv):
def load_hparams_from_tags_csv(tags_csv) -> Namespace:
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')
return Namespace()

tags = {}
with open(tags_csv) as f:
csv_reader = csv.reader(f, delimiter=',')
for row in list(csv_reader)[1:]:
tags[row[0]] = convert(row[1])
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
ns = Namespace(**tags)
return ns

Expand Down
27 changes: 17 additions & 10 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Union, Optional, Dict, Iterable, Any, Callable, List

Expand Down Expand Up @@ -41,6 +42,12 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""
pass

def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)
return params

@abstractmethod
def log_hyperparams(self, params: argparse.Namespace):
"""Record hyperparameters.
Expand All @@ -50,19 +57,19 @@ def log_hyperparams(self, params: argparse.Namespace):
"""
pass

def save(self):
def save(self) -> None:
"""Save log data."""
pass

def finalize(self, status: str):
def finalize(self, status: str) -> None:
"""Do any processing that is necessary to finalize an experiment.
Args:
status: Status that the experiment finished with (e.g. success, failed, aborted)
"""
pass

def close(self):
def close(self) -> None:
"""Do any cleanup that is necessary to close an experiment."""
pass

Expand All @@ -72,7 +79,7 @@ def rank(self) -> int:
return self._rank

@rank.setter
def rank(self, value: int):
def rank(self, value: int) -> None:
"""Set the process rank."""
self._rank = value

Expand Down Expand Up @@ -107,23 +114,23 @@ def __getitem__(self, index: int) -> LightningLoggerBase:
def experiment(self) -> List[Any]:
return [logger.experiment for logger in self._logger_iterable]

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
[logger.log_metrics(metrics, step) for logger in self._logger_iterable]

def log_hyperparams(self, params: argparse.Namespace):
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
[logger.log_hyperparams(params) for logger in self._logger_iterable]

def save(self):
def save(self) -> None:
[logger.save() for logger in self._logger_iterable]

def finalize(self, status: str):
def finalize(self, status: str) -> None:
[logger.finalize(status) for logger in self._logger_iterable]

def close(self):
def close(self) -> None:
[logger.close() for logger in self._logger_iterable]

@LightningLoggerBase.rank.setter
def rank(self, value: int):
def rank(self, value: int) -> None:
self._rank = value
for logger in self._logger_iterable:
logger.rank = value
Expand Down
15 changes: 8 additions & 7 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
CometLogger
-------------
"""
import argparse
from argparse import Namespace
from logging import getLogger
from typing import Optional, Dict, Union
from typing import Optional, Dict, Union, Any

try:
from comet_ml import Experiment as CometExperiment
Expand Down Expand Up @@ -162,15 +162,16 @@ def experiment(self) -> CometBaseExperiment:
return self._experiment

@rank_zero_only
def log_hyperparams(self, params: argparse.Namespace):
self.experiment.log_parameters(vars(params))
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
self.experiment.log_parameters(params)

@rank_zero_only
def log_metrics(
self,
metrics: Dict[str, Union[torch.Tensor, float]],
step: Optional[int] = None
):
) -> None:
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
Expand All @@ -182,7 +183,7 @@ def reset_experiment(self):
self._experiment = None

@rank_zero_only
def finalize(self, status: str):
def finalize(self, status: str) -> None:
r"""
When calling self.experiment.end(), that experiment won't log any more data to Comet. That's why, if you need
to log any more data you need to create an ExistingCometExperiment. For example, to log data when testing your
Expand All @@ -199,7 +200,7 @@ def name(self) -> str:
return self.experiment.project_name

@name.setter
def name(self, value: str):
def name(self, value: str) -> None:
self.experiment.set_name(value)

@property
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def any_lightning_module_function_or_hook(...):
self.logger.experiment.whatever_ml_flow_supports(...)
"""
import argparse
from argparse import Namespace
from logging import getLogger
from time import time
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, Union

try:
import mlflow
Expand Down Expand Up @@ -88,12 +88,13 @@ def run_id(self):
return self._run_id

@rank_zero_only
def log_hyperparams(self, params: argparse.Namespace):
for k, v in vars(params).items():
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
for k, v in params.items():
self.experiment.log_param(self.run_id, k, v)

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
timestamp_ms = int(time() * 1000)
for k, v in metrics.items():
if isinstance(v, str):
Expand All @@ -105,7 +106,7 @@ def save(self):
pass

@rank_zero_only
def finalize(self, status: str = 'FINISHED'):
def finalize(self, status: str = 'FINISHED') -> None:
if status == 'success':
status = 'FINISHED'
self.experiment.set_terminated(self.run_id, status)
Expand Down
23 changes: 12 additions & 11 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
NeptuneLogger
--------------
"""
import argparse
from argparse import Namespace
from logging import getLogger
from typing import Optional, List, Dict, Any, Union, Iterable

Expand Down Expand Up @@ -164,16 +164,17 @@ def experiment(self) -> Experiment:
return self._experiment

@rank_zero_only
def log_hyperparams(self, params: argparse.Namespace):
for key, val in vars(params).items():
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
for key, val in params.items():
self.experiment.set_property(f'param__{key}', val)

@rank_zero_only
def log_metrics(
self,
metrics: Dict[str, Union[torch.Tensor, float]],
step: Optional[int] = None
):
) -> None:
"""Log metrics (numeric values) in Neptune experiments
Args:
Expand All @@ -184,7 +185,7 @@ def log_metrics(
self.log_metric(key, val, step=step)

@rank_zero_only
def finalize(self, status: str):
def finalize(self, status: str) -> None:
self.experiment.stop()

@property
Expand All @@ -207,7 +208,7 @@ def log_metric(
metric_name: str,
metric_value: Union[torch.Tensor, float, str],
step: Optional[int] = None
):
) -> None:
"""Log metrics (numeric values) in Neptune experiments
Args:
Expand All @@ -224,7 +225,7 @@ def log_metric(
self.experiment.log_metric(metric_name, x=step, y=metric_value)

@rank_zero_only
def log_text(self, log_name: str, text: str, step: Optional[int] = None):
def log_text(self, log_name: str, text: str, step: Optional[int] = None) -> None:
"""Log text data in Neptune experiment
Args:
Expand All @@ -235,7 +236,7 @@ def log_text(self, log_name: str, text: str, step: Optional[int] = None):
self.log_metric(log_name, text, step=step)

@rank_zero_only
def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None):
def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None) -> None:
"""Log image data in Neptune experiment
Args:
Expand All @@ -250,7 +251,7 @@ def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] =
self.experiment.log_image(log_name, x=step, y=image)

@rank_zero_only
def log_artifact(self, artifact: str, destination: Optional[str] = None):
def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None:
"""Save an artifact (file) in Neptune experiment storage.
Args:
Expand All @@ -261,7 +262,7 @@ def log_artifact(self, artifact: str, destination: Optional[str] = None):
self.experiment.log_artifact(artifact, destination)

@rank_zero_only
def set_property(self, key: str, value: Any):
def set_property(self, key: str, value: Any) -> None:
"""Set key-value pair as Neptune experiment property.
Args:
Expand All @@ -271,7 +272,7 @@ def set_property(self, key: str, value: Any):
self.experiment.set_property(key, value)

@rank_zero_only
def append_tags(self, tags: Union[str, Iterable[str]]):
def append_tags(self, tags: Union[str, Iterable[str]]) -> None:
"""appends tags to neptune experiment
Args:
Expand Down
Loading

0 comments on commit a2ebb11

Please sign in to comment.