Skip to content

Commit

Permalink
deprecate passing ModelCheckpoint instance to Trainer(checkpoint_call…
Browse files Browse the repository at this point in the history
…back=...) (#4336)

* first attempt

* update tests

* support multiple

* test bugfix

* changelog

* pep

* pep

* import order

* import

* improve test for resuming

* test

* update test

* add references test

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* docstring suggestion deprecation

Co-authored-by: Jeff Yang <ydcjeff@outlook.com>

* paramref

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
(cherry picked from commit d1234c5)
  • Loading branch information
awaelchli authored and Borda committed Nov 4, 2020
1 parent e94d48c commit 7d1288b
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 27 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated bool values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))


- Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336))

### Removed


Expand Down
36 changes: 24 additions & 12 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from typing import Union, Optional

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -44,25 +47,31 @@ def on_trainer_init(
# configure checkpoint callback
# it is important that this is the last callback to run
# pass through the required args to figure out defaults
checkpoint_callback = self.init_default_checkpoint_callback(checkpoint_callback)
if checkpoint_callback:
self.trainer.callbacks.append(checkpoint_callback)

# TODO refactor codebase (tests) to not directly reach into these callbacks
self.trainer.checkpoint_callback = checkpoint_callback
self.configure_checkpoint_callbacks(checkpoint_callback)

# init progress bar
self.trainer._progress_bar_callback = self.configure_progress_bar(
progress_bar_refresh_rate, process_position
)

def init_default_checkpoint_callback(self, checkpoint_callback):
if checkpoint_callback is True:
checkpoint_callback = ModelCheckpoint(dirpath=None, filename=None)
elif checkpoint_callback is False:
checkpoint_callback = None
def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]):
if isinstance(checkpoint_callback, ModelCheckpoint):
# TODO: deprecated, remove this block in v1.4.0
rank_zero_warn(
"Passing a ModelCheckpoint instance to Trainer(checkpoint_callbacks=...)"
" is deprecated since v1.1 and will no longer be supported in v1.4.",
DeprecationWarning
)
self.trainer.callbacks.append(checkpoint_callback)

if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
raise MisconfigurationException(
"Trainer was configured with checkpoint_callback=False but found ModelCheckpoint"
" in callbacks list."
)

return checkpoint_callback
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None))

def configure_progress_bar(self, refresh_rate=1, process_position=0):
progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)]
Expand All @@ -83,3 +92,6 @@ def configure_progress_bar(self, refresh_rate=1, process_position=0):
progress_bar_callback = None

return progress_bar_callback

def _trainer_has_checkpoint_callbacks(self):
return len(self.trainer.checkpoint_callbacks) > 0
17 changes: 16 additions & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Type, TypeVar

from pytorch_lightning.callbacks import ProgressBarBase
from pytorch_lightning.callbacks import Callback, ProgressBarBase, ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
Expand Down Expand Up @@ -46,6 +46,7 @@ class TrainerProperties(ABC):
_weights_save_path: str
model_connector: ModelConnector
checkpoint_connector: CheckpointConnector
callbacks: List[Callback]

@property
def use_amp(self) -> bool:
Expand Down Expand Up @@ -187,6 +188,20 @@ def weights_save_path(self) -> str:
return os.path.normpath(self._weights_save_path)
return self._weights_save_path

@property
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
"""
The first checkpoint callback in the Trainer.callbacks list, or ``None`` if
no checkpoint callbacks exist.
"""
callbacks = self.checkpoint_callbacks
return callbacks[0] if len(callbacks) > 0 else None

@property
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
""" A list of all instances of ModelCheckpoint found in the Trainer.callbacks list. """
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]

def save_checkpoint(self, filepath, weights_only: bool = False):
self.checkpoint_connector.save_checkpoint(filepath, weights_only)

Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Trainer(
def __init__(
self,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
checkpoint_callback: Union[ModelCheckpoint, bool] = True,
checkpoint_callback: bool = True,
callbacks: Optional[List[Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: float = 0,
Expand Down Expand Up @@ -169,7 +169,12 @@ def __init__(
callbacks: Add a list of callbacks.
checkpoint_callback: Callback for checkpointing.
checkpoint_callback: If ``True``, enable checkpointing.
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``.
.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
v1.1.0 and will be unsupported from v1.4.0.
check_val_every_n_epoch: Check val every n train epochs.
Expand Down Expand Up @@ -297,7 +302,6 @@ def __init__(

# init callbacks
# Declare attributes to be set in callback_connector on_trainer_init
self.checkpoint_callback: Union[ModelCheckpoint, bool] = checkpoint_callback
self.callback_connector.on_trainer_init(
callbacks,
checkpoint_callback,
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial):
trainer.weights_summary = None # not needed before full run
trainer.logger = DummyLogger()
trainer.callbacks = [] # not needed before full run
trainer.checkpoint_callback = False # required for saving
trainer.limit_train_batches = 1.0
trainer.optimizers, trainer.schedulers = [], [] # required for saving
trainer.model = model # required for saving
Expand All @@ -157,7 +156,6 @@ def __scale_batch_restore_params(trainer):
trainer.weights_summary = trainer.__dumped_params['weights_summary']
trainer.logger = trainer.__dumped_params['logger']
trainer.callbacks = trainer.__dumped_params['callbacks']
trainer.checkpoint_callback = trainer.__dumped_params['checkpoint_callback']
trainer.auto_scale_batch_size = trainer.__dumped_params['auto_scale_batch_size']
trainer.limit_train_batches = trainer.__dumped_params['limit_train_batches']
trainer.model = trainer.__dumped_params['model']
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ def lr_find(
if trainer.progress_bar_callback:
trainer.progress_bar_callback.disable()

# Disable standard checkpoint & early stopping
trainer.checkpoint_callback = False

# Required for saving the model
trainer.optimizers, trainer.schedulers = [], [],
trainer.model = model
Expand Down Expand Up @@ -212,7 +209,6 @@ def __lr_finder_restore_params(trainer, model):
trainer.logger = trainer.__dumped_params['logger']
trainer.callbacks = trainer.__dumped_params['callbacks']
trainer.max_steps = trainer.__dumped_params['max_steps']
trainer.checkpoint_callback = trainer.__dumped_params['checkpoint_callback']
model.configure_optimizers = trainer.__dumped_params['configure_optimizers']
del trainer.__dumped_params

Expand Down
40 changes: 40 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,3 +746,43 @@ def test_filepath_decomposition_dirpath_filename(tmpdir, filepath, dirpath, file

assert mc_cb.dirpath == dirpath
assert mc_cb.filename == filename


def test_configure_model_checkpoint(tmpdir):
""" Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """
kwargs = dict(default_root_dir=tmpdir)
callback1 = ModelCheckpoint()
callback2 = ModelCheckpoint()

# no callbacks
trainer = Trainer(checkpoint_callback=False, callbacks=[], **kwargs)
assert not any(isinstance(c, ModelCheckpoint) for c in trainer.callbacks)
assert trainer.checkpoint_callback is None

# default configuration
trainer = Trainer(checkpoint_callback=True, callbacks=[], **kwargs)
assert len([c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)]) == 1
assert isinstance(trainer.checkpoint_callback, ModelCheckpoint)

# custom callback passed to callbacks list, checkpoint_callback=True is ignored
trainer = Trainer(checkpoint_callback=True, callbacks=[callback1], **kwargs)
assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1]
assert trainer.checkpoint_callback == callback1

# multiple checkpoint callbacks
trainer = Trainer(callbacks=[callback1, callback2], **kwargs)
assert trainer.checkpoint_callback == callback1
assert trainer.checkpoint_callbacks == [callback1, callback2]

with pytest.warns(DeprecationWarning, match='will no longer be supported in v1.4'):
trainer = Trainer(checkpoint_callback=callback1, callbacks=[], **kwargs)
assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1]
assert trainer.checkpoint_callback == callback1

with pytest.warns(DeprecationWarning, match="will no longer be supported in v1.4"):
trainer = Trainer(checkpoint_callback=callback1, callbacks=[callback2], **kwargs)
assert trainer.checkpoint_callback == callback2
assert trainer.checkpoint_callbacks == [callback2, callback1]

with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"):
Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs)
77 changes: 72 additions & 5 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging as log
import os
import pickle
from copy import deepcopy

import cloudpickle
import pytest
Expand All @@ -24,7 +25,7 @@

import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer, LightningModule, Callback
from pytorch_lightning import Trainer, LightningModule, Callback, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from tests.base import EvalModelTemplate, GenericEvalModelTemplate, TrialMNIST

Expand All @@ -51,24 +52,90 @@ def on_train_end(self, trainer, pl_module):
self._check_properties(trainer, pl_module)


def test_resume_from_checkpoint(tmpdir):
def test_model_properties_resume_from_checkpoint(tmpdir):
""" Test that properties like `current_epoch` and `global_step`
in model and trainer are always the same. """
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
trainer_args = dict(
default_root_dir=tmpdir,
max_epochs=2,
max_epochs=1,
logger=False,
checkpoint_callback=checkpoint_callback,
callbacks=[ModelTrainerPropertyParity()] # this performs the assertions
callbacks=[checkpoint_callback, ModelTrainerPropertyParity()] # this performs the assertions
)
trainer = Trainer(**trainer_args)
trainer.fit(model)

trainer_args.update(max_epochs=2)
trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt"))
trainer.fit(model)


class CaptureCallbacksBeforeTraining(Callback):
callbacks = []

def on_train_start(self, trainer, pl_module):
self.callbacks = deepcopy(trainer.callbacks)


def test_callbacks_state_resume_from_checkpoint(tmpdir):
""" Test that resuming from a checkpoint restores callbacks that persist state. """
model = EvalModelTemplate()
callback_capture = CaptureCallbacksBeforeTraining()

def get_trainer_args():
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
trainer_args = dict(
default_root_dir=tmpdir,
max_steps=1,
logger=False,
callbacks=[
checkpoint,
callback_capture,
]
)
assert checkpoint.best_model_path == ""
assert checkpoint.best_model_score == 0
return trainer_args

# initial training
trainer = Trainer(**get_trainer_args())
trainer.fit(model)
callbacks_before_resume = deepcopy(trainer.callbacks)

# resumed training
trainer = Trainer(**get_trainer_args(), resume_from_checkpoint=str(tmpdir / "last.ckpt"))
trainer.fit(model)

assert len(callbacks_before_resume) == len(callback_capture.callbacks)

for before, after in zip(callbacks_before_resume, callback_capture.callbacks):
if isinstance(before, ModelCheckpoint):
assert before.best_model_path == after.best_model_path
assert before.best_model_score == after.best_model_score


def test_callbacks_references_resume_from_checkpoint(tmpdir):
""" Test that resuming from a checkpoint sets references as expected. """
model = EvalModelTemplate()
args = {'default_root_dir': tmpdir, 'max_steps': 1, 'logger': False}

# initial training
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
trainer = Trainer(**args, callbacks=[checkpoint])
assert checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback
trainer.fit(model)

# resumed training
new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
# pass in a new checkpoint object, which should take
# precedence over the one in the last.ckpt file
trainer = Trainer(**args, callbacks=[new_checkpoint], resume_from_checkpoint=str(tmpdir / "last.ckpt"))
assert checkpoint is not new_checkpoint
assert new_checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback
trainer.fit(model)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_running_test_pretrained_model_distrib_dp(tmpdir):
"""Verify `test()` on pretrained model."""
Expand Down
6 changes: 6 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_tbd_remove_in_v1_4_0(tmpdir):
with pytest.deprecated_call(match='will no longer be supported in v1.4'):
callback = ModelCheckpoint()
Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir)


def test_tbd_remove_in_v1_2_0():
with pytest.deprecated_call(match='will be removed in v1.2'):
checkpoint_cb = ModelCheckpoint(filepath='.')
Expand Down

0 comments on commit 7d1288b

Please sign in to comment.