Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wandb): offset logging step when resuming #5050

Merged
merged 23 commits into from
Dec 19, 2020
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f7c7994
feat(wandb): offset logging step when resuming
borisdayma Dec 9, 2020
a93cd0a
feat(wandb): output warnings
borisdayma Dec 9, 2020
d00e57a
fix(wandb): allow step to be None
borisdayma Dec 10, 2020
dc3794a
test(wandb): update tests
borisdayma Dec 10, 2020
eb75676
feat(wandb): display warning only once
borisdayma Dec 10, 2020
69bf4bb
style: fix PEP issues
borisdayma Dec 10, 2020
891acf4
Merge branch 'master' into feat_wandb_resume
borisdayma Dec 10, 2020
5978805
tests(wandb): fix tests
borisdayma Dec 10, 2020
dbc15f2
tests(wandb): improve test
borisdayma Dec 10, 2020
ea4f88e
style: fix whitespace
borisdayma Dec 10, 2020
1c3bc82
feat: improve warning
borisdayma Dec 10, 2020
6c02a2c
feat(wandb): use variable from class instance
borisdayma Dec 10, 2020
63dcd70
tests(wandb): check warnings
borisdayma Dec 11, 2020
c9066cd
Merge branch 'master' into feat_wandb_resume
borisdayma Dec 11, 2020
61cb199
feat(wandb): use WarningCache
borisdayma Dec 14, 2020
f7d832e
tests(wandb): fix tests
borisdayma Dec 15, 2020
b996d2a
Merge branch 'master' into feat_wandb_resume
borisdayma Dec 15, 2020
d4d500e
style: fix formatting
borisdayma Dec 15, 2020
e157ce3
Merge branch 'feat_wandb_resume' of https://github.com/borisdayma/pyt…
borisdayma Dec 15, 2020
792a94a
Merge branch 'master' into feat_wandb_resume
borisdayma Dec 16, 2020
f930f27
Merge branch 'master' into feat_wandb_resume
borisdayma Dec 16, 2020
91e089d
Merge branch 'master' into feat_wandb_resume
borisdayma Dec 19, 2020
78e0d19
Merge branch 'master' into feat_wandb_resume
Borda Dec 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.warning_utils import WarningCache


class WandbLogger(LightningLoggerBase):
Expand Down Expand Up @@ -66,6 +67,9 @@ class WandbLogger(LightningLoggerBase):
wandb_logger = WandbLogger()
trainer = Trainer(logger=wandb_logger)

Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`,
make sure to use `commit=False` so the logging step does not increase.

See Also:
- `Tutorial <https://app.wandb.ai/cayush/pytorchlightning/reports/
Use-Pytorch-Lightning-with-Weights-%26-Biases--Vmlldzo2NjQ1Mw>`__
Expand Down Expand Up @@ -103,8 +107,9 @@ def __init__(
self._log_model = log_model
self._prefix = prefix
self._kwargs = kwargs
# logging multiple Trainer on a single W&B run (k-fold, etc)
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
self._step_offset = 0
self.warning_cache = WarningCache()

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -134,6 +139,8 @@ def experiment(self) -> Run:
self._experiment = wandb.init(
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run
# offset logging step when resuming a run
self._step_offset = self._experiment.step
# save checkpoints in wandb dir to upload on W&B servers
if self._log_model:
self._save_dir = self._experiment.dir
Expand All @@ -154,6 +161,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

metrics = self._add_prefix(metrics)
if step is not None and step + self._step_offset < self.experiment.step:
self.warning_cache.warn('Trying to log at a previous step. Use `commit=False` when logging metrics manually.')
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)

@property
Expand Down
6 changes: 5 additions & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def test_loggers_fit_test_all(tmpdir, monkeypatch):
with mock.patch('pytorch_lightning.loggers.test_tube.Experiment'):
_test_loggers_fit_test(tmpdir, TestTubeLogger)

with mock.patch('pytorch_lightning.loggers.wandb.wandb'):
with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb:
wandb.run = None
wandb.init().step = 0
_test_loggers_fit_test(tmpdir, WandbLogger)


Expand Down Expand Up @@ -368,5 +370,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
# WandB
with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb:
logger = _instantiate_logger(WandbLogger, save_idr=tmpdir, prefix=prefix)
wandb.run = None
wandb.init().step = 0
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0)
28 changes: 27 additions & 1 deletion tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@
from tests.base import EvalModelTemplate, BoringModel


def get_warnings(recwarn):
warnings_text = '\n'.join(str(w.message) for w in recwarn.list)
recwarn.clear()
return warnings_text


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_logger_init(wandb):
def test_wandb_logger_init(wandb, recwarn):
"""Verify that basic functionality of wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here."""

Expand All @@ -34,6 +40,9 @@ def test_wandb_logger_init(wandb):
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

# mock wandb step
wandb.init().step = 0

# test wandb.init not called if there is a W&B run
wandb.init().log.reset_mock()
wandb.init.reset_mock()
Expand All @@ -49,15 +58,28 @@ def test_wandb_logger_init(wandb):
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)

# log hyper parameters
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
wandb.init().config.update.assert_called_once_with(
{'test': 'None', 'nested/a': 1, 'b': [2, 3, 4]},
allow_val_change=True,
)

# watch a model
logger.watch('model', 'log', 10)
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

# verify warning for logging at a previous step
assert 'Trying to log at a previous step' not in get_warnings(recwarn)
# current step from wandb should be 6 (last logged step)
logger.experiment.step = 6
# logging at step 2 should raise a warning (step_offset is still 3)
logger.log_metrics({'acc': 1.0}, step=2)
assert 'Trying to log at a previous step' in get_warnings(recwarn)
# logging again at step 2 should not display again the same warning
logger.log_metrics({'acc': 1.0}, step=2)
assert 'Trying to log at a previous step' not in get_warnings(recwarn)

assert logger.name == wandb.init().project_name()
assert logger.version == wandb.init().id

Expand All @@ -71,6 +93,7 @@ def test_wandb_pickle(wandb, tmpdir):
class Experiment:
""" """
id = 'the_id'
step = 0

def project_name(self):
return 'the_project_name'
Expand Down Expand Up @@ -108,8 +131,11 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
assert logger.name is None

# mock return values of experiment
wandb.run = None
wandb.init().step = 0
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
logger.experiment.step = 0

for _ in range(2):
_ = logger.experiment
Expand Down