Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start version suffixes at 1 #5008

Merged
merged 31 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
180cd54
Rename original filepath to v0
carmocca Dec 7, 2020
8224a30
Clean-up
carmocca Dec 7, 2020
9486c03
Suggestions from code review
carmocca Dec 10, 2020
96780c6
Revert renaming. Start version number at 1
carmocca Dec 10, 2020
6b1a475
Merge branch 'master' into 5000
carmocca Dec 10, 2020
75586d3
Merge branch 'master' into 5000
carmocca Dec 11, 2020
2a17a02
Merge branch 'master' into 5000
carmocca Dec 11, 2020
e721483
Merge remote-tracking branch 'upstream/release/1.2-dev' into 5000
carmocca Jan 15, 2021
df52229
Add ModelCheckpoint.STARTING_VERSION
carmocca Jan 15, 2021
365c549
Apply suggestions from code review
carmocca Jan 15, 2021
662351c
Add note about class attributes
carmocca Jan 15, 2021
3e6f535
Merge remote-tracking branch 'upstream/release/1.2-dev' into 5000
carmocca Jan 15, 2021
7f70313
Update CHANGELOG
carmocca Jan 15, 2021
23516ca
Fix doc
carmocca Jan 15, 2021
7cfac05
Merge branch 'release/1.2-dev' into 5000
rohitgr7 Jan 21, 2021
44216da
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
62abf1c
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
f6b9a9a
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
18f3a35
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
84eccd1
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
94a7109
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
f03d5f7
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
74473d4
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
6b0cd94
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
be9c55c
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
e5d5530
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
7ac4bc2
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
d5c9ed2
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
2a7c0d7
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
eff2e81
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
b60982a
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default of `find_unused_parameters` to `False` in DDP ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))


- Changed `ModelCheckpoint` version suffixes to start at 1 ([5008](https://github.com/PyTorchLightning/pytorch-lightning/pull/5008))


- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))


Expand Down
53 changes: 31 additions & 22 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class ModelCheckpoint(Callback):
the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
if ``save_top_k == -1``, all models are saved.
Please note that the monitors are checked every `period` epochs.
Please note that the monitors are checked every ``period`` epochs.
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
appended with a version count starting with ``v1``.
mode: one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
Expand All @@ -105,6 +105,17 @@ class ModelCheckpoint(Callback):
.. warning::
This argument has been deprecated in v1.1 and will be removed in v1.3

Note:
For extra customization, ModelCheckpoint includes the following attributes:

- ``CHECKPOINT_JOIN_CHAR = "-"``
- ``CHECKPOINT_NAME_LAST = "last"``
- ``FILE_EXTENSION = ".ckpt"``
- ``STARTING_VERSION = 1``

For example, you can change the default last checkpoint name by doing
``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``

Example::

>>> from pytorch_lightning import Trainer
Expand All @@ -128,11 +139,13 @@ class ModelCheckpoint(Callback):
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path

"""

CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_NAME_LAST = "last"
FILE_EXTENSION = ".ckpt"
STARTING_VERSION = 1

def __init__(
self,
Expand Down Expand Up @@ -485,28 +498,24 @@ def _validate_monitor_key(self, trainer):

def _get_metric_interpolated_filepath_name(
self,
ckpt_name_metrics: Dict[str, Any],
monitor_candidates: Dict[str, Any],
epoch: int,
step: int,
del_filepath: Optional[str] = None
) -> str:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)

version_cnt = 0
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates)
version = self.STARTING_VERSION
while self._fs.exists(filepath) and filepath != del_filepath:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
version_cnt += 1

filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version)
version += 1
return filepath

def _monitor_candidates(self, trainer):
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
Comment on lines -503 to -505
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we see somewhere that these are already contained?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton told me a long time ago (don't remember where). Can you comment?

ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
return ckpt_name_metrics
monitor_candidates = deepcopy(trainer.logger_connector.callback_metrics)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch)
return monitor_candidates

def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
def _save_last_checkpoint(self, trainer, pl_module, monitor_candidates):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return
Expand All @@ -517,13 +526,13 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
ckpt_name_metrics,
prefix=self.prefix
monitor_candidates,
prefix=self.prefix,
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
else:
last_filepath = self._get_metric_interpolated_filepath_name(
ckpt_name_metrics, trainer.current_epoch, trainer.global_step
monitor_candidates, trainer.current_epoch, trainer.global_step
)

accelerator_backend = trainer.accelerator_backend
Expand All @@ -534,10 +543,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
else:
self._save_model(last_filepath, trainer, pl_module)
if (
self.last_model_path
and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last)
and trainer.is_global_zero
self.last_model_path
and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last)
and trainer.is_global_zero
):
self._del_model(self.last_model_path)
self.last_model_path = last_filepath
Expand Down
91 changes: 58 additions & 33 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,18 +738,20 @@ def test_val_check_interval_checkpoint_files(tmpdir):
save_top_k=-1,
monitor="val_acc",
mode="max",
verbose=True
)
trainer = Trainer(
default_root_dir=tmpdir,
val_check_interval=0.2,
max_epochs=1,
limit_train_batches=10,
callbacks=[model_checkpoint]
callbacks=[model_checkpoint],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
files = sorted([p.name for p in Path(tmpdir).glob("*.ckpt")])
assert files == [f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]]
files = {p.basename for p in tmpdir.listdir()}
assert files == {f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]}


def test_current_score(tmpdir):
Expand Down Expand Up @@ -844,43 +846,66 @@ def __init__(self, hparams):
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type


@pytest.mark.parametrize('max_epochs', [3, 4])
@pytest.mark.parametrize(
'save_top_k, expected',
[
(1, ['curr_epoch.ckpt']),
(2, ['curr_epoch.ckpt', 'curr_epoch-v0.ckpt']),
]
)
def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, expected):
def test_ckpt_version_after_rerun_new_trainer(tmpdir):
"""
Test that version is added to filename if required and it already exists in dirpath.
Check that previous checkpoints are renamed to have the correct
version suffix when new trainer instances are used
"""
model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
filename='curr_epoch',
save_top_k=save_top_k,
monitor='epoch',
mode='max',
)
epochs = 2
for i in range(epochs):
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="{epoch}")
trainer = Trainer(
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
default_root_dir=tmpdir,
callbacks=[mc],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(BoringModel())

# check best_k_models state
expected = {"epoch=0-v1.ckpt", "epoch=1-v1.ckpt"} if i else {"epoch=0.ckpt", "epoch=1.ckpt"}
assert {Path(f).name for f in mc.best_k_models.keys()} == expected

# check created ckpts
assert set(f.basename for f in tmpdir.listdir()) == {
"epoch=0.ckpt",
"epoch=1.ckpt",
"epoch=0-v1.ckpt",
"epoch=1-v1.ckpt",
}


def test_ckpt_version_after_rerun_same_trainer(tmpdir):
"""
Check that previous checkpoints are renamed to have the correct
version suffix when the same trainer instance is used
"""
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="test")
mc.STARTING_VERSION = 9
trainer = Trainer(
max_epochs=2,
limit_train_batches=1,
limit_val_batches=1,
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_epochs=max_epochs,
limit_train_batches=2,
limit_val_batches=2,
logger=None,
callbacks=[mc],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(BoringModel())
trainer.max_epochs = 4
trainer.fit(BoringModel())

model = BoringModel()
trainer.fit(model)
ckpt_files = os.listdir(tmpdir)
assert set(ckpt_files) == set(expected)

epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files]
assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))
ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION)
expected = {'test.ckpt', *[f"test-v{i}.ckpt" for i in ckpt_range]}
# check best_k_models state
assert {Path(f).name for f in mc.best_k_models.keys()} == expected
# check created ckpts
assert set(sorted(os.listdir(tmpdir))) == expected


def test_model_checkpoint_mode_options():
Expand Down