Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(wandb): prevent WandbLogger from dropping values #5931

Merged
merged 26 commits into from
Feb 27, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3b13bf4
feat(wandb): use new wandb API
borisdayma Feb 10, 2021
866eebe
feat: handle earlier wandb versions
borisdayma Feb 10, 2021
46a4680
feat: remove unused import
borisdayma Feb 10, 2021
9f3ef61
feat(wandb): regular x-axis for train/step
borisdayma Feb 10, 2021
0b037fc
feat(wandb): offset not needed anymore
borisdayma Feb 10, 2021
e46ba5f
tests(wandb): handle new API
borisdayma Feb 12, 2021
a0940cf
style: remove white space
borisdayma Feb 12, 2021
6b80d73
doc(wandb): update CHANGELOG
borisdayma Feb 12, 2021
8a2818c
feat(wandb): update per API
borisdayma Feb 17, 2021
b49a2b6
Merge branch 'master' into feat-wandb_x
borisdayma Feb 18, 2021
f00dadf
feat(wandb): deprecation of sync_step
borisdayma Feb 19, 2021
6230155
fix(wandb): typo
borisdayma Feb 19, 2021
59cf073
style: fix pep8
borisdayma Feb 19, 2021
d95dd58
Apply suggestions from code review
borisdayma Feb 19, 2021
0b2d77e
Merge branch 'master' of https://github.com/borisdayma/pytorch-lightn…
borisdayma Feb 19, 2021
b0f48e5
docs: update CHANGELOG
borisdayma Feb 19, 2021
b19f5d2
Add deprecation test
carmocca Feb 21, 2021
92d805f
Merge branch 'master' into feat-wandb_x
carmocca Feb 21, 2021
f5f61bf
Merge branch 'master' into feat-wandb_x
tchaton Feb 22, 2021
6639100
Apply suggestions from code review
borisdayma Feb 24, 2021
6d587a0
fix(wandb): tests and typo
borisdayma Feb 24, 2021
3babc78
Merge branch 'master'
borisdayma Feb 24, 2021
67c0215
Merge branch 'master' into feat-wandb_x
awaelchli Feb 25, 2021
8b1b234
Merge branch 'master' into feat-wandb_x
awaelchli Feb 27, 2021
3ced69c
fix changelog
awaelchli Feb 27, 2021
8ffdf0b
fix changelog
awaelchli Feb 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))


- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))


- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)


- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))
Expand Down
38 changes: 16 additions & 22 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()

_WANDB_AVAILABLE = _module_available("wandb")

try:
Expand Down Expand Up @@ -56,7 +58,6 @@ class WandbLogger(LightningLoggerBase):
project: The name of the project to which this run will belong.
log_model: Save checkpoints in wandb dir to upload on W&B servers.
prefix: A string to put at the beginning of metric keys.
sync_step: Sync Trainer step with wandb step.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
:func:`wandb.init` can be passed as keyword arguments in this logger.
Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(
log_model: Optional[bool] = False,
experiment=None,
prefix: Optional[str] = '',
sync_step: Optional[bool] = True,
sync_step: Optional[bool] = None,
**kwargs
):
if wandb is None:
Expand All @@ -108,6 +109,12 @@ def __init__(
'Hint: Set `offline=False` to log your model.'
)

if sync_step is not None:
warning_cache.warn(
"`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
" Metrics are now logged separately and automatically synchronized.", DeprecationWarning
)

super().__init__()
self._name = name
self._save_dir = save_dir
Expand All @@ -117,12 +124,8 @@ def __init__(
self._project = project
self._log_model = log_model
self._prefix = prefix
self._sync_step = sync_step
self._experiment = experiment
self._kwargs = kwargs
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
self._step_offset = 0
self.warning_cache = WarningCache()

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -159,12 +162,14 @@ def experiment(self) -> Run:
**self._kwargs
) if wandb.run is None else wandb.run

# offset logging step when resuming a run
self._step_offset = self._experiment.step

# save checkpoints in wandb dir to upload on W&B servers
if self._save_dir is None:
self._save_dir = self._experiment.dir

# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("*", step_metric='train/step', step_sync=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to have this step metric name as an argument for customization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be useful but there would be a few extra things to think of:

  • what to do if that metric never appears?
  • what is the start value if it's not logged at the beginning (for example if it's an evaluation metric)

If you call this method explicitly you will actually be able to overwrite this default behavior.
Let me see with the W&B team what they plan to do if the step_metric is not monotonic to have an idea of how we should handle these custom cases (and feel free if you have any example in mind).

borisdayma marked this conversation as resolved.
Show resolved Hide resolved

return self._experiment

def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
Expand All @@ -182,15 +187,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

metrics = self._add_prefix(metrics)
if self._sync_step and step is not None and step + self._step_offset < self.experiment.step:
self.warning_cache.warn(
'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
' or try logging with `commit=False` when calling manually `wandb.log`.'
)
if self._sync_step:
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
elif step is not None:
self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)})
if step is not None:
self.experiment.log({**metrics, 'train/step': step})
borisdayma marked this conversation as resolved.
Show resolved Hide resolved
else:
self.experiment.log(metrics)

Expand All @@ -210,10 +208,6 @@ def version(self) -> Optional[str]:

@rank_zero_only
def finalize(self, status: str) -> None:
# offset future training logged on same W&B run
if self._experiment is not None:
self._step_offset = self._experiment.step

# upload all checkpoints from saving dir
if self._log_model:
wandb.save(os.path.join(self.save_dir, "*.ckpt"))
25 changes: 25 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.5.0"""
from unittest import mock

import pytest

from pytorch_lightning.loggers import WandbLogger


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_v1_5_0_wandb_unused_sync_step(tmpdir):
with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"):
WandbLogger(sync_step=True)
2 changes: 1 addition & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,4 +404,4 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
wandb.run = None
wandb.init().step = 0
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0)
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0, 'train/step': 0})
33 changes: 3 additions & 30 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,7 @@ def test_wandb_logger_init(wandb, recwarn):
logger = WandbLogger()
logger.log_metrics({'acc': 1.0})
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

# test sync_step functionality
wandb.init().log.reset_mock()
wandb.init.reset_mock()
wandb.run = None
wandb.init().step = 0
logger = WandbLogger(sync_step=False)
logger.log_metrics({'acc': 1.0})
wandb.init().log.assert_called_once_with({'acc': 1.0})
wandb.init().log.reset_mock()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3})

# mock wandb step
wandb.init().step = 0

# test wandb.init not called if there is a W&B run
wandb.init().log.reset_mock()
Expand All @@ -65,13 +50,12 @@ def test_wandb_logger_init(wandb, recwarn):
logger = WandbLogger()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)
wandb.init().log.assert_called_once_with({'acc': 1.0, 'train/step': 3})

# continue training on same W&B run and offset step
wandb.init().step = 3
logger.finalize('success')
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)
logger.log_metrics({'acc': 1.0}, step=6)
wandb.init().log.assert_called_with({'acc': 1.0, 'train/step': 6})

# log hyper parameters
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
Expand All @@ -88,17 +72,6 @@ def test_wandb_logger_init(wandb, recwarn):
logger.watch('model', 'log', 10)
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

# verify warning for logging at a previous step
assert 'Trying to log at a previous step' not in get_warnings(recwarn)
# current step from wandb should be 6 (last logged step)
logger.experiment.step = 6
# logging at step 2 should raise a warning (step_offset is still 3)
logger.log_metrics({'acc': 1.0}, step=2)
assert 'Trying to log at a previous step' in get_warnings(recwarn)
# logging again at step 2 should not display again the same warning
logger.log_metrics({'acc': 1.0}, step=2)
assert 'Trying to log at a previous step' not in get_warnings(recwarn)

assert logger.name == wandb.init().project_name()
assert logger.version == wandb.init().id

Expand Down