-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(wandb): log models as artifacts #6231
Changes from 59 commits
bfb8872
541b001
bbd8633
3365261
dfd7553
6950d3d
f9cc20f
c518d71
9b9aaa6
eb2080d
7d98a99
a6ad9aa
444a4eb
52b642f
765d081
49f3688
4a55e46
f16231c
cbbf8ff
123cd88
0822d5d
ee5b1d1
947ab7a
03af2c3
743903c
363b3ac
7e331c1
b438940
0dc78cc
78cfc7c
eeed466
5227329
cc0fcd6
a71603d
ded7204
1b88a5e
4f35813
876dbee
ba1e937
9593557
fe98f4f
4b38fc4
b59fdf1
13a730b
aa904ce
27c49eb
e0a9578
bbf4683
58193e8
fda377f
ce6c912
5e39044
62d5cae
0b7bb39
c06fc8f
0ca8310
0ca6abb
f6f8f61
1faa389
e0f302f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,14 +24,8 @@ | |
from tests.helpers import BoringModel | ||
|
||
|
||
def get_warnings(recwarn): | ||
warnings_text = '\n'.join(str(w.message) for w in recwarn.list) | ||
recwarn.clear() | ||
return warnings_text | ||
|
||
|
||
@mock.patch('pytorch_lightning.loggers.wandb.wandb') | ||
def test_wandb_logger_init(wandb, recwarn): | ||
def test_wandb_logger_init(wandb): | ||
"""Verify that basic functionality of wandb logger works. | ||
Wandb doesn't work well with pytest so we have to mock it out here.""" | ||
|
||
|
@@ -51,8 +45,6 @@ def test_wandb_logger_init(wandb, recwarn): | |
run = wandb.init() | ||
logger = WandbLogger(experiment=run) | ||
assert logger.experiment | ||
assert run.dir is not None | ||
assert logger.save_dir == run.dir | ||
|
||
# test wandb.init not called if there is a W&B run | ||
wandb.init().log.reset_mock() | ||
|
@@ -140,10 +132,8 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): | |
|
||
# mock return values of experiment | ||
wandb.run = None | ||
wandb.init().step = 0 | ||
logger.experiment.id = '1' | ||
logger.experiment.project_name.return_value = 'project' | ||
logger.experiment.step = 0 | ||
|
||
for _ in range(2): | ||
_ = logger.experiment | ||
|
@@ -164,6 +154,71 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): | |
assert trainer.log_dir == logger.save_dir | ||
|
||
|
||
@mock.patch('pytorch_lightning.loggers.wandb.wandb') | ||
def test_wandb_log_model(wandb, tmpdir): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a test for restarting an experiment ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand. If you want a full test (without mocking |
||
""" Test that the logger creates the folders and files in the right place. """ | ||
|
||
wandb.run = None | ||
model = BoringModel() | ||
|
||
# test log_model=True | ||
logger = WandbLogger(log_model=True) | ||
logger.experiment.id = '1' | ||
logger.experiment.project_name.return_value = 'project' | ||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) | ||
trainer.fit(model) | ||
wandb.init().log_artifact.assert_called_once() | ||
|
||
# test log_model='all' | ||
wandb.init().log_artifact.reset_mock() | ||
wandb.init.reset_mock() | ||
logger = WandbLogger(log_model='all') | ||
logger.experiment.id = '1' | ||
logger.experiment.project_name.return_value = 'project' | ||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) | ||
trainer.fit(model) | ||
assert wandb.init().log_artifact.call_count == 2 | ||
|
||
# test log_model=False | ||
wandb.init().log_artifact.reset_mock() | ||
wandb.init.reset_mock() | ||
logger = WandbLogger(log_model=False) | ||
logger.experiment.id = '1' | ||
logger.experiment.project_name.return_value = 'project' | ||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) | ||
trainer.fit(model) | ||
assert not wandb.init().log_artifact.called | ||
|
||
# test correct metadata | ||
import pytorch_lightning.loggers.wandb as pl_wandb | ||
pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True | ||
wandb.init().log_artifact.reset_mock() | ||
wandb.init.reset_mock() | ||
wandb.Artifact.reset_mock() | ||
logger = pl_wandb.WandbLogger(log_model=True) | ||
logger.experiment.id = '1' | ||
logger.experiment.project_name.return_value = 'project' | ||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) | ||
trainer.fit(model) | ||
wandb.Artifact.assert_called_once_with( | ||
name='model-1', | ||
type='model', | ||
metadata={ | ||
'score': None, | ||
'original_filename': 'epoch=1-step=5-v3.ckpt', | ||
'ModelCheckpoint': { | ||
'monitor': None, | ||
'mode': 'min', | ||
'save_last': None, | ||
'save_top_k': None, | ||
'save_weights_only': False, | ||
'_every_n_train_steps': 0, | ||
'_every_n_val_epochs': 1 | ||
} | ||
} | ||
) | ||
|
||
|
||
def test_wandb_sanitize_callable_params(tmpdir): | ||
""" | ||
Callback function are not serializiable. Therefore, we get them a chance to return | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey guys... we need to standardize for all loggers, not just wnb. let's sync up on this to make sure these changes aren't just for a single logger.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know if I need to update this on the other loggers as well.
I was just trying to take advantage of this PR to clean up the doc since I often get asked on when to use
self.log
,self.logger.experiment.log
or evenself.logger[0].experiment.log
(I typically suggest to just try to useself.log
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.log_images
could be the API used to log every image related artefacts.