Skip to content

Commit

Permalink
Deprecate prefix argument in ModelCheckpoint (#4765)
Browse files Browse the repository at this point in the history
* Deprecate prefix in ModelCheckpoint

* chlog

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 21, 2020
1 parent 8e91cee commit 37b388e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated `prefix` argument in `ModelCheckpoint` ([#4765](https://github.com/PyTorchLightning/pytorch-lightning/pull/4765))


### Removed
Expand Down
16 changes: 15 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ class ModelCheckpoint(Callback):
saved (``model.save_weights(filepath)``), else the full model
is saved (``model.save(filepath)``).
period: Interval (number of epochs) between checkpoints.
prefix: A string to put at the beginning of checkpoint filename.
.. warning::
This argument has been deprecated in v1.1 and will be removed in v1.3
dirpath: directory to save the model file.
Expand Down Expand Up @@ -167,6 +171,12 @@ def __init__(
if save_top_k is None and monitor is not None:
self.save_top_k = 1

if prefix:
rank_zero_warn(
'Argument `prefix` is deprecated in v1.1 and will be removed in v1.3.'
' Please prepend your prefix in `filename` instead.', DeprecationWarning
)

self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(filepath, dirpath, filename, save_top_k)
self.__validate_init_configuration()
Expand Down Expand Up @@ -380,7 +390,11 @@ def _format_checkpoint_name(
if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt])

if prefix:
filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename])

return filename

def format_checkpoint_name(
self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None
Expand Down
4 changes: 4 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def test_tbd_remove_in_v1_3_0(tmpdir):
callback = ModelCheckpoint()
Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir)

# Deprecate prefix
with pytest.deprecated_call(match='will be removed in v1.3'):
callback = ModelCheckpoint(prefix='temp')


def test_tbd_remove_in_v1_2_0():
with pytest.deprecated_call(match='will be removed in v1.2'):
Expand Down

0 comments on commit 37b388e

Please sign in to comment.