Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecate passing ModelCheckpoint instance to Trainer(checkpoint_callback=...) #4336

Merged
merged 17 commits into from
Oct 30, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated bool values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))


- Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336))

### Removed


Expand Down
36 changes: 24 additions & 12 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from typing import Union, Optional

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -44,25 +47,31 @@ def on_trainer_init(
# configure checkpoint callback
# it is important that this is the last callback to run
# pass through the required args to figure out defaults
checkpoint_callback = self.init_default_checkpoint_callback(checkpoint_callback)
if checkpoint_callback:
self.trainer.callbacks.append(checkpoint_callback)

# TODO refactor codebase (tests) to not directly reach into these callbacks
self.trainer.checkpoint_callback = checkpoint_callback
self.configure_checkpoint_callbacks(checkpoint_callback)

# init progress bar
self.trainer._progress_bar_callback = self.configure_progress_bar(
progress_bar_refresh_rate, process_position
)

def init_default_checkpoint_callback(self, checkpoint_callback):
if checkpoint_callback is True:
checkpoint_callback = ModelCheckpoint(dirpath=None, filename=None)
elif checkpoint_callback is False:
checkpoint_callback = None
def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]):
if isinstance(checkpoint_callback, ModelCheckpoint):
# TODO: deprecated, remove this block in v1.4.0
rank_zero_warn(
"Passing a ModelCheckpoint instance to Trainer(checkpoint_callbacks=...)"
" is deprecated since v1.1 and will no longer be supported in v1.4.",
DeprecationWarning
)
self.trainer.callbacks.append(checkpoint_callback)

if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
raise MisconfigurationException(
"Trainer was configured with checkpoint_callback=False but found ModelCheckpoint"
" in callbacks list."
)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

return checkpoint_callback
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None))

def configure_progress_bar(self, refresh_rate=1, process_position=0):
progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)]
Expand All @@ -83,3 +92,6 @@ def configure_progress_bar(self, refresh_rate=1, process_position=0):
progress_bar_callback = None

return progress_bar_callback

def _trainer_has_checkpoint_callbacks(self):
return len(self.trainer.checkpoint_callbacks) > 0
17 changes: 16 additions & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Type, TypeVar

from pytorch_lightning.callbacks import ProgressBarBase
from pytorch_lightning.callbacks import Callback, ProgressBarBase, ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
Expand Down Expand Up @@ -46,6 +46,7 @@ class TrainerProperties(ABC):
_weights_save_path: str
model_connector: ModelConnector
checkpoint_connector: CheckpointConnector
callbacks: List[Callback]

@property
def use_amp(self) -> bool:
Expand Down Expand Up @@ -187,6 +188,20 @@ def weights_save_path(self) -> str:
return os.path.normpath(self._weights_save_path)
return self._weights_save_path

@property
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
"""
The first checkpoint callback in the Trainer.callbacks list, or ``None`` if
no checkpoint callbacks exist.
"""
callbacks = self.checkpoint_callbacks
return callbacks[0] if len(callbacks) > 0 else None

@property
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
""" A list of all instances of ModelCheckpoint found in the Trainer.callbacks list. """
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]

def save_checkpoint(self, filepath, weights_only: bool = False):
self.checkpoint_connector.save_checkpoint(filepath, weights_only)

Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Trainer(
def __init__(
self,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
checkpoint_callback: Union[ModelCheckpoint, bool] = True,
checkpoint_callback: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Why can we not just remove the checkpoint_callback argument all together?
A: Lightning philosophy is to provide good defaults, this includes checkpointing. We need a way to turn it off though, i.e., checkpoint_callback=False

What about keeping the checkpoint_callback parameter but changing its default value to False? I think it will be pretty annoying to have to set checkpoint_callback=False every time you pass a custom ModelCheckpoint via callbacks. And I think most people use a custom ModelCheckpoint instead of just checkpoint_callback=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about keeping the checkpoint_callback parameter but changing its default value to False?

doesn't solve the problem I'm trying to solve here, which is eliminate ambiguity when restoring the state of trainer. see answer of 2nd FAQ question.

pretty annoying to have to set checkpoint_callback=False every time you pass a custom ModelCheckpoint

with this PR proposal, the value will be ignored if you pass in a custom one. False is only needed when you want to disable checkpointing completely. I believe I have this covered in a test.

Copy link
Contributor

@ananthsub ananthsub Oct 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i agree with @carmocca , this is super confusing when adding my own checkpoint callback. given how loose the default checkpoint callback is, and with the coming customizations, I'd rather drop the checkpoint_callback arg altogether and force everything to be configured through callbacks. Given the callback implementation already exists, I personally don't think it's much of a request for people to instantiate the checkpoint callback (and confirm their settings while doing so) and pass it along to the trainer.

I also think that's a nice message for users: "See how extensible this framework is" vs "look at all the magic this trainer configures for you which you can't change"

Even if that's not in this PR, it feels inevitable that checkpoint_callback=False would eventually be the new default and then later we could drop the arg altogether

Copy link
Contributor Author

@awaelchli awaelchli Oct 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes fine with me, I don't have strong preference here.
Note that this PR does NOT close the disussion on whether there should be a checkpoint_callback arg or not.
I'm simply restricting what can be passed to the argument.

It looks like a lot of api change, but it is really more a bugfix.

callbacks: Optional[List[Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: float = 0,
Expand Down Expand Up @@ -169,7 +169,12 @@ def __init__(

callbacks: Add a list of callbacks.

checkpoint_callback: Callback for checkpointing.
checkpoint_callback: If ``True``, enable checkpointing.
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``.

.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
v1.1.0 and will be unsupported from v1.4.0.

check_val_every_n_epoch: Check val every n train epochs.

Expand Down Expand Up @@ -297,7 +302,6 @@ def __init__(

# init callbacks
# Declare attributes to be set in callback_connector on_trainer_init
self.checkpoint_callback: Union[ModelCheckpoint, bool] = checkpoint_callback
self.callback_connector.on_trainer_init(
callbacks,
checkpoint_callback,
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial):
trainer.weights_summary = None # not needed before full run
trainer.logger = DummyLogger()
trainer.callbacks = [] # not needed before full run
trainer.checkpoint_callback = False # required for saving
Copy link
Contributor Author

@awaelchli awaelchli Oct 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed these from Tuner because ModelCheckpoint now entirely lives in callbacks list, and this is properly backed up by Tuner already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trainer.limit_train_batches = 1.0
trainer.optimizers, trainer.schedulers = [], [] # required for saving
trainer.model = model # required for saving
Expand All @@ -157,7 +156,6 @@ def __scale_batch_restore_params(trainer):
trainer.weights_summary = trainer.__dumped_params['weights_summary']
trainer.logger = trainer.__dumped_params['logger']
trainer.callbacks = trainer.__dumped_params['callbacks']
trainer.checkpoint_callback = trainer.__dumped_params['checkpoint_callback']
trainer.auto_scale_batch_size = trainer.__dumped_params['auto_scale_batch_size']
trainer.limit_train_batches = trainer.__dumped_params['limit_train_batches']
trainer.model = trainer.__dumped_params['model']
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ def lr_find(
if trainer.progress_bar_callback:
trainer.progress_bar_callback.disable()

# Disable standard checkpoint & early stopping
trainer.checkpoint_callback = False

# Required for saving the model
trainer.optimizers, trainer.schedulers = [], [],
trainer.model = model
Expand Down Expand Up @@ -212,7 +209,6 @@ def __lr_finder_restore_params(trainer, model):
trainer.logger = trainer.__dumped_params['logger']
trainer.callbacks = trainer.__dumped_params['callbacks']
trainer.max_steps = trainer.__dumped_params['max_steps']
trainer.checkpoint_callback = trainer.__dumped_params['checkpoint_callback']
model.configure_optimizers = trainer.__dumped_params['configure_optimizers']
del trainer.__dumped_params

Expand Down
40 changes: 40 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,3 +743,43 @@ def test_filepath_decomposition_dirpath_filename(tmpdir, filepath, dirpath, file

assert mc_cb.dirpath == dirpath
assert mc_cb.filename == filename


def test_configure_model_checkpoint(tmpdir):
""" Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """
kwargs = dict(default_root_dir=tmpdir)
callback1 = ModelCheckpoint()
callback2 = ModelCheckpoint()

# no callbacks
trainer = Trainer(checkpoint_callback=False, callbacks=[], **kwargs)
assert not any(isinstance(c, ModelCheckpoint) for c in trainer.callbacks)
assert trainer.checkpoint_callback is None

# default configuration
trainer = Trainer(checkpoint_callback=True, callbacks=[], **kwargs)
assert len([c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)]) == 1
assert isinstance(trainer.checkpoint_callback, ModelCheckpoint)

# custom callback passed to callbacks list, checkpoint_callback=True is ignored
trainer = Trainer(checkpoint_callback=True, callbacks=[callback1], **kwargs)
assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1]
assert trainer.checkpoint_callback == callback1

# multiple checkpoint callbacks
trainer = Trainer(callbacks=[callback1, callback2], **kwargs)
assert trainer.checkpoint_callback == callback1
assert trainer.checkpoint_callbacks == [callback1, callback2]

with pytest.warns(DeprecationWarning, match='will no longer be supported in v1.4'):
trainer = Trainer(checkpoint_callback=callback1, callbacks=[], **kwargs)
assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1]
assert trainer.checkpoint_callback == callback1

with pytest.warns(DeprecationWarning, match="will no longer be supported in v1.4"):
trainer = Trainer(checkpoint_callback=callback1, callbacks=[callback2], **kwargs)
assert trainer.checkpoint_callback == callback2
assert trainer.checkpoint_callbacks == [callback2, callback1]

with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"):
Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs)
77 changes: 72 additions & 5 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging as log
import os
import pickle
from copy import deepcopy

import cloudpickle
import pytest
Expand All @@ -24,7 +25,7 @@

import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer, LightningModule, Callback
from pytorch_lightning import Trainer, LightningModule, Callback, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from tests.base import EvalModelTemplate, GenericEvalModelTemplate, TrialMNIST

Expand All @@ -51,24 +52,90 @@ def on_train_end(self, trainer, pl_module):
self._check_properties(trainer, pl_module)


def test_resume_from_checkpoint(tmpdir):
def test_model_properties_resume_from_checkpoint(tmpdir):
""" Test that properties like `current_epoch` and `global_step`
in model and trainer are always the same. """
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
trainer_args = dict(
default_root_dir=tmpdir,
max_epochs=2,
max_epochs=1,
logger=False,
checkpoint_callback=checkpoint_callback,
callbacks=[ModelTrainerPropertyParity()] # this performs the assertions
callbacks=[checkpoint_callback, ModelTrainerPropertyParity()] # this performs the assertions
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
)
trainer = Trainer(**trainer_args)
trainer.fit(model)

trainer_args.update(max_epochs=2)
trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt"))
trainer.fit(model)


class CaptureCallbacksBeforeTraining(Callback):
callbacks = []

def on_train_start(self, trainer, pl_module):
self.callbacks = deepcopy(trainer.callbacks)


def test_callbacks_state_resume_from_checkpoint(tmpdir):
""" Test that resuming from a checkpoint restores callbacks that persist state. """
model = EvalModelTemplate()
callback_capture = CaptureCallbacksBeforeTraining()

def get_trainer_args():
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
trainer_args = dict(
default_root_dir=tmpdir,
max_steps=1,
logger=False,
callbacks=[
checkpoint,
callback_capture,
]
)
assert checkpoint.best_model_path == ""
assert checkpoint.best_model_score == 0
return trainer_args

# initial training
trainer = Trainer(**get_trainer_args())
trainer.fit(model)
callbacks_before_resume = deepcopy(trainer.callbacks)

# resumed training
trainer = Trainer(**get_trainer_args(), resume_from_checkpoint=str(tmpdir / "last.ckpt"))
trainer.fit(model)

assert len(callbacks_before_resume) == len(callback_capture.callbacks)

for before, after in zip(callbacks_before_resume, callback_capture.callbacks):
if isinstance(before, ModelCheckpoint):
assert before.best_model_path == after.best_model_path
assert before.best_model_score == after.best_model_score


awaelchli marked this conversation as resolved.
Show resolved Hide resolved
def test_callbacks_references_resume_from_checkpoint(tmpdir):
""" Test that resuming from a checkpoint sets references as expected. """
model = EvalModelTemplate()
args = {'default_root_dir': tmpdir, 'max_steps': 1, 'logger': False}

# initial training
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
trainer = Trainer(**args, callbacks=[checkpoint])
assert checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback
trainer.fit(model)

# resumed training
new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
# pass in a new checkpoint object, which should take
# precedence over the one in the last.ckpt file
trainer = Trainer(**args, callbacks=[new_checkpoint], resume_from_checkpoint=str(tmpdir / "last.ckpt"))
assert checkpoint is not new_checkpoint
assert new_checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback
trainer.fit(model)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_running_test_pretrained_model_distrib_dp(tmpdir):
"""Verify `test()` on pretrained model."""
Expand Down
6 changes: 6 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_tbd_remove_in_v1_4_0(tmpdir):
with pytest.deprecated_call(match='will no longer be supported in v1.4'):
callback = ModelCheckpoint()
Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir)


def test_tbd_remove_in_v1_2_0():
with pytest.deprecated_call(match='will be removed in v1.2'):
checkpoint_cb = ModelCheckpoint(filepath='.')
Expand Down