From ebbff72071b163500a69986c45c88f70758097ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 27 Oct 2020 03:38:00 +0100 Subject: [PATCH 1/4] Add key --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + tests/checkpointing/test_model_checkpoint.py | 5 +---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6c6a1741c31c5..83297397c1d33 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -187,6 +187,7 @@ def on_validation_end(self, trainer, pl_module): def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: return { + "best_model_monitor": self.monitor, "best_model_score": self.best_model_score, "best_model_path": self.best_model_path, } diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 976a91f551e0a..1634b73424dd1 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -507,10 +507,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) ch_type = type(model_checkpoint) - assert all(list( - ckpt_last["callbacks"][ch_type][k] == ckpt_last_epoch["callbacks"][ch_type][k] - for k in ("best_model_score", "best_model_path") - )) + assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type] # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) From 682d9438462c28ee72d5c0ff731a100f1d2a33b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 27 Oct 2020 03:55:00 +0100 Subject: [PATCH 2/4] Remove unused variables --- pytorch_lightning/callbacks/model_checkpoint.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 83297397c1d33..f44ce4f57d2aa 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -131,8 +131,6 @@ class ModelCheckpoint(Callback): CHECKPOINT_JOIN_CHAR = "-" CHECKPOINT_NAME_LAST = "last" - CHECKPOINT_STATE_BEST_SCORE = "checkpoint_callback_best_model_score" - CHECKPOINT_STATE_BEST_PATH = "checkpoint_callback_best_model_path" def __init__( self, From 9bd31aaa1439f19c9e439584dbfe192b1f5fcd6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 27 Oct 2020 03:59:30 +0100 Subject: [PATCH 3/4] Update CHANGELOG [skip ci] --- CHANGELOG.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 08e2e93b93d9a..a81d0d79df8ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `dirpath` and `filename` parameter in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213)) -- Added plugins docs and DDPPlugin to customize ddp across all accelerators([#4258](https://github.com/PyTorchLightning/pytorch-lightning/pull/4285)) +- Added plugins docs and DDPPlugin to customize ddp across all accelerators ([#4258](https://github.com/PyTorchLightning/pytorch-lightning/pull/4285)) - Added `strict` option to the scheduler dictionary ([#3586](https://github.com/PyTorchLightning/pytorch-lightning/pull/3586)) @@ -21,7 +21,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `fsspec` support for profilers ([#4162](https://github.com/PyTorchLightning/pytorch-lightning/pull/4162)) -- Added autogenerated helptext to `Trainer.add_argparse_args`. ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344)) +- Added autogenerated helptext to `Trainer.add_argparse_args` ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344)) + + +- Added "best_model_monitor" key to saved `ModelCheckpoints` ([#4383](https://github.com/PyTorchLightning/pytorch-lightning/pull/4383)) ### Changed From 97a956115a179f7b848050250277e0852ce90935 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 27 Oct 2020 15:17:23 +0100 Subject: [PATCH 4/4] best_model_monitor -> monitor --- CHANGELOG.md | 2 +- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a81d0d79df8ae..b98028bf2606c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added autogenerated helptext to `Trainer.add_argparse_args` ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344)) -- Added "best_model_monitor" key to saved `ModelCheckpoints` ([#4383](https://github.com/PyTorchLightning/pytorch-lightning/pull/4383)) +- Added "monitor" key to saved `ModelCheckpoints` ([#4383](https://github.com/PyTorchLightning/pytorch-lightning/pull/4383)) ### Changed diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f44ce4f57d2aa..4c9d3f4e30072 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -185,7 +185,7 @@ def on_validation_end(self, trainer, pl_module): def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: return { - "best_model_monitor": self.monitor, + "monitor": self.monitor, "best_model_score": self.best_model_score, "best_model_path": self.best_model_path, }