Skip to content

Commit

Permalink
fix(wandb): prevent WandbLogger from dropping values (#5931)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
4 people authored and lexierule committed Mar 5, 2021
1 parent 9329f58 commit 5abfd2c
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 53 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))


- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
39 changes: 17 additions & 22 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()

_WANDB_AVAILABLE = _module_available("wandb")

try:
Expand Down Expand Up @@ -56,7 +58,6 @@ class WandbLogger(LightningLoggerBase):
project: The name of the project to which this run will belong.
log_model: Save checkpoints in wandb dir to upload on W&B servers.
prefix: A string to put at the beginning of metric keys.
sync_step: Sync Trainer step with wandb step.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
:func:`wandb.init` can be passed as keyword arguments in this logger.
Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(
log_model: Optional[bool] = False,
experiment=None,
prefix: Optional[str] = '',
sync_step: Optional[bool] = True,
sync_step: Optional[bool] = None,
**kwargs
):
if wandb is None:
Expand All @@ -108,6 +109,12 @@ def __init__(
'Hint: Set `offline=False` to log your model.'
)

if sync_step is not None:
warning_cache.warn(
"`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
" Metrics are now logged separately and automatically synchronized.", DeprecationWarning
)

super().__init__()
self._name = name
self._save_dir = save_dir
Expand All @@ -117,12 +124,8 @@ def __init__(
self._project = project
self._log_model = log_model
self._prefix = prefix
self._sync_step = sync_step
self._experiment = experiment
self._kwargs = kwargs
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
self._step_offset = 0
self.warning_cache = WarningCache()

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -159,12 +162,15 @@ def experiment(self) -> Run:
**self._kwargs
) if wandb.run is None else wandb.run

# offset logging step when resuming a run
self._step_offset = self._experiment.step

# save checkpoints in wandb dir to upload on W&B servers
if self._save_dir is None:
self._save_dir = self._experiment.dir

# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric='trainer/global_step', step_sync=True)

return self._experiment

def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
Expand All @@ -182,15 +188,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

metrics = self._add_prefix(metrics)
if self._sync_step and step is not None and step + self._step_offset < self.experiment.step:
self.warning_cache.warn(
'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
' or try logging with `commit=False` when calling manually `wandb.log`.'
)
if self._sync_step:
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
elif step is not None:
self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)})
if step is not None:
self.experiment.log({**metrics, 'trainer/global_step': step})
else:
self.experiment.log(metrics)

Expand All @@ -210,10 +209,6 @@ def version(self) -> Optional[str]:

@rank_zero_only
def finalize(self, status: str) -> None:
# offset future training logged on same W&B run
if self._experiment is not None:
self._step_offset = self._experiment.step

# upload all checkpoints from saving dir
if self._log_model:
wandb.save(os.path.join(self.save_dir, "*.ckpt"))
9 changes: 9 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.5.0"""

from unittest import mock

import pytest

from pytorch_lightning import Trainer, Callback
from pytorch_lightning.loggers import WandbLogger
from tests.helpers import BoringModel
from tests.helpers.utils import no_warning_call


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_v1_5_0_wandb_unused_sync_step(tmpdir):
with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"):
WandbLogger(sync_step=True)


def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
class OldSignature(Callback):
def on_save_checkpoint(self, trainer, pl_module): # noqa
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,4 +404,4 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
wandb.run = None
wandb.init().step = 0
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0)
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0, 'trainer/global_step': 0})
33 changes: 3 additions & 30 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,7 @@ def test_wandb_logger_init(wandb, recwarn):
logger = WandbLogger()
logger.log_metrics({'acc': 1.0})
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

# test sync_step functionality
wandb.init().log.reset_mock()
wandb.init.reset_mock()
wandb.run = None
wandb.init().step = 0
logger = WandbLogger(sync_step=False)
logger.log_metrics({'acc': 1.0})
wandb.init().log.assert_called_once_with({'acc': 1.0})
wandb.init().log.reset_mock()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3})

# mock wandb step
wandb.init().step = 0

# test wandb.init not called if there is a W&B run
wandb.init().log.reset_mock()
Expand All @@ -65,13 +50,12 @@ def test_wandb_logger_init(wandb, recwarn):
logger = WandbLogger()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3})

# continue training on same W&B run and offset step
wandb.init().step = 3
logger.finalize('success')
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)
logger.log_metrics({'acc': 1.0}, step=6)
wandb.init().log.assert_called_with({'acc': 1.0, 'trainer/global_step': 6})

# log hyper parameters
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
Expand All @@ -88,17 +72,6 @@ def test_wandb_logger_init(wandb, recwarn):
logger.watch('model', 'log', 10)
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

# verify warning for logging at a previous step
assert 'Trying to log at a previous step' not in get_warnings(recwarn)
# current step from wandb should be 6 (last logged step)
logger.experiment.step = 6
# logging at step 2 should raise a warning (step_offset is still 3)
logger.log_metrics({'acc': 1.0}, step=2)
assert 'Trying to log at a previous step' in get_warnings(recwarn)
# logging again at step 2 should not display again the same warning
logger.log_metrics({'acc': 1.0}, step=2)
assert 'Trying to log at a previous step' not in get_warnings(recwarn)

assert logger.name == wandb.init().project_name()
assert logger.version == wandb.init().id

Expand Down

0 comments on commit 5abfd2c

Please sign in to comment.