Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hierarchical dict #1152

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for hierarchical `dict` ([#1152](https://github.com/PyTorchLightning/pytorch-lightning/pull/1152))
- Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122))
- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
Expand Down
33 changes: 33 additions & 0 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,39 @@ def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:

return params

@staticmethod
def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]:
"""Flatten hierarchical dict e.g. {'a': {'b': 'c'}} -> {'a/b': 'c'}.

Args:
params: Dictionary contains hparams
delimiter: Delimiter to express the hierarchy. Defaults to '/'.

Returns:
Flatten dict.

Examples:
>>> LightningLoggerBase._flatten_dict({'a': {'b': 'c'}})
{'a/b': 'c'}
>>> LightningLoggerBase._flatten_dict({'a': {'b': 123}})
{'a/b': 123}
"""
S-aiueo32 marked this conversation as resolved.
Show resolved Hide resolved

def _dict_generator(input_dict, prefixes=None):
prefixes = prefixes[:] if prefixes else []
if isinstance(input_dict, dict):
for key, value in input_dict.items():
if isinstance(value, (dict, Namespace)):
value = vars(value) if isinstance(value, Namespace) else value
for d in _dict_generator(value, prefixes + [key]):
yield d
else:
yield prefixes + [key, value if value is not None else str(None)]
else:
yield prefixes + [input_dict if input_dict is None else str(input_dict)]

return {delimiter.join(keys): val for *keys, val in _dict_generator(params)}

@staticmethod
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Returns params with non-primitvies converted to strings for logging
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def experiment(self) -> CometBaseExperiment:
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
self.experiment.log_parameters(params)

@rank_zero_only
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def run_id(self):
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
for k, v in params.items():
self.experiment.log_param(self.run_id, k, v)

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def experiment(self) -> Experiment:
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
for key, val in params.items():
self.experiment.set_property(f'param__{key}', val)

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def experiment(self) -> SummaryWriter:
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
sanitized_params = self._sanitize_params(params)

if parse_version(torch.__version__) < parse_version("1.3.0"):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
# TODO: HACK figure out where this is being set to true
self.experiment.debug = self.debug
params = self._convert_params(params)
params = self._flatten_dict(params)
self.experiment.argparse(Namespace(**params))

@rank_zero_only
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/loggers/trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
return None
if not params:
return
if isinstance(params, dict):
self._trains.connect(params)
else:
self._trains.connect(vars(params))

params = self._convert_params(params)
params = self._flatten_dict(params)
self._trains.connect(params)

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ def test_tensorboard_log_hyperparams(tmpdir):
"int": 1,
"string": "abc",
"bool": True,
"dict": {'a': {'b': 'c'}},
"list": [1, 2, 3],
"namespace": Namespace(foo=3),
"namespace": Namespace(foo=Namespace(bar='buzz')),
"layer": torch.nn.BatchNorm1d
}
logger.log_hyperparams(hparams)