Skip to content

Commit

Permalink
Add Support for Non-primitive types in TensorboardLogger (#1130)
Browse files Browse the repository at this point in the history
* Added support for non-primitive types to tensorboard logger

* added EOF newline

* PEP8

* Updated CHANGELOG for PR #1130. Moved _sanitize_params to base logger. Cleaned up _sanitize_params

* Updated CHANGELOG for PR #1130. Moved _sanitize_params to base logger. Cleaned up _sanitize_params

* changed convert_params to static method

* PEP8

* Cleanup Doctest for _sanitize_params

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Removed OrderedDict import

* Updated import order to conventions

Co-authored-by: Manbir Gulati <manbirgulati@Manbirs-MBP.hsd1.md.comcast.net>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 14, 2020
1 parent f6a7a52 commit da61398
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
- Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))


### Changed

Expand Down
28 changes: 27 additions & 1 deletion pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from functools import wraps
from typing import Union, Optional, Dict, Iterable, Any, Callable, List

import torch


def rank_zero_only(fn: Callable):
"""Decorate a logger method to run it only on the process with rank 0.
Expand Down Expand Up @@ -42,7 +44,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""
pass

def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
@staticmethod
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)
Expand All @@ -52,6 +55,29 @@ def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str,

return params

@staticmethod
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Returns params with non-primitvies converted to strings for logging
>>> params = {"float": 0.3,
... "int": 1,
... "string": "abc",
... "bool": True,
... "list": [1, 2, 3],
... "namespace": Namespace(foo=3),
... "layer": torch.nn.BatchNorm1d}
>>> import pprint
>>> pprint.pprint(LightningLoggerBase._sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE
{'bool': True,
'float': 0.3,
'int': 1,
'layer': "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>",
'list': '[1, 2, 3]',
'namespace': 'Namespace(foo=3)',
'string': 'abc'}
"""
return {k: v if type(v) in [bool, int, float, str, torch.Tensor] else str(v) for k, v in params.items()}

@abstractmethod
def log_hyperparams(self, params: argparse.Namespace):
"""Record hyperparameters.
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def experiment(self) -> SummaryWriter:
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
sanitized_params = self._sanitize_params(params)

if parse_version(torch.__version__) < parse_version("1.3.0"):
warn(
Expand All @@ -110,13 +111,14 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
)
else:
from torch.utils.tensorboard.summary import hparams
exp, ssi, sei = hparams(params, {})
exp, ssi, sei = hparams(sanitized_params, {})
writer = self.experiment._get_file_writer()
writer.add_summary(exp)
writer.add_summary(ssi)
writer.add_summary(sei)

# some alternative should be added
self.tags.update(params)
self.tags.update(sanitized_params)

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
Expand Down
6 changes: 5 additions & 1 deletion tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
from argparse import Namespace

import pytest
import torch
Expand Down Expand Up @@ -108,6 +109,9 @@ def test_tensorboard_log_hyperparams(tmpdir):
"float": 0.3,
"int": 1,
"string": "abc",
"bool": True
"bool": True,
"list": [1, 2, 3],
"namespace": Namespace(foo=3),
"layer": torch.nn.BatchNorm1d
}
logger.log_hyperparams(hparams)

0 comments on commit da61398

Please sign in to comment.