diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index f0c042724538a..8e5311b11dcb1 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -31,6 +31,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.warning_utils import WarningCache class WandbLogger(LightningLoggerBase): @@ -66,6 +67,9 @@ class WandbLogger(LightningLoggerBase): wandb_logger = WandbLogger() trainer = Trainer(logger=wandb_logger) + Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`, + make sure to use `commit=False` so the logging step does not increase. + See Also: - `Tutorial `__ @@ -103,8 +107,9 @@ def __init__( self._log_model = log_model self._prefix = prefix self._kwargs = kwargs - # logging multiple Trainer on a single W&B run (k-fold, etc) + # logging multiple Trainer on a single W&B run (k-fold, resuming, etc) self._step_offset = 0 + self.warning_cache = WarningCache() def __getstate__(self): state = self.__dict__.copy() @@ -134,6 +139,8 @@ def experiment(self) -> Run: self._experiment = wandb.init( name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous, id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run + # offset logging step when resuming a run + self._step_offset = self._experiment.step # save checkpoints in wandb dir to upload on W&B servers if self._log_model: self._save_dir = self._experiment.dir @@ -154,6 +161,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' metrics = self._add_prefix(metrics) + if step is not None and step + self._step_offset < self.experiment.step: + self.warning_cache.warn('Trying to log at a previous step. Use `commit=False` when logging metrics manually.') self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None) @property diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index ea40814b18861..945d8945a22c2 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -74,7 +74,9 @@ def test_loggers_fit_test_all(tmpdir, monkeypatch): with mock.patch('pytorch_lightning.loggers.test_tube.Experiment'): _test_loggers_fit_test(tmpdir, TestTubeLogger) - with mock.patch('pytorch_lightning.loggers.wandb.wandb'): + with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb: + wandb.run = None + wandb.init().step = 0 _test_loggers_fit_test(tmpdir, WandbLogger) @@ -368,5 +370,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): # WandB with mock.patch('pytorch_lightning.loggers.wandb.wandb'): logger = _instantiate_logger(WandbLogger, save_idr=tmpdir, prefix=prefix) + wandb.run = None + wandb.init().step = 0 logger.log_metrics({"test": 1.0}, step=0) logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index fa503f5d8eeb1..398ee45ef4aa0 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -22,8 +22,14 @@ from tests.base import EvalModelTemplate +def get_warnings(recwarn): + warnings_text = '\n'.join(str(w.message) for w in recwarn.list) + recwarn.clear() + return warnings_text + + @mock.patch('pytorch_lightning.loggers.wandb.wandb') -def test_wandb_logger_init(wandb): +def test_wandb_logger_init(wandb, recwarn): """Verify that basic functionality of wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here.""" @@ -34,6 +40,9 @@ def test_wandb_logger_init(wandb): wandb.init.assert_called_once() wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None) + # mock wandb step + wandb.init().step = 0 + # test wandb.init not called if there is a W&B run wandb.init().log.reset_mock() wandb.init.reset_mock() @@ -49,15 +58,28 @@ def test_wandb_logger_init(wandb): logger.log_metrics({'acc': 1.0}, step=3) wandb.init().log.assert_called_with({'acc': 1.0}, step=6) + # log hyper parameters logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]}) wandb.init().config.update.assert_called_once_with( {'test': 'None', 'nested/a': 1, 'b': [2, 3, 4]}, allow_val_change=True, ) + # watch a model logger.watch('model', 'log', 10) wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10) + # verify warning for logging at a previous step + assert 'Trying to log at a previous step' not in get_warnings(recwarn) + # current step from wandb should be 6 (last logged step) + logger.experiment.step = 6 + # logging at step 2 should raise a warning (step_offset is still 3) + logger.log_metrics({'acc': 1.0}, step=2) + assert 'Trying to log at a previous step' in get_warnings(recwarn) + # logging again at step 2 should not display again the same warning + logger.log_metrics({'acc': 1.0}, step=2) + assert 'Trying to log at a previous step' not in get_warnings(recwarn) + assert logger.name == wandb.init().project_name() assert logger.version == wandb.init().id @@ -71,6 +93,7 @@ def test_wandb_pickle(wandb, tmpdir): class Experiment: """ """ id = 'the_id' + step = 0 def project_name(self): return 'the_project_name' @@ -108,8 +131,11 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): assert logger.name is None # mock return values of experiment + wandb.run = None + wandb.init().step = 0 logger.experiment.id = '1' logger.experiment.project_name.return_value = 'project' + logger.experiment.step = 0 for _ in range(2): _ = logger.experiment