Skip to content

Commit

Permalink
Support hierarchical dict (#1152)
Browse files Browse the repository at this point in the history
* Add support for hierarchical dict

* Support nested Namespace

* Add docstring

* Migrate hparam flattening to each logger

* Modify URLs in CHANGELOG

* typo

* Simplify the conditional branch about Namespace

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update CHANGELOG.md

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* added examples section to docstring

* renamed _dict -> input_dict

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
S-aiueo32 and Borda committed Mar 19, 2020
1 parent 22a7264 commit 01b8991
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for hierarchical `dict` ([#1152](https://github.com/PyTorchLightning/pytorch-lightning/pull/1152))
- Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122))
- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
Expand Down
33 changes: 33 additions & 0 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,39 @@ def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:

return params

@staticmethod
def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]:
"""Flatten hierarchical dict e.g. {'a': {'b': 'c'}} -> {'a/b': 'c'}.
Args:
params: Dictionary contains hparams
delimiter: Delimiter to express the hierarchy. Defaults to '/'.
Returns:
Flatten dict.
Examples:
>>> LightningLoggerBase._flatten_dict({'a': {'b': 'c'}})
{'a/b': 'c'}
>>> LightningLoggerBase._flatten_dict({'a': {'b': 123}})
{'a/b': 123}
"""

def _dict_generator(input_dict, prefixes=None):
prefixes = prefixes[:] if prefixes else []
if isinstance(input_dict, dict):
for key, value in input_dict.items():
if isinstance(value, (dict, Namespace)):
value = vars(value) if isinstance(value, Namespace) else value
for d in _dict_generator(value, prefixes + [key]):
yield d
else:
yield prefixes + [key, value if value is not None else str(None)]
else:
yield prefixes + [input_dict if input_dict is None else str(input_dict)]

return {delimiter.join(keys): val for *keys, val in _dict_generator(params)}

@staticmethod
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Returns params with non-primitvies converted to strings for logging
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def experiment(self) -> CometBaseExperiment:
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
self.experiment.log_parameters(params)

@rank_zero_only
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def run_id(self):
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
for k, v in params.items():
self.experiment.log_param(self.run_id, k, v)

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def experiment(self) -> Experiment:
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
for key, val in params.items():
self.experiment.set_property(f'param__{key}', val)

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def experiment(self) -> SummaryWriter:
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
sanitized_params = self._sanitize_params(params)

if parse_version(torch.__version__) < parse_version("1.3.0"):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
# TODO: HACK figure out where this is being set to true
self.experiment.debug = self.debug
params = self._convert_params(params)
params = self._flatten_dict(params)
self.experiment.argparse(Namespace(**params))

@rank_zero_only
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/loggers/trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
return None
if not params:
return
if isinstance(params, dict):
self._trains.connect(params)
else:
self._trains.connect(vars(params))

params = self._convert_params(params)
params = self._flatten_dict(params)
self._trains.connect(params)

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def test_tensorboard_log_hyperparams(tmpdir):
"int": 1,
"string": "abc",
"bool": True,
"dict": {'a': {'b': 'c'}},
"list": [1, 2, 3],
"namespace": Namespace(foo=3),
"namespace": Namespace(foo=Namespace(bar='buzz')),
"layer": torch.nn.BatchNorm1d
}
logger.log_hyperparams(hparams)

0 comments on commit 01b8991

Please sign in to comment.