From 8d871709be71788a03de0a566cd209daeec09fcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 25 Oct 2020 00:15:17 +0200 Subject: [PATCH 01/15] first attempt --- .../trainer/connectors/callback_connector.py | 42 +++++++++++++------ pytorch_lightning/trainer/properties.py | 18 +++++++- pytorch_lightning/trainer/trainer.py | 7 ++-- tests/checkpointing/test_model_checkpoint.py | 35 ++++++++++++++++ tests/test_deprecated.py | 7 ++++ 5 files changed, 93 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 187ff237056a2..e02b8d2d2cd94 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os + +from typing import Union, Optional + from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -44,25 +47,37 @@ def on_trainer_init( # configure checkpoint callback # it is important that this is the last callback to run # pass through the required args to figure out defaults - checkpoint_callback = self.init_default_checkpoint_callback(checkpoint_callback) - if checkpoint_callback: - self.trainer.callbacks.append(checkpoint_callback) - - # TODO refactor codebase (tests) to not directly reach into these callbacks - self.trainer.checkpoint_callback = checkpoint_callback + self.configure_checkpoint_callbacks(checkpoint_callback) # init progress bar self.trainer._progress_bar_callback = self.configure_progress_bar( progress_bar_refresh_rate, process_position ) - def init_default_checkpoint_callback(self, checkpoint_callback): - if checkpoint_callback is True: - checkpoint_callback = ModelCheckpoint(dirpath=None, filename=None) - elif checkpoint_callback is False: - checkpoint_callback = None + def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]): + if isinstance(checkpoint_callback, ModelCheckpoint): + # TODO: deprecated, remove this block in v1.4.0 + rank_zero_warn( + "Passing a ModelCheckpoint instance to Trainer(checkpoint_callbacks=...)" + " is deprecated since v1.1 and will no longer be supported in v1.4.", + DeprecationWarning + ) + self.trainer.callbacks.append(checkpoint_callback) - return checkpoint_callback + if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False: + raise MisconfigurationException( + "Trainer was configured with checkpoint_callback=False but found ModelCheckpoint" + " in callbacks list." + ) + + if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True: + self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None)) + + if len(self.trainer.checkpoint_callbacks) > 1: + raise MisconfigurationException( + "You added multiple ModelCheckpoint callbacks to the Trainer, but currently only one" + " instance is supported." + ) def configure_progress_bar(self, refresh_rate=1, process_position=0): progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] @@ -83,3 +98,6 @@ def configure_progress_bar(self, refresh_rate=1, process_position=0): progress_bar_callback = None return progress_bar_callback + + def _trainer_has_checkpoint_callbacks(self): + return len(self.trainer.checkpoint_callbacks) > 0 diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index afb2f4cb5eb91..963c84e77b2c6 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -17,7 +17,8 @@ from argparse import ArgumentParser, Namespace from typing import List, Optional, Union, Type, TypeVar -from pytorch_lightning.callbacks import ProgressBarBase +from pytorch_lightning import Callback +from pytorch_lightning.callbacks import ProgressBarBase, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector @@ -46,6 +47,7 @@ class TrainerProperties(ABC): _weights_save_path: str model_connector: ModelConnector checkpoint_connector: CheckpointConnector + callbacks: List[Callback] @property def use_amp(self) -> bool: @@ -187,6 +189,20 @@ def weights_save_path(self) -> str: return os.path.normpath(self._weights_save_path) return self._weights_save_path + @property + def checkpoint_callback(self) -> Optional[ModelCheckpoint]: + """ + The first checkpoint callback in the Trainer.callbacks list, or ``None`` if + no checkpoint callbacks exist. + """ + callbacks = self.checkpoint_callbacks + return callbacks[0] if len(callbacks) > 0 else None + + @property + def checkpoint_callbacks(self) -> List[ModelCheckpoint]: + """ A list of all instances of ModelCheckpoint found in the Trainer.callbacks list. """ + return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + def save_checkpoint(self, filepath, weights_only: bool = False): self.checkpoint_connector.save_checkpoint(filepath, weights_only) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 44250ae905aba..3e456bba5f12a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -85,7 +85,7 @@ class Trainer( def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, - checkpoint_callback: Union[ModelCheckpoint, bool] = True, + checkpoint_callback: bool = True, callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, @@ -169,7 +169,9 @@ def __init__( callbacks: Add a list of callbacks. - checkpoint_callback: Callback for checkpointing. + checkpoint_callback: If ``True``, will configure a default ModelCheckpoint callback. + Passing a ModelCheckpoint instance to this argument is deprecated since + v1.1.0 and will be unsupported from v1.4.0. Default: ``True`` check_val_every_n_epoch: Check val every n train epochs. @@ -296,7 +298,6 @@ def __init__( # init callbacks # Declare attributes to be set in callback_connector on_trainer_init - self.checkpoint_callback: Union[ModelCheckpoint, bool] = checkpoint_callback self.callback_connector.on_trainer_init( callbacks, checkpoint_callback, diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 976a91f551e0a..096a194ad58fb 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -746,3 +746,38 @@ def test_filepath_decomposition_dirpath_filename(tmpdir, filepath, dirpath, file assert mc_cb.dirpath == dirpath assert mc_cb.filename == filename + + +def test_configure_model_checkpoint(tmpdir): + """ Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """ + kwargs = dict(default_root_dir=tmpdir) + callback = ModelCheckpoint() + + # no callbacks + trainer = Trainer(checkpoint_callback=False, callbacks=[], **kwargs) + assert not any(isinstance(c, ModelCheckpoint) for c in trainer.callbacks) + assert trainer.checkpoint_callback is None + + # default configuration + trainer = Trainer(checkpoint_callback=True, callbacks=[], **kwargs) + assert len([c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)]) == 1 + assert isinstance(trainer.checkpoint_callback, ModelCheckpoint) + + # custom callback passed to callbacks list, checkpoint_callback=True is ignored + trainer = Trainer(checkpoint_callback=True, callbacks=[callback], **kwargs) + assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback] + assert trainer.checkpoint_callback == callback + + with pytest.warns(DeprecationWarning, match='will no longer be supported in v1.4'): + trainer = Trainer(checkpoint_callback=callback, callbacks=[], **kwargs) + assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback] + assert trainer.checkpoint_callback == callback + + with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"): + Trainer(checkpoint_callback=False, callbacks=[callback], **kwargs) + + with pytest.raises(MisconfigurationException, match="You added multiple ModelCheckpoint callbacks"): + Trainer(checkpoint_callback=callback, callbacks=[callback], **kwargs) + + with pytest.raises(MisconfigurationException, match="You added multiple ModelCheckpoint callbacks"): + Trainer(callbacks=[callback, callback], **kwargs) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 70eb2a709b195..dd8e855bea493 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -4,12 +4,19 @@ import torch +from pytorch_lightning import Trainer from tests.base import EvalModelTemplate from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities.exceptions import MisconfigurationException +def test_tbd_remove_in_v1_4_0(tmpdir): + with pytest.deprecated_call(match='will no longer be supported in v1.4'): + callback = ModelCheckpoint() + trainer = Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir) + + def test_tbd_remove_in_v1_2_0(): with pytest.deprecated_call(match='will be removed in v1.2'): checkpoint_cb = ModelCheckpoint(filepath='.') From d764359d2f7d889e43e1325db32a27e1e58f935b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 25 Oct 2020 00:33:26 +0200 Subject: [PATCH 02/15] update tests --- pytorch_lightning/tuner/batch_size_scaling.py | 2 -- pytorch_lightning/tuner/lr_finder.py | 4 ---- tests/models/test_restore.py | 3 +-- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 87783fbde5d1f..10de8a2d289e5 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -144,7 +144,6 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial): trainer.weights_summary = None # not needed before full run trainer.logger = DummyLogger() trainer.callbacks = [] # not needed before full run - trainer.checkpoint_callback = False # required for saving trainer.limit_train_batches = 1.0 trainer.optimizers, trainer.schedulers = [], [] # required for saving trainer.model = model # required for saving @@ -157,7 +156,6 @@ def __scale_batch_restore_params(trainer): trainer.weights_summary = trainer.__dumped_params['weights_summary'] trainer.logger = trainer.__dumped_params['logger'] trainer.callbacks = trainer.__dumped_params['callbacks'] - trainer.checkpoint_callback = trainer.__dumped_params['checkpoint_callback'] trainer.auto_scale_batch_size = trainer.__dumped_params['auto_scale_batch_size'] trainer.limit_train_batches = trainer.__dumped_params['limit_train_batches'] trainer.model = trainer.__dumped_params['model'] diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index d0ab33df8e1b8..3107f9f44824a 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -155,9 +155,6 @@ def lr_find( if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() - # Disable standard checkpoint & early stopping - trainer.checkpoint_callback = False - # Required for saving the model trainer.optimizers, trainer.schedulers = [], [], trainer.model = model @@ -212,7 +209,6 @@ def __lr_finder_restore_params(trainer, model): trainer.logger = trainer.__dumped_params['logger'] trainer.callbacks = trainer.__dumped_params['callbacks'] trainer.max_steps = trainer.__dumped_params['max_steps'] - trainer.checkpoint_callback = trainer.__dumped_params['checkpoint_callback'] model.configure_optimizers = trainer.__dumped_params['configure_optimizers'] del trainer.__dumped_params diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 862294e64765f..a88b61c479187 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -60,8 +60,7 @@ def test_resume_from_checkpoint(tmpdir): default_root_dir=tmpdir, max_epochs=2, logger=False, - checkpoint_callback=checkpoint_callback, - callbacks=[ModelTrainerPropertyParity()] # this performs the assertions + callbacks=[checkpoint_callback, ModelTrainerPropertyParity()] # this performs the assertions ) trainer = Trainer(**trainer_args) trainer.fit(model) From 44425379cc4ce386a87acb14cf2cfa065f6ffc5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 25 Oct 2020 01:10:59 +0200 Subject: [PATCH 03/15] support multiple --- .../trainer/connectors/callback_connector.py | 6 ---- tests/checkpointing/test_model_checkpoint.py | 31 ++++++++++++------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index e02b8d2d2cd94..b8a4276a2d747 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -73,12 +73,6 @@ def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpo if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True: self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None)) - if len(self.trainer.checkpoint_callbacks) > 1: - raise MisconfigurationException( - "You added multiple ModelCheckpoint callbacks to the Trainer, but currently only one" - " instance is supported." - ) - def configure_progress_bar(self, refresh_rate=1, process_position=0): progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] if len(progress_bars) > 1: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 096a194ad58fb..62e1491ddfeed 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -751,7 +751,8 @@ def test_filepath_decomposition_dirpath_filename(tmpdir, filepath, dirpath, file def test_configure_model_checkpoint(tmpdir): """ Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """ kwargs = dict(default_root_dir=tmpdir) - callback = ModelCheckpoint() + callback1 = ModelCheckpoint() + callback2 = ModelCheckpoint() # no callbacks trainer = Trainer(checkpoint_callback=False, callbacks=[], **kwargs) @@ -764,20 +765,26 @@ def test_configure_model_checkpoint(tmpdir): assert isinstance(trainer.checkpoint_callback, ModelCheckpoint) # custom callback passed to callbacks list, checkpoint_callback=True is ignored - trainer = Trainer(checkpoint_callback=True, callbacks=[callback], **kwargs) - assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback] - assert trainer.checkpoint_callback == callback + trainer = Trainer(checkpoint_callback=True, callbacks=[callback1], **kwargs) + assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1] + assert trainer.checkpoint_callback == callback1 + + # multiple checkpoint callbacks + trainer = Trainer(callbacks=[callback1, callback2], **kwargs) + assert trainer.checkpoint_callback == callback1 + assert trainer.checkpoint_callbacks == [callback1, callback2] with pytest.warns(DeprecationWarning, match='will no longer be supported in v1.4'): - trainer = Trainer(checkpoint_callback=callback, callbacks=[], **kwargs) - assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback] - assert trainer.checkpoint_callback == callback + trainer = Trainer(checkpoint_callback=callback1, callbacks=[], **kwargs) + assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1] + assert trainer.checkpoint_callback == callback1 + + with pytest.warns(DeprecationWarning, match="will no longer be supported in v1.4"): + trainer = Trainer(checkpoint_callback=callback1, callbacks=[callback2], **kwargs) + assert trainer.checkpoint_callback == callback2 + assert trainer.checkpoint_callbacks == [callback2, callback1] with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"): - Trainer(checkpoint_callback=False, callbacks=[callback], **kwargs) + Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs) - with pytest.raises(MisconfigurationException, match="You added multiple ModelCheckpoint callbacks"): - Trainer(checkpoint_callback=callback, callbacks=[callback], **kwargs) - with pytest.raises(MisconfigurationException, match="You added multiple ModelCheckpoint callbacks"): - Trainer(callbacks=[callback, callback], **kwargs) From 7318e87c911decec6e0a4b254b2f683e8d547ccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Oct 2020 20:57:40 +0100 Subject: [PATCH 04/15] test bugfix --- tests/models/test_restore.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index a88b61c479187..9daf112e711b4 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -64,8 +64,10 @@ def test_resume_from_checkpoint(tmpdir): ) trainer = Trainer(**trainer_args) trainer.fit(model) + callbacks_before_resume = trainer.callbacks.copy() trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt")) trainer.fit(model) + assert trainer.callbacks == callbacks_before_resume @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") From f8a0b7ef3484d98bcf027897f2c08ee06ff65684 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Oct 2020 20:59:59 +0100 Subject: [PATCH 05/15] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1de62b442f25b..e9a46d2a00c70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated bool values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656)) +- Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336)) + ### Removed From 923b3e12a2176bf5076cb3f92528b7b3f74723c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Oct 2020 21:33:14 +0100 Subject: [PATCH 06/15] pep --- tests/checkpointing/test_model_checkpoint.py | 2 -- tests/test_deprecated.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0fa3621a041e7..3bc2ca436ec15 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -783,5 +783,3 @@ def test_configure_model_checkpoint(tmpdir): with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"): Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs) - - diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 23784d854c983..34eb5dae85826 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -19,7 +19,7 @@ def test_tbd_remove_in_v1_4_0(tmpdir): with pytest.deprecated_call(match='will no longer be supported in v1.4'): callback = ModelCheckpoint() - trainer = Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir) + Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir) def test_tbd_remove_in_v1_2_0(): From 66c34061ec8600b843f8cd72eabe0919c3ef5072 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Oct 2020 21:36:33 +0100 Subject: [PATCH 07/15] pep --- tests/test_deprecated.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 34eb5dae85826..67f38568e2103 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -6,7 +6,6 @@ import torch -from pytorch_lightning import Trainer from tests.base import EvalModelTemplate from pytorch_lightning.metrics.functional.classification import auc From 80fd9874c9ab16ced9e0bfdaed5cfe07f76d31c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Oct 2020 21:43:58 +0100 Subject: [PATCH 08/15] import order --- pytorch_lightning/trainer/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 963c84e77b2c6..2298df9c8637b 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -17,7 +17,7 @@ from argparse import ArgumentParser, Namespace from typing import List, Optional, Union, Type, TypeVar -from pytorch_lightning import Callback +from pytorch_lightning.callbacks import Callback, ProgressBarBase from pytorch_lightning.callbacks import ProgressBarBase, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector From 78693ccfefd714cf4c6558e887be3c31821df9ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Oct 2020 21:45:41 +0100 Subject: [PATCH 09/15] import --- pytorch_lightning/trainer/properties.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 2298df9c8637b..8d509d41d52bf 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -17,8 +17,7 @@ from argparse import ArgumentParser, Namespace from typing import List, Optional, Union, Type, TypeVar -from pytorch_lightning.callbacks import Callback, ProgressBarBase -from pytorch_lightning.callbacks import ProgressBarBase, ModelCheckpoint +from pytorch_lightning.callbacks import Callback, ProgressBarBase, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector From e1271e1eabfb18d56fc0520e737c2d6ff0509f42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Oct 2020 00:21:15 +0100 Subject: [PATCH 10/15] improve test for resuming --- tests/models/test_restore.py | 49 +++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 9daf112e711b4..8d5cbfc93c02c 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -15,6 +15,7 @@ import logging as log import os import pickle +from copy import deepcopy import cloudpickle import pytest @@ -24,7 +25,7 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils -from pytorch_lightning import Trainer, LightningModule, Callback +from pytorch_lightning import Trainer, LightningModule, Callback, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from tests.base import EvalModelTemplate, GenericEvalModelTemplate, TrialMNIST @@ -51,23 +52,47 @@ def on_train_end(self, trainer, pl_module): self._check_properties(trainer, pl_module) +class CaptureCallbacksBeforeTraining(Callback): + callbacks = [] + + def on_train_start(self, trainer, pl_module): + self.callbacks = deepcopy(trainer.callbacks) + + def test_resume_from_checkpoint(tmpdir): """ Test that properties like `current_epoch` and `global_step` in model and trainer are always the same. """ model = EvalModelTemplate() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) - trainer_args = dict( - default_root_dir=tmpdir, - max_epochs=2, - logger=False, - callbacks=[checkpoint_callback, ModelTrainerPropertyParity()] # this performs the assertions - ) - trainer = Trainer(**trainer_args) + callback_capture = CaptureCallbacksBeforeTraining() + + def get_trainer_args(): + trainer_args = dict( + default_root_dir=tmpdir, + max_epochs=2, + logger=False, + callbacks=[ + ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True), + callback_capture, + ModelTrainerPropertyParity() # this performs the assertions + ] + ) + return trainer_args + + # initial training + trainer = Trainer(**get_trainer_args()) trainer.fit(model) - callbacks_before_resume = trainer.callbacks.copy() - trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt")) + callbacks_before_resume = deepcopy(trainer.callbacks) + + # resumed training + trainer = Trainer(**get_trainer_args(), resume_from_checkpoint=str(tmpdir / "last.ckpt")) trainer.fit(model) - assert trainer.callbacks == callbacks_before_resume + + assert len(callbacks_before_resume) == len(callback_capture.callbacks) + + for before, after in zip(callbacks_before_resume, callback_capture.callbacks): + if isinstance(before, ModelCheckpoint): + assert before.best_model_path == after.best_model_path + assert before.best_model_score == after.best_model_score @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") From 616fa06218bccdf367553808f847aead8ba79045 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Oct 2020 00:28:37 +0100 Subject: [PATCH 11/15] test --- tests/models/test_restore.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 8d5cbfc93c02c..0cf6ddba11cb1 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -52,6 +52,25 @@ def on_train_end(self, trainer, pl_module): self._check_properties(trainer, pl_module) +def test_model_properties_resume_from_checkpoint(tmpdir): + """ Test that properties like `current_epoch` and `global_step` + in model and trainer are always the same. """ + model = EvalModelTemplate() + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) + trainer_args = dict( + default_root_dir=tmpdir, + max_epochs=1, + logger=False, + callbacks=[checkpoint_callback, ModelTrainerPropertyParity()] # this performs the assertions + ) + trainer = Trainer(**trainer_args) + trainer.fit(model) + + trainer_args.update(max_epochs=2) + trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt")) + trainer.fit(model) + + class CaptureCallbacksBeforeTraining(Callback): callbacks = [] @@ -59,9 +78,8 @@ def on_train_start(self, trainer, pl_module): self.callbacks = deepcopy(trainer.callbacks) -def test_resume_from_checkpoint(tmpdir): - """ Test that properties like `current_epoch` and `global_step` - in model and trainer are always the same. """ +def test_callbacks_state_resume_from_checkpoint(tmpdir): + """ Test that resuming from a checkpoint restores callbacks that persist state. """ model = EvalModelTemplate() callback_capture = CaptureCallbacksBeforeTraining() @@ -73,7 +91,6 @@ def get_trainer_args(): callbacks=[ ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True), callback_capture, - ModelTrainerPropertyParity() # this performs the assertions ] ) return trainer_args From c2235b92f09ddf02bed221b8322ff99d53163bb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Oct 2020 00:32:46 +0100 Subject: [PATCH 12/15] update test --- tests/models/test_restore.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 0cf6ddba11cb1..54ae867d844b3 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -84,15 +84,18 @@ def test_callbacks_state_resume_from_checkpoint(tmpdir): callback_capture = CaptureCallbacksBeforeTraining() def get_trainer_args(): + checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) trainer_args = dict( default_root_dir=tmpdir, - max_epochs=2, + max_steps=1, logger=False, callbacks=[ - ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True), + checkpoint, callback_capture, ] ) + assert checkpoint.best_model_path == "" + assert checkpoint.best_model_score == 0 return trainer_args # initial training From d2f8791aa07fae0581241c4b39e7187418def76a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Oct 2020 06:16:23 +0100 Subject: [PATCH 13/15] add references test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- tests/models/test_restore.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 54ae867d844b3..848d6127c4cdb 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -115,6 +115,27 @@ def get_trainer_args(): assert before.best_model_score == after.best_model_score +def test_callbacks_references_resume_from_checkpoint(tmpdir): + """ Test that resuming from a checkpoint sets references as expected. """ + model = EvalModelTemplate() + args = {'default_root_dir': tmpdir, 'max_steps': 1, 'logger': False} + + # initial training + checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) + trainer = Trainer(**args, callbacks=[checkpoint]) + assert checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback + trainer.fit(model) + + # resumed training + new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) + # pass in a new checkpoint object, which should take + # precedence over the one in the last.ckpt file + trainer = Trainer(**args, callbacks=[new_checkpoint], resume_from_checkpoint=str(tmpdir / "last.ckpt")) + assert checkpoint is not new_checkpoint + assert new_checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback + trainer.fit(model) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_running_test_pretrained_model_distrib_dp(tmpdir): """Verify `test()` on pretrained model.""" From fd02011930aa76704035428f719747213f761a7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Oct 2020 22:07:31 +0100 Subject: [PATCH 14/15] docstring suggestion deprecation Co-authored-by: Jeff Yang --- pytorch_lightning/trainer/trainer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index de481f708ed3d..1024b2b5019be 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -169,9 +169,13 @@ def __init__( callbacks: Add a list of callbacks. - checkpoint_callback: If ``True``, will configure a default ModelCheckpoint callback. - Passing a ModelCheckpoint instance to this argument is deprecated since - v1.1.0 and will be unsupported from v1.4.0. Default: ``True`` + checkpoint_callback: If ``True``, enable checkpointing. + It will configure a default ModelCheckpoint callback + if there is no custom ModelCheckpoint callback in `:paramref:~Trainer.callbacks`. + Default: ``True``. + + .. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since + v1.1.0 and will be unsupported from v1.4.0. check_val_every_n_epoch: Check val every n train epochs. From c507898e33192f52e0346cf91aa0e5ad0bc36ec0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Oct 2020 22:18:26 +0100 Subject: [PATCH 15/15] paramref --- pytorch_lightning/trainer/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1024b2b5019be..008633273a0d1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -170,9 +170,8 @@ def __init__( callbacks: Add a list of callbacks. checkpoint_callback: If ``True``, enable checkpointing. - It will configure a default ModelCheckpoint callback - if there is no custom ModelCheckpoint callback in `:paramref:~Trainer.callbacks`. - Default: ``True``. + It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``. .. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since v1.1.0 and will be unsupported from v1.4.0.