Skip to content

Commit

Permalink
Add checkpoint parameter to on_save_checkpoint (#6072)
Browse files Browse the repository at this point in the history
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
  • Loading branch information
2 people authored and lexierule committed Mar 5, 2021
1 parent 4b71a83 commit 9329f58
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 38 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added


- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))


### Changed


Expand Down
24 changes: 19 additions & 5 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import abc
from typing import Any
from typing import Any, Dict

from pytorch_lightning.core.lightning import LightningModule

Expand Down Expand Up @@ -177,12 +177,26 @@ def on_keyboard_interrupt(self, trainer, pl_module: LightningModule) -> None:
"""Called when the training is interrupted by ``KeyboardInterrupt``."""
pass

def on_save_checkpoint(self, trainer, pl_module: LightningModule) -> None:
"""Called when saving a model checkpoint, use to persist state."""
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict:
"""
Called when saving a model checkpoint, use to persist state.
Args:
trainer: the current Trainer instance.
pl_module: the current LightningModule instance.
checkpoint: the checkpoint dictionary that will be saved.
Returns:
The callback state.
"""
pass

def on_load_checkpoint(self, checkpointed_state) -> None:
"""Called when loading a model checkpoint, use to reload state."""
def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint, use to reload state.
Args:
callback_state: the callback state returned by ``on_save_checkpoint``.
"""
pass

def on_after_backward(self, trainer, pl_module: LightningModule) -> None:
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Monitor a metric and stop training when it stops improving.
"""
from typing import Any, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -140,19 +141,19 @@ def _validate_condition_metric(self, logs):
def monitor_op(self):
return self.mode_dict[self.mode]

def on_save_checkpoint(self, trainer, pl_module):
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
return {
'wait_count': self.wait_count,
'stopped_epoch': self.stopped_epoch,
'best_score': self.best_score,
'patience': self.patience
}

def on_load_checkpoint(self, checkpointed_state):
self.wait_count = checkpointed_state['wait_count']
self.stopped_epoch = checkpointed_state['stopped_epoch']
self.best_score = checkpointed_state['best_score']
self.patience = checkpointed_state['patience']
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
self.wait_count = callback_state['wait_count']
self.stopped_epoch = callback_state['stopped_epoch']
self.best_score = callback_state['best_score']
self.patience = callback_state['patience']

def on_validation_end(self, trainer, pl_module):
if trainer.running_sanity_check:
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def on_validation_end(self, trainer, pl_module):
"""
self.save_checkpoint(trainer, pl_module)

def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
return {
"monitor": self.monitor,
"best_model_score": self.best_model_score,
Expand All @@ -220,9 +220,9 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
"dirpath": self.dirpath
}

def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.best_model_score = checkpointed_state["best_model_score"]
self.best_model_path = checkpointed_state["best_model_path"]
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
self.best_model_score = callback_state["best_model_score"]
self.best_model_path = callback_state["best_model_path"]

def save_checkpoint(self, trainer, pl_module):
"""
Expand Down
27 changes: 22 additions & 5 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

from abc import ABC
from copy import deepcopy
from typing import List
from inspect import signature
from typing import List, Dict, Any, Type, Callable

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn


class TrainerCallbackHookMixin(ABC):
Expand Down Expand Up @@ -197,14 +199,29 @@ def on_keyboard_interrupt(self):
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.lightning_module)

def on_save_checkpoint(self):
@staticmethod
def __is_old_signature(fn: Callable) -> bool:
parameters = list(signature(fn).parameters)
if len(parameters) == 2 and parameters[1] != "args":
return True
return False

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
"""Called when saving a model checkpoint."""
callback_states = {}
for callback in self.callbacks:
callback_class = type(callback)
state = callback.on_save_checkpoint(self, self.lightning_module)
if self.__is_old_signature(callback.on_save_checkpoint):
rank_zero_warn(
"`Callback.on_save_checkpoint` signature has changed in v1.3."
" A `checkpoint` parameter has been added."
" Support for the old signature will be removed in v1.5",
DeprecationWarning
)
state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled
else:
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
callback_states[callback_class] = state
callback_states[type(callback)] = state
return callback_states

def on_load_checkpoint(self, checkpoint):
Expand Down
14 changes: 5 additions & 9 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,18 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
if not has_reached_max_steps:
current_epoch += 1

model = self.trainer.lightning_module

checkpoint = {
'epoch': current_epoch,
'global_step': global_step,
'pytorch-lightning_version': pytorch_lightning.__version__,
'state_dict': model.state_dict(),
}

if not weights_only:

# dump callbacks
callback_states = self.trainer.on_save_checkpoint()
checkpoint['callbacks'] = callback_states
checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint)

optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
Expand All @@ -305,12 +306,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
elif self.trainer.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = amp.state_dict()

# add the hyper_parameters and state_dict from the model
model = self.trainer.lightning_module

# dump the module_arguments and state_dict from the model
checkpoint['state_dict'] = model.state_dict()

# dump hyper-parameters
if model.hparams:
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_trainer_callback_system(torch_save, tmpdir):
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, 'fit'),
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def __init__(self, expected_state, *args, **kwargs):

def on_train_start(self, trainer, pl_module):
if self.expected_state:
assert self.on_save_checkpoint(trainer, pl_module) == self.expected_state
assert self.on_save_checkpoint(trainer, pl_module, {}) == self.expected_state

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.saved_states.append(self.on_save_checkpoint(trainer, pl_module).copy())
self.saved_states.append(self.on_save_checkpoint(trainer, pl_module, {}).copy())


def test_resume_early_stopping_from_checkpoint(tmpdir):
Expand Down
4 changes: 2 additions & 2 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,9 @@ def __init__(self, expected_count, *args, **kwargs):
def on_train_start(self, trainer, pl_module):
torch.save = Mock(wraps=torch.save)

def on_save_checkpoint(self, trainer, pl_module):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
# expect all ranks to run but only rank 0 will actually write the checkpoint file
super().on_save_checkpoint(trainer, pl_module)
super().on_save_checkpoint(trainer, pl_module, checkpoint)
self.on_save_checkpoint_count += 1

def on_train_end(self, trainer, pl_module):
Expand Down
4 changes: 2 additions & 2 deletions tests/deprecated_api/test_remove_1-4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in vX.Y.Z"""
"""Test deprecated functionality which will be removed in v1.4.0"""
import sys

import pytest
Expand Down Expand Up @@ -243,5 +243,5 @@ def training_step(self, batch, batch_idx):

trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=True, max_epochs=1)

with pytest.warns(DeprecationWarning, match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
trainer.fit(TestModel())
56 changes: 56 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.5.0"""

import pytest

from pytorch_lightning import Trainer, Callback
from tests.helpers import BoringModel
from tests.helpers.utils import no_warning_call


def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
class OldSignature(Callback):
def on_save_checkpoint(self, trainer, pl_module): # noqa
...

model = BoringModel()
trainer_kwargs = {
"default_root_dir": tmpdir,
"checkpoint_callback": False,
"max_epochs": 1,
}
filepath = tmpdir / "test.ckpt"

trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()])
trainer.fit(model)

with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.save_checkpoint(filepath)

class NewSignature(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
...

class ValidSignature1(Callback):
def on_save_checkpoint(self, trainer, *args):
...

class ValidSignature2(Callback):
def on_save_checkpoint(self, *args):
...

trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()]
with no_warning_call(DeprecationWarning):
trainer.save_checkpoint(filepath)
19 changes: 19 additions & 0 deletions tests/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
import functools
import os
import traceback
from contextlib import contextmanager
from typing import Optional

import pytest

from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -111,3 +115,18 @@ def inner_f(queue, **kwargs):
assert result == 1, 'expected 1, but returned %s' % result

return wrapper


@contextmanager
def no_warning_call(warning_type, match: Optional[str] = None):
with pytest.warns(None) as record:
yield

try:
w = record.pop(warning_type)
if not ((match and match in w.text) or w):
return
except AssertionError:
# no warning raised
return
raise AssertionError(f"`{warning_type}` was raised: {w}")
4 changes: 2 additions & 2 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def test_checkpoint_callbacks_are_last(tmpdir):

class StatefulCallback0(Callback):

def on_save_checkpoint(self, trainer, pl_module):
def on_save_checkpoint(self, *args):
return {"content0": 0}


class StatefulCallback1(Callback):

def on_save_checkpoint(self, trainer, pl_module):
def on_save_checkpoint(self, *args):
return {"content1": 1}


Expand Down

0 comments on commit 9329f58

Please sign in to comment.