Skip to content

Commit

Permalink
feat(wandb): offset logging step when resuming (#5050)
Browse files Browse the repository at this point in the history
* feat(wandb): offset logging step when resuming

* feat(wandb): output warnings

* fix(wandb): allow step to be None

* test(wandb): update tests

* feat(wandb): display warning only once

* style: fix PEP issues

* tests(wandb): fix tests

* tests(wandb): improve test

* style: fix whitespace

* feat: improve warning

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* feat(wandb): use variable from class instance

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* tests(wandb): check warnings

* feat(wandb): use WarningCache

* tests(wandb): fix tests

* style: fix formatting

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people committed Dec 29, 2020
1 parent 5e55b48 commit e69177a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
11 changes: 10 additions & 1 deletion pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.warning_utils import WarningCache


class WandbLogger(LightningLoggerBase):
Expand Down Expand Up @@ -66,6 +67,9 @@ class WandbLogger(LightningLoggerBase):
wandb_logger = WandbLogger()
trainer = Trainer(logger=wandb_logger)
Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`,
make sure to use `commit=False` so the logging step does not increase.
See Also:
- `Tutorial <https://app.wandb.ai/cayush/pytorchlightning/reports/
Use-Pytorch-Lightning-with-Weights-%26-Biases--Vmlldzo2NjQ1Mw>`__
Expand Down Expand Up @@ -103,8 +107,9 @@ def __init__(
self._log_model = log_model
self._prefix = prefix
self._kwargs = kwargs
# logging multiple Trainer on a single W&B run (k-fold, etc)
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
self._step_offset = 0
self.warning_cache = WarningCache()

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -134,6 +139,8 @@ def experiment(self) -> Run:
self._experiment = wandb.init(
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run
# offset logging step when resuming a run
self._step_offset = self._experiment.step
# save checkpoints in wandb dir to upload on W&B servers
if self._log_model:
self._save_dir = self._experiment.dir
Expand All @@ -154,6 +161,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

metrics = self._add_prefix(metrics)
if step is not None and step + self._step_offset < self.experiment.step:
self.warning_cache.warn('Trying to log at a previous step. Use `commit=False` when logging metrics manually.')
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)

@property
Expand Down
6 changes: 5 additions & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def test_loggers_fit_test_all(tmpdir, monkeypatch):
with mock.patch('pytorch_lightning.loggers.test_tube.Experiment'):
_test_loggers_fit_test(tmpdir, TestTubeLogger)

with mock.patch('pytorch_lightning.loggers.wandb.wandb'):
with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb:
wandb.run = None
wandb.init().step = 0
_test_loggers_fit_test(tmpdir, WandbLogger)


Expand Down Expand Up @@ -368,5 +370,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
# WandB
with mock.patch('pytorch_lightning.loggers.wandb.wandb'):
logger = _instantiate_logger(WandbLogger, save_idr=tmpdir, prefix=prefix)
wandb.run = None
wandb.init().step = 0
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0)
28 changes: 27 additions & 1 deletion tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@
from tests.base import EvalModelTemplate


def get_warnings(recwarn):
warnings_text = '\n'.join(str(w.message) for w in recwarn.list)
recwarn.clear()
return warnings_text


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_logger_init(wandb):
def test_wandb_logger_init(wandb, recwarn):
"""Verify that basic functionality of wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here."""

Expand All @@ -34,6 +40,9 @@ def test_wandb_logger_init(wandb):
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

# mock wandb step
wandb.init().step = 0

# test wandb.init not called if there is a W&B run
wandb.init().log.reset_mock()
wandb.init.reset_mock()
Expand All @@ -49,15 +58,28 @@ def test_wandb_logger_init(wandb):
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)

# log hyper parameters
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
wandb.init().config.update.assert_called_once_with(
{'test': 'None', 'nested/a': 1, 'b': [2, 3, 4]},
allow_val_change=True,
)

# watch a model
logger.watch('model', 'log', 10)
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

# verify warning for logging at a previous step
assert 'Trying to log at a previous step' not in get_warnings(recwarn)
# current step from wandb should be 6 (last logged step)
logger.experiment.step = 6
# logging at step 2 should raise a warning (step_offset is still 3)
logger.log_metrics({'acc': 1.0}, step=2)
assert 'Trying to log at a previous step' in get_warnings(recwarn)
# logging again at step 2 should not display again the same warning
logger.log_metrics({'acc': 1.0}, step=2)
assert 'Trying to log at a previous step' not in get_warnings(recwarn)

assert logger.name == wandb.init().project_name()
assert logger.version == wandb.init().id

Expand All @@ -71,6 +93,7 @@ def test_wandb_pickle(wandb, tmpdir):
class Experiment:
""" """
id = 'the_id'
step = 0

def project_name(self):
return 'the_project_name'
Expand Down Expand Up @@ -108,8 +131,11 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
assert logger.name is None

# mock return values of experiment
wandb.run = None
wandb.init().step = 0
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
logger.experiment.step = 0

for _ in range(2):
_ = logger.experiment
Expand Down

0 comments on commit e69177a

Please sign in to comment.