Skip to content

Commit

Permalink
None check for filepath in ModelCheckpoint (#1654)
Browse files Browse the repository at this point in the history
Check if the optional filepath is None before checking if it exists

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
yukw777 and mergify[bot] committed Apr 29, 2020
1 parent 9b86aea commit 42d5cfc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed ModelCheckpoint not None checking filepath ([1654](https://github.com/PyTorchLightning/pytorch-lightning/pull/1654))


## [0.7.5] - 2020-04-27

### Changed

- Allow logging of metrics together with `hparams` ([#1630](https://github.com/PyTorchLightning/pytorch-lightning/pull/1630))
- Allow metrics logged together with hparams ([#1630](https://github.com/PyTorchLightning/pytorch-lightning/pull/1630))

Expand Down Expand Up @@ -51,7 +53,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `ddp_cpu` backend for testing ddp without GPUs ([#1158](https://github.com/PyTorchLightning/pytorch-lightning/pull/1158))
- Added [Horovod](http://horovod.ai) support as a distributed backend `Trainer(distributed_backend='horovod')` ([#1529](https://github.com/PyTorchLightning/pytorch-lightning/pull/1529))
- Added support for 8 core distributed training on Kaggle TPU's ([#1568](https://github.com/PyTorchLightning/pytorch-lightning/pull/1568))
- Added support for native AMP ([#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561), [#1580](https://github.com/PyTorchLightning/pytorch-lightning/pull/1580))
- Added support for native AMP ([#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561), [#1580](https://github.com/PyTorchLightning/pytorch-lightning/pull/1580))

### Changed

Expand All @@ -78,7 +80,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed loggers - flushing last logged metrics even before continue, e.g. `trainer.test()` results ([#1459](https://github.com/PyTorchLightning/pytorch-lightning/pull/1459))
- Fixed optimizer configuration when `configure_optimizers` returns dict without `lr_scheduler` ([#1443](https://github.com/PyTorchLightning/pytorch-lightning/pull/1443))
- Fixed `LightningModule` - mixing hparams and arguments in `LightningModule.__init__()` crashes load_from_checkpoint() ([#1505](https://github.com/PyTorchLightning/pytorch-lightning/pull/1505))
- Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)).
- Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)).
- Allow use of sweeps with `WandbLogger` ([#1512](https://github.com/PyTorchLightning/pytorch-lightning/pull/1512))
- Fixed a bug that caused the `callbacks` Trainer argument to reference a global variable ([#1534](https://github.com/PyTorchLightning/pytorch-lightning/pull/1534)).
- Fixed a bug that set all boolean CLI arguments from `Trainer.add_argparse_args` always to True ([#1571](https://github.com/PyTorchLightning/pytorch-lightning/pull/1571))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if save_top_k > 0 and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
Expand Down
6 changes: 4 additions & 2 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
Expand Down Expand Up @@ -249,7 +250,8 @@ def test_pickling(tmpdir):
pickle.dumps(early_stopping)


def test_model_checkpoint_with_non_string_input(tmpdir):
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is
set correctly """
tutils.reset_seed()
Expand All @@ -260,7 +262,7 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase):
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)

checkpoint = ModelCheckpoint(filepath=None, save_top_k=-1)
checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)

trainer = Trainer(default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
Expand Down

0 comments on commit 42d5cfc

Please sign in to comment.