Skip to content

Commit

Permalink
feat(wandb): log in sync with Trainer step (#4405)
Browse files Browse the repository at this point in the history
* feat(wandb): log in sync with Trainer step

* docs: update CHANGELOG

* style(test_wandb): fix formatting

* parentheses

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
3 people committed Oct 28, 2020
1 parent 41de453 commit ff41d80
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320))


- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))


### Deprecated


Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(
self._offline = offline
self._log_model = log_model
self._kwargs = kwargs
# logging multiple Trainer on a single W&B run (k-fold, etc)
self._step_offset = 0

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -141,8 +143,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

self.experiment.log({'global_step': step, **metrics} if step is not None else metrics)
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)

@property
def save_dir(self) -> Optional[str]:
Expand All @@ -159,6 +160,10 @@ def version(self) -> Optional[str]:
return self._experiment.id if self._experiment else self._id

def finalize(self, status: str) -> None:
# offset future training logged on same W&B run
if self._experiment is not None:
self._step_offset = self._experiment.step

# upload all checkpoints from saving dir
if self._log_model:
wandb.save(os.path.join(self.save_dir, "*.ckpt"))
10 changes: 8 additions & 2 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@ def test_wandb_logger(wandb):
logger = WandbLogger(anonymous=True, offline=True)

logger.log_metrics({'acc': 1.0})
wandb.init().log.assert_called_once_with({'acc': 1.0})
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

wandb.init().log.reset_mock()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0})
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)

# continue training on same W&B run
wandb.init().step = 3
logger.finalize('success')
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)

logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
wandb.init().config.update.assert_called_once_with(
Expand Down

0 comments on commit ff41d80

Please sign in to comment.