Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(wandb): prevent WandbLogger from dropping values #5931

Merged
merged 26 commits into from
Feb 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3b13bf4
feat(wandb): use new wandb API
borisdayma Feb 10, 2021
866eebe
feat: handle earlier wandb versions
borisdayma Feb 10, 2021
46a4680
feat: remove unused import
borisdayma Feb 10, 2021
9f3ef61
feat(wandb): regular x-axis for train/step
borisdayma Feb 10, 2021
0b037fc
feat(wandb): offset not needed anymore
borisdayma Feb 10, 2021
e46ba5f
tests(wandb): handle new API
borisdayma Feb 12, 2021
a0940cf
style: remove white space
borisdayma Feb 12, 2021
6b80d73
doc(wandb): update CHANGELOG
borisdayma Feb 12, 2021
8a2818c
feat(wandb): update per API
borisdayma Feb 17, 2021
b49a2b6
Merge branch 'master' into feat-wandb_x
borisdayma Feb 18, 2021
f00dadf
feat(wandb): deprecation of sync_step
borisdayma Feb 19, 2021
6230155
fix(wandb): typo
borisdayma Feb 19, 2021
59cf073
style: fix pep8
borisdayma Feb 19, 2021
d95dd58
Apply suggestions from code review
borisdayma Feb 19, 2021
0b2d77e
Merge branch 'master' of https://github.com/borisdayma/pytorch-lightn…
borisdayma Feb 19, 2021
b0f48e5
docs: update CHANGELOG
borisdayma Feb 19, 2021
b19f5d2
Add deprecation test
carmocca Feb 21, 2021
92d805f
Merge branch 'master' into feat-wandb_x
carmocca Feb 21, 2021
f5f61bf
Merge branch 'master' into feat-wandb_x
tchaton Feb 22, 2021
6639100
Apply suggestions from code review
borisdayma Feb 24, 2021
6d587a0
fix(wandb): tests and typo
borisdayma Feb 24, 2021
3babc78
Merge branch 'master'
borisdayma Feb 24, 2021
67c0215
Merge branch 'master' into feat-wandb_x
awaelchli Feb 25, 2021
8b1b234
Merge branch 'master' into feat-wandb_x
awaelchli Feb 27, 2021
3ced69c
fix changelog
awaelchli Feb 27, 2021
8ffdf0b
fix changelog
awaelchli Feb 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))


- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
39 changes: 17 additions & 22 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()

_WANDB_AVAILABLE = _module_available("wandb")

try:
Expand Down Expand Up @@ -56,7 +58,6 @@ class WandbLogger(LightningLoggerBase):
project: The name of the project to which this run will belong.
log_model: Save checkpoints in wandb dir to upload on W&B servers.
prefix: A string to put at the beginning of metric keys.
sync_step: Sync Trainer step with wandb step.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
:func:`wandb.init` can be passed as keyword arguments in this logger.
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
log_model: Optional[bool] = False,
experiment=None,
prefix: Optional[str] = '',
sync_step: Optional[bool] = True,
sync_step: Optional[bool] = None,
**kwargs
):
if wandb is None:
Expand All @@ -114,6 +115,12 @@ def __init__(
'Hint: Set `offline=False` to log your model.'
)

if sync_step is not None:
warning_cache.warn(
"`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
" Metrics are now logged separately and automatically synchronized.", DeprecationWarning
)

super().__init__()
self._name = name
self._save_dir = save_dir
Expand All @@ -123,12 +130,8 @@ def __init__(
self._project = project
self._log_model = log_model
self._prefix = prefix
self._sync_step = sync_step
self._experiment = experiment
self._kwargs = kwargs
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
self._step_offset = 0
self.warning_cache = WarningCache()

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -165,12 +168,15 @@ def experiment(self) -> Run:
**self._kwargs
) if wandb.run is None else wandb.run

# offset logging step when resuming a run
self._step_offset = self._experiment.step

# save checkpoints in wandb dir to upload on W&B servers
if self._save_dir is None:
self._save_dir = self._experiment.dir

# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric='trainer/global_step', step_sync=True)

return self._experiment

def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
Expand All @@ -188,15 +194,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

metrics = self._add_prefix(metrics)
if self._sync_step and step is not None and step + self._step_offset < self.experiment.step:
self.warning_cache.warn(
'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
' or try logging with `commit=False` when calling manually `wandb.log`.'
)
if self._sync_step:
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
elif step is not None:
self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)})
if step is not None:
self.experiment.log({**metrics, 'trainer/global_step': step})
else:
self.experiment.log(metrics)

Expand All @@ -216,10 +215,6 @@ def version(self) -> Optional[str]:

@rank_zero_only
def finalize(self, status: str) -> None:
# offset future training logged on same W&B run
if self._experiment is not None:
self._step_offset = self._experiment.step

# upload all checkpoints from saving dir
if self._log_model:
wandb.save(os.path.join(self.save_dir, "*.ckpt"))
9 changes: 9 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.5.0"""

from unittest import mock

import pytest

from pytorch_lightning import Trainer, Callback
from pytorch_lightning.loggers import WandbLogger
from tests.helpers import BoringModel
from tests.helpers.utils import no_warning_call


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_v1_5_0_wandb_unused_sync_step(tmpdir):
with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"):
WandbLogger(sync_step=True)


def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
class OldSignature(Callback):
def on_save_checkpoint(self, trainer, pl_module): # noqa
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,4 +404,4 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
wandb.run = None
wandb.init().step = 0
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0)
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0, 'trainer/global_step': 0})
33 changes: 3 additions & 30 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,7 @@ def test_wandb_logger_init(wandb, recwarn):
logger = WandbLogger()
logger.log_metrics({'acc': 1.0})
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

# test sync_step functionality
wandb.init().log.reset_mock()
wandb.init.reset_mock()
wandb.run = None
wandb.init().step = 0
logger = WandbLogger(sync_step=False)
logger.log_metrics({'acc': 1.0})
wandb.init().log.assert_called_once_with({'acc': 1.0})
wandb.init().log.reset_mock()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3})

# mock wandb step
wandb.init().step = 0

# test wandb.init not called if there is a W&B run
wandb.init().log.reset_mock()
Expand All @@ -65,13 +50,12 @@ def test_wandb_logger_init(wandb, recwarn):
logger = WandbLogger()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3})

# continue training on same W&B run and offset step
wandb.init().step = 3
logger.finalize('success')
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)
logger.log_metrics({'acc': 1.0}, step=6)
wandb.init().log.assert_called_with({'acc': 1.0, 'trainer/global_step': 6})

# log hyper parameters
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
Expand All @@ -88,17 +72,6 @@ def test_wandb_logger_init(wandb, recwarn):
logger.watch('model', 'log', 10)
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

# verify warning for logging at a previous step
assert 'Trying to log at a previous step' not in get_warnings(recwarn)
# current step from wandb should be 6 (last logged step)
logger.experiment.step = 6
# logging at step 2 should raise a warning (step_offset is still 3)
logger.log_metrics({'acc': 1.0}, step=2)
assert 'Trying to log at a previous step' in get_warnings(recwarn)
# logging again at step 2 should not display again the same warning
logger.log_metrics({'acc': 1.0}, step=2)
assert 'Trying to log at a previous step' not in get_warnings(recwarn)

assert logger.name == wandb.init().project_name()
assert logger.version == wandb.init().id

Expand Down