Skip to content

Commit

Permalink
Add check for verbose attribute of ModelCheckpoint (#6419)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
ashleve and awaelchli committed Mar 8, 2021
1 parent e1f5eac commit 9eded7f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def check_checkpoint_callback(self, should_update, is_last=False):
if should_update and self.trainer.checkpoint_connector.has_trained:
callbacks = self.trainer.checkpoint_callbacks

if is_last and any(cb.save_last for cb in callbacks):
if is_last and any(cb.save_last and cb.verbose for cb in callbacks):
rank_zero_info("Saving latest checkpoint...")

model = self.trainer.lightning_module
Expand Down
11 changes: 8 additions & 3 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,19 +672,24 @@ def test_default_checkpoint_behavior(tmpdir):
@pytest.mark.parametrize('max_epochs', [1, 2])
@pytest.mark.parametrize('should_validate', [True, False])
@pytest.mark.parametrize('save_last', [True, False])
def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_validate, save_last):
@pytest.mark.parametrize('verbose', [True, False])
def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_validate, save_last, verbose):
"""Tests 'Saving latest checkpoint...' log"""
model = LogInTwoMethods()
if not should_validate:
model.validation_step = None
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last)],
callbacks=[
ModelCheckpoint(
monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose
)
],
max_epochs=max_epochs,
)
with caplog.at_level(logging.INFO):
trainer.fit(model)
assert caplog.messages.count('Saving latest checkpoint...') == save_last
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
Expand Down

0 comments on commit 9eded7f

Please sign in to comment.