Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wandb): log models as artifacts #6231

Merged
merged 60 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
bfb8872
feat(wandb): log models as artifacts
borisdayma Feb 27, 2021
541b001
feat: add Logger.connect
borisdayma Feb 28, 2021
bbd8633
fix: circular ref with type checking
borisdayma Mar 1, 2021
3365261
feat(wandb): use connect method
borisdayma Mar 1, 2021
dfd7553
style: pep8
borisdayma Mar 1, 2021
6950d3d
fix(configure_logger): logger can be bool
borisdayma Mar 1, 2021
f9cc20f
feat(connect): Trainer is not optional
borisdayma Mar 1, 2021
c518d71
feat(configure_logger): make trainer a proxy
borisdayma Mar 3, 2021
9b9aaa6
fix: unused import
borisdayma Mar 3, 2021
eb2080d
docs: more explicit doc
borisdayma Mar 3, 2021
7d98a99
doc: update docstring
borisdayma Mar 3, 2021
a6ad9aa
feat: ModelCheckpoint metadata
borisdayma Mar 3, 2021
444a4eb
Merge branch 'master' into feat_artifacts
borisdayma Mar 3, 2021
52b642f
feat: 1 checkpoint = 1 artifact
borisdayma Mar 4, 2021
765d081
feat: proxy typing + apply suggestions
borisdayma Mar 4, 2021
49f3688
Merge branch 'master' into feat_artifacts
borisdayma Mar 4, 2021
4a55e46
feat: don't log same model twice
borisdayma Mar 4, 2021
f16231c
fix: typo
borisdayma Mar 4, 2021
cbbf8ff
feat: log artifacts during training
borisdayma Mar 4, 2021
123cd88
fix: docs build
borisdayma Mar 4, 2021
0822d5d
feat: use proxy ref
borisdayma Mar 4, 2021
ee5b1d1
Merge branch 'master' into feat_artifacts
borisdayma Mar 4, 2021
947ab7a
fix: mypy
borisdayma Mar 4, 2021
03af2c3
fix: unused import
borisdayma Mar 4, 2021
743903c
fix: continuous logging logic
borisdayma Mar 4, 2021
363b3ac
fix: formatting
borisdayma Mar 5, 2021
7e331c1
docs: update log_model
borisdayma Mar 5, 2021
b438940
docs(wandb): improve log_model
borisdayma Mar 5, 2021
0dc78cc
feat(wandb): more explicit artifact name
borisdayma Mar 5, 2021
78cfc7c
feat(wandb): simplify artifact name
borisdayma Mar 5, 2021
eeed466
docs(wandb): improve documentation
borisdayma Mar 7, 2021
5227329
Merge branch 'master'
borisdayma Mar 7, 2021
cc0fcd6
test: after_save_checkpoint called
borisdayma Mar 7, 2021
a71603d
docs(wandb): fix typo
borisdayma Mar 7, 2021
ded7204
test(wandb): test log_model
borisdayma Mar 7, 2021
1b88a5e
feat(wandb): min version
borisdayma Mar 7, 2021
4f35813
test(wandb): fix directory creation
borisdayma Mar 7, 2021
876dbee
docs: update CHANGELOG
borisdayma Mar 8, 2021
ba1e937
test(wandb): fix variable not defined
borisdayma Mar 8, 2021
9593557
Merge branch 'master' into feat_artifacts
borisdayma Mar 8, 2021
fe98f4f
feat: after_save_checkpoint on rank 0 only
borisdayma Mar 9, 2021
4b38fc4
Merge branch 'master' into feat_artifacts
borisdayma Mar 10, 2021
b59fdf1
Merge branch 'master' into feat_artifacts
borisdayma Mar 11, 2021
13a730b
Merge branch 'master' into feat_artifacts
borisdayma Mar 12, 2021
aa904ce
feat: handle new args of ModelCheckpoint
borisdayma Mar 12, 2021
27c49eb
test(wandb): check correct metadata
borisdayma Mar 12, 2021
e0a9578
tests(wandb): unused fixture
borisdayma Mar 14, 2021
bbf4683
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
borisdayma Mar 14, 2021
58193e8
feat: logger.after_save_checkpoint always exists
borisdayma Mar 14, 2021
fda377f
test: wandb fixture required
borisdayma Mar 14, 2021
ce6c912
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
borisdayma Apr 1, 2021
5e39044
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
borisdayma Apr 8, 2021
62d5cae
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
borisdayma May 14, 2021
0b7bb39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
c06fc8f
test(wandb): parameter unset
borisdayma May 14, 2021
0ca8310
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
borisdayma May 26, 2021
0ca6abb
formatting
awaelchli May 27, 2021
f6f8f61
typo fix
awaelchli May 27, 2021
1faa389
fix typo in docs
awaelchli May 27, 2021
e0f302f
Merge branch 'master' into feat_artifacts
awaelchli May 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))


- Changed `WandbLogger(log_model={True/'all'})` to log models as artifacts ([#6231](https://github.com/PyTorchLightning/pytorch-lightning/pull/6231))
- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622))


Expand Down
13 changes: 9 additions & 4 deletions docs/source/common/loggers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ The :class:`~pytorch_lightning.loggers.TestTubeLogger` is available anywhere exc
Weights and Biases
==================

`Weights and Biases <https://www.wandb.com/>`_ is a third-party logger.
`Weights and Biases <https://docs.wandb.ai/integrations/lightning/>`_ is a third-party logger.
To use :class:`~pytorch_lightning.loggers.WandbLogger` as your logger do the following.
First, install the package:

Expand All @@ -215,9 +215,14 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
.. code-block:: python

from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(offline=True)

# instrument experiment with W&B
wandb_logger = WandbLogger(project='MNIST', log_model='all')
trainer = Trainer(logger=wandb_logger)

# log gradients and model topology
wandb_logger.watch(model)

The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your
:class:`~pytorch_lightning.core.lightning.LightningModule`.

Expand All @@ -226,8 +231,8 @@ The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except
class MyModule(LightningModule):
def any_lightning_module_function_or_hook(self):
some_img = fake_image()
self.logger.experiment.log({
"generated_images": [wandb.Image(some_img, caption="...")]
self.log({
"generated_images": [wandb.Image(some_img, caption="...")]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey guys... we need to standardize for all loggers, not just wnb. let's sync up on this to make sure these changes aren't just for a single logger.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if I need to update this on the other loggers as well.

I was just trying to take advantage of this PR to clean up the doc since I often get asked on when to use self.log, self.logger.experiment.log or even self.logger[0].experiment.log (I typically suggest to just try to use self.log).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.log_images could be the API used to log every image related artefacts.

})
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

.. seealso::
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
from weakref import proxy

import numpy as np
import torch
Expand Down Expand Up @@ -330,6 +331,10 @@ def save_checkpoint(self, trainer: 'pl.Trainer', unused: Optional['pl.LightningM
# Mode 3: save last checkpoints
self._save_last_checkpoint(trainer, monitor_candidates)

# notify loggers
if trainer.is_global_zero and trainer.logger:
trainer.logger.after_save_checkpoint(proxy(self))

def _should_skip_saving_checkpoint(self, trainer: 'pl.Trainer') -> bool:
from pytorch_lightning.trainer.states import TrainerFn
return (
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
from argparse import Namespace
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
from weakref import ReferenceType

import numpy as np
import torch

from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only

Expand Down Expand Up @@ -71,6 +73,15 @@ def __init__(
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
self._agg_default_func = agg_default_func

def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
"""
Called after model checkpoint callback saves a new checkpoint

Args:
model_checkpoint: the model checkpoint callback instance
"""
pass

def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
Expand Down Expand Up @@ -357,6 +368,10 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]):
def __getitem__(self, index: int) -> LightningLoggerBase:
return [logger for logger in self._logger_iterable][index]

def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
for logger in self._logger_iterable:
logger.after_save_checkpoint(checkpoint_callback)

def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
Expand Down
89 changes: 75 additions & 14 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,26 @@
Weights and Biases Logger
-------------------------
"""
import operator
import os
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, Optional, Union
from weakref import ReferenceType

import torch.nn as nn

from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()

_WANDB_AVAILABLE = _module_available("wandb")
_WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22")

try:
import wandb
Expand All @@ -40,7 +46,7 @@

class WandbLogger(LightningLoggerBase):
r"""
Log using `Weights and Biases <https://www.wandb.com/>`_.
Log using `Weights and Biases <https://docs.wandb.ai/integrations/lightning>`_.

Install it with pip:

Expand All @@ -56,7 +62,15 @@ class WandbLogger(LightningLoggerBase):
version: Same as id.
anonymous: Enables or explicitly disables anonymous logging.
project: The name of the project to which this run will belong.
log_model: Save checkpoints in wandb dir to upload on W&B servers.
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
as W&B artifacts.

* if ``log_model == 'all'``, checkpoints are logged during training.
* if ``log_model == True``, checkpoints are logged at the end of training, except when
:paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` ``== -1``
which also logs every checkpoint during training.
* if ``log_model == False`` (default), no checkpoint is logged.

prefix: A string to put at the beginning of metric keys.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
Expand All @@ -71,15 +85,16 @@ class WandbLogger(LightningLoggerBase):

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
wandb_logger = WandbLogger()

# instrument experiment with W&B
wandb_logger = WandbLogger(project='MNIST', log_model='all')
trainer = Trainer(logger=wandb_logger)

Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`,
make sure to use `commit=False` so the logging step does not increase.
# log gradients and model topology
wandb_logger.watch(model)

See Also:
- `Tutorial <https://colab.research.google.com/drive/16d1uctGaw2y9KhGBlINNTsWpmlXdJwRW?usp=sharing>`__
on how to use W&B with PyTorch Lightning
- `Demo in Google Colab <http://wandb.me/lightning>`__ with model logging
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__

"""
Expand Down Expand Up @@ -114,6 +129,13 @@ def __init__(
'Hint: Set `offline=False` to log your model.'
)

if log_model and not _WANDB_GREATER_EQUAL_0_10_22:
warning_cache.warn(
f'Providing log_model={log_model} requires wandb version >= 0.10.22'
' for logging associated model metadata.\n'
'Hint: Upgrade with `pip install --ugrade wandb`.'
)

if sync_step is not None:
warning_cache.warn(
"`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
Expand All @@ -125,6 +147,8 @@ def __init__(
self._log_model = log_model
self._prefix = prefix
self._experiment = experiment
self._logged_model_time = {}
self._checkpoint_callback = None
# set wandb init arguments
anonymous_lut = {True: 'allow', False: None}
self._wandb_init = dict(
Expand Down Expand Up @@ -168,10 +192,6 @@ def experiment(self) -> Run:
os.environ['WANDB_MODE'] = 'dryrun'
self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run

# save checkpoints in wandb dir to upload on W&B servers
if self._save_dir is None:
self._save_dir = self._experiment.dir

# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
Expand Down Expand Up @@ -213,8 +233,49 @@ def version(self) -> Optional[str]:
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id

def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
# log checkpoints as artifacts
if self._log_model == 'all' or self._log_model is True and checkpoint_callback.save_top_k == -1:
self._scan_and_log_checkpoints(checkpoint_callback)
elif self._log_model is True:
self._checkpoint_callback = checkpoint_callback

@rank_zero_only
def finalize(self, status: str) -> None:
# upload all checkpoints from saving dir
if self._log_model:
wandb.save(os.path.join(self.save_dir, "*.ckpt"))
# log checkpoints as artifacts
if self._checkpoint_callback:
self._scan_and_log_checkpoints(self._checkpoint_callback)

def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
# get checkpoints to be saved with associated score
checkpoints = {
checkpoint_callback.last_model_path: checkpoint_callback.current_score,
checkpoint_callback.best_model_path: checkpoint_callback.best_model_score,
**checkpoint_callback.best_k_models
}
checkpoints = sorted([(Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file()])
checkpoints = [
c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0]
]

# log iteratively all new checkpoints
for t, p, s in checkpoints:
metadata = {
'score': s,
'original_filename': Path(p).name,
'ModelCheckpoint': {
k: getattr(checkpoint_callback, k)
for k in [
'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', '_every_n_train_steps',
'_every_n_val_epochs'
]
# ensure it does not break if `ModelCheckpoint` args change
if hasattr(checkpoint_callback, k)
}
} if _WANDB_GREATER_EQUAL_0_10_22 else None
artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata)
artifact.add_file(p, name='model.ckpt')
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
self.experiment.log_artifact(artifact, aliases=aliases)
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
self._logged_model_time[p] = t
5 changes: 5 additions & 0 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self):
self.hparams_logged = None
self.metrics_logged = {}
self.finalized = False
self.after_save_checkpoint_called = False

@property
def experiment(self):
Expand Down Expand Up @@ -92,6 +93,9 @@ def name(self):
def version(self):
return "1"

def after_save_checkpoint(self, checkpoint_callback):
self.after_save_checkpoint_called = True


def test_custom_logger(tmpdir):

Expand All @@ -115,6 +119,7 @@ def training_step(self, batch, batch_idx):
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert logger.hparams_logged == model.hparams
assert logger.metrics_logged != {}
assert logger.after_save_checkpoint_called
assert logger.finalized_status == "success"


Expand Down
77 changes: 66 additions & 11 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,8 @@
from tests.helpers import BoringModel


def get_warnings(recwarn):
warnings_text = '\n'.join(str(w.message) for w in recwarn.list)
recwarn.clear()
return warnings_text


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_logger_init(wandb, recwarn):
def test_wandb_logger_init(wandb):
"""Verify that basic functionality of wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here."""

Expand All @@ -51,8 +45,6 @@ def test_wandb_logger_init(wandb, recwarn):
run = wandb.init()
logger = WandbLogger(experiment=run)
assert logger.experiment
assert run.dir is not None
assert logger.save_dir == run.dir

# test wandb.init not called if there is a W&B run
wandb.init().log.reset_mock()
Expand Down Expand Up @@ -140,10 +132,8 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):

# mock return values of experiment
wandb.run = None
wandb.init().step = 0
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
logger.experiment.step = 0

for _ in range(2):
_ = logger.experiment
Expand All @@ -164,6 +154,71 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
assert trainer.log_dir == logger.save_dir


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_log_model(wandb, tmpdir):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test for restarting an experiment ?

Copy link
Contributor Author

@borisdayma borisdayma Mar 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand.
If you mean reloading a model, it is not done by the integration but by wandb api (see demo) and interact with the servers (to download the artifact). The logic of uploading a file and re-downloading the same file would probably be more specific to wandb itself.

If you want a full test (without mocking wandb), I think there was issues in the past which lead to mocking all loggers but I could try to recreate one.

""" Test that the logger creates the folders and files in the right place. """

wandb.run = None
model = BoringModel()

# test log_model=True
logger = WandbLogger(log_model=True)
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
wandb.init().log_artifact.assert_called_once()

# test log_model='all'
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
logger = WandbLogger(log_model='all')
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
assert wandb.init().log_artifact.call_count == 2

# test log_model=False
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
logger = WandbLogger(log_model=False)
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
assert not wandb.init().log_artifact.called

# test correct metadata
import pytorch_lightning.loggers.wandb as pl_wandb
pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
wandb.Artifact.reset_mock()
logger = pl_wandb.WandbLogger(log_model=True)
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
wandb.Artifact.assert_called_once_with(
name='model-1',
type='model',
metadata={
'score': None,
'original_filename': 'epoch=1-step=5-v3.ckpt',
'ModelCheckpoint': {
'monitor': None,
'mode': 'min',
'save_last': None,
'save_top_k': None,
'save_weights_only': False,
'_every_n_train_steps': 0,
'_every_n_val_epochs': 1
}
}
)


def test_wandb_sanitize_callable_params(tmpdir):
"""
Callback function are not serializiable. Therefore, we get them a chance to return
Expand Down