diff --git a/CHANGELOG.md b/CHANGELOG.md index bb0471e969c22..2ce2525d90bfa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) +- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) @@ -58,6 +59,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index ec257bf444f5c..f6deb9adf58d3 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -178,12 +178,14 @@ Under the hood, Lightning does the following (pseudocode): loss = training_step(batch) losses.append(loss.detach()) + # clear gradients + optimizer.zero_grad() + # backward loss.backward() - # apply and clear grads + # update parameters optimizer.step() - optimizer.zero_grad() Training epoch-level metrics @@ -212,12 +214,14 @@ Here's the pseudocode of what it does under the hood: # forward out = training_step(val_batch) + # clear gradients + optimizer.zero_grad() + # backward loss.backward() - # apply and clear grads + # update parameters optimizer.step() - optimizer.zero_grad() epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs])) @@ -247,12 +251,14 @@ The matching pseudocode is: # forward out = training_step(val_batch) + # clear gradients + optimizer.zero_grad() + # backward loss.backward() - # apply and clear grads + # update parameters optimizer.step() - optimizer.zero_grad() training_epoch_end(outs) @@ -946,9 +952,9 @@ When set to ``False``, Lightning does not automate the optimization process. Thi opt = self.optimizers(use_pl_optimizer=True) loss = ... + opt.zero_grad() self.manual_backward(loss) opt.step() - opt.zero_grad() This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research. @@ -1048,11 +1054,13 @@ This is the pseudocode to describe how all the hooks are called during a call to loss = out.loss + on_before_zero_grad() + optimizer_zero_grad() + backward() on_after_backward() + optimizer_step() - on_before_zero_grad() - optimizer_zero_grad() on_train_batch_end(out) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 10c7c5ad59bfd..5614e481e0888 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -75,12 +75,14 @@ Here's the pseudocode for what the trainer does under the hood (showing the trai # train step loss = training_step(batch) + # clear gradients + optimizer.zero_grad() + # backward loss.backward() - # apply and clear grads + # update parameters optimizer.step() - optimizer.zero_grad() losses.append(loss) diff --git a/notebooks/06-mnist-tpu-training.ipynb b/notebooks/06-mnist-tpu-training.ipynb index 5d55388b2efcf..a0dfdceece9b1 100644 --- a/notebooks/06-mnist-tpu-training.ipynb +++ b/notebooks/06-mnist-tpu-training.ipynb @@ -80,7 +80,7 @@ "id": "AYGWh10lRaF1" }, "source": [ - "! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl" + "! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl" ], "execution_count": null, "outputs": [] diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index cc716a9b99ad7..ac9dec5502a5a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -94,8 +94,25 @@ class ModelCheckpoint(Callback): save_weights_only: if ``True``, then only the model's weights will be saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). + every_n_train_steps: Number of training steps between checkpoints. + If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training + To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative. + This must be mutually exclusive with ``every_n_val_epochs``. + every_n_val_epochs: Number of validation epochs between checkpoints. + If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end + To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative. + This must be mutually exclusive with ``every_n_train_steps``. + Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and + ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` + will only save checkpoints at epochs 0 < E <= N + where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E. period: Interval (number of epochs) between checkpoints. + .. warning:: + This argument has been deprecated in v1.3 and will be removed in v1.5. + + Use ``every_n_val_epochs`` instead. + Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -166,8 +183,10 @@ def __init__( save_top_k: Optional[int] = None, save_weights_only: bool = False, mode: str = "min", - period: int = 1, - auto_insert_metric_name: bool = True + auto_insert_metric_name: bool = True, + every_n_train_steps: Optional[int] = None, + every_n_val_epochs: Optional[int] = None, + period: Optional[int] = None, ): super().__init__() self.monitor = monitor @@ -175,7 +194,6 @@ def __init__( self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only - self.period = period self.auto_insert_metric_name = auto_insert_metric_name self._last_global_step_saved = -1 self.current_score = None @@ -189,6 +207,7 @@ def __init__( self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) + self.__init_triggers(every_n_train_steps, every_n_val_epochs, period) self.__validate_init_configuration() def on_pretrain_routine_start(self, trainer, pl_module): @@ -198,10 +217,26 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.__resolve_ckpt_dir(trainer) self.save_function = trainer.save_checkpoint - def on_validation_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, *args, **kwargs) -> None: + """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """ + if self._should_skip_saving_checkpoint(trainer): + return + step = trainer.global_step + skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) + if skip_batch: + return + self.save_checkpoint(trainer) + + def on_validation_end(self, trainer, *args, **kwargs) -> None: """ checkpoints can be saved at the end of the val loop """ + skip = ( + self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1 + or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0 + ) + if skip: + return self.save_checkpoint(trainer) def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: @@ -229,20 +264,8 @@ def save_checkpoint(self, trainer, unused: Optional = None): " has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning ) - epoch = trainer.current_epoch global_step = trainer.global_step - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.state != TrainerState.FITTING # don't save anything during non-fit - or trainer.sanity_checking # don't save anything during sanity check - or self.period < 1 # no models are saved - or (epoch + 1) % self.period # skip epoch - or self._last_global_step_saved == global_step # already saved at the last step - ): - return - self._add_backward_monitor_support(trainer) self._validate_monitor_key(trainer) @@ -265,9 +288,32 @@ def save_checkpoint(self, trainer, unused: Optional = None): if trainer.is_global_zero and trainer.logger and hasattr(trainer.logger, 'after_save_checkpoint'): trainer.logger.after_save_checkpoint(proxy(self)) + def _should_skip_saving_checkpoint(self, trainer) -> bool: + from pytorch_lightning.trainer.states import TrainerState + return ( + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.state != TrainerState.FITTING # don't save anything during non-fit + or trainer.sanity_checking # don't save anything during sanity check + or self._last_global_step_saved == trainer.global_step # already saved at the last step + ) + def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') + if self._every_n_train_steps < 0: + raise MisconfigurationException( + f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0' + ) + if self._every_n_val_epochs < 0: + raise MisconfigurationException( + f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0' + ) + if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0: + raise MisconfigurationException( + f'Invalid values for every_n_train_steps={self._every_n_train_steps}' + ' and every_n_val_epochs={self._every_n_val_epochs}.' + ' Both cannot be enabled at the same time.' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): @@ -314,6 +360,46 @@ def __init_monitor_mode(self, monitor, mode): self.kth_value, self.mode = mode_dict[mode] + def __init_triggers( + self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int] + ) -> None: + + # Default to running once after each validation epoch if neither + # every_n_train_steps nor every_n_val_epochs is set + if every_n_train_steps is None and every_n_val_epochs is None: + self._every_n_val_epochs = 1 + self._every_n_train_steps = 0 + log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") + else: + self._every_n_val_epochs = every_n_val_epochs or 0 + self._every_n_train_steps = every_n_train_steps or 0 + + # period takes precedence over every_n_val_epochs for backwards compatibility + if period is not None: + rank_zero_warn( + 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + self._every_n_val_epochs = period + + self._period = self._every_n_val_epochs + + @property + def period(self) -> Optional[int]: + rank_zero_warn( + 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + return self._period + + @period.setter + def period(self, value: Optional[int]) -> None: + rank_zero_warn( + 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + self._period = value + @rank_zero_only def _del_model(self, filepath: str): if self._fs.exists(filepath): @@ -427,11 +513,8 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], """ filename = self._format_checkpoint_name( - self.filename, - epoch, - step, - metrics, - auto_insert_metric_name=self.auto_insert_metric_name) + self.filename, epoch, step, metrics, auto_insert_metric_name=self.auto_insert_metric_name + ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) @@ -586,9 +669,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A self._save_model(trainer, filepath) if ( - self.save_top_k is None - and self.best_model_path - and self.best_model_path != filepath + self.save_top_k is None and self.best_model_path and self.best_model_path != filepath and trainer.is_global_zero ): self._del_model(self.best_model_path) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 95bd8b3f8cc44..83505913d0186 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,13 +16,14 @@ import platform from abc import ABC from copy import deepcopy -from typing import Callable, Iterable, List, Optional, Tuple, Union +from typing import Callable, Iterable, List, Tuple, Union from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.core import LightningModule +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -36,8 +37,6 @@ class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - global_rank: int - shown_warnings:... val_check_interval: float tpu_local_core_rank: int train_dataloader: DataLoader @@ -48,13 +47,10 @@ class TrainerDataLoadingMixin(ABC): test_dataloaders: List[DataLoader] num_test_batches: List[Union[int, float]] limit_train_batches: Union[int, float] - limit_val_batches: Union[int, float] - limit_test_batches: Union[int, float] - replace_sampler_ddp: bool + overfit_batches: Union[int, float] + distributed_sampler_kwargs: dict accelerator: Accelerator - num_nodes: int - num_processes: int - distributed_backend: Optional[str] + accelerator_connector: AcceleratorConnector dev_debugger: InternalDebugger def _worker_check(self, dataloader: DataLoader, name: str) -> None: diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 70db8b36814ca..69d3887fc7718 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -19,8 +19,6 @@ class DeprecatedDistDeviceAttributes: - _distrib_type: DistributedType - _device_type: DeviceType num_gpus: int accelerator_connector: AcceleratorConnector @@ -135,7 +133,7 @@ def use_single_gpu(self, val: bool) -> None: class DeprecatedTrainerAttributes: accelerator: Accelerator - lightning_module = LightningModule + lightning_module: LightningModule sanity_checking: bool @property diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 16060f863884c..6a036b9fcec49 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -14,12 +14,11 @@ import inspect from abc import ABC -from typing import Mapping, Union +from typing import Mapping import torch -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities import DeviceType, DistributedType +from pytorch_lightning.utilities import DistributedType from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.memory import recursive_detach @@ -28,17 +27,8 @@ class TrainerLoggingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - current_epoch: int - _device_type: DeviceType _distrib_type: DistributedType - log_gpu_memory:... - logger: Union[LightningLoggerBase, bool] - global_step: int - global_rank: int - default_root_dir: str - slurm_job_id: int num_gpus: int - logged_metrics:... def metrics_to_scalars(self, metrics): new_metrics = {} diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 54731977cbee9..2795dd4f0af30 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -29,10 +29,7 @@ class TrainerTrainingTricksMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - default_root_dir: str lightning_module: LightningModule - progress_bar_callback:... - on_gpu: bool def print_nan_gradients(self) -> None: model = self.lightning_module diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 4a8088070f041..e5583b9bbdf86 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -434,11 +434,8 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): # auto_insert_metric_name=False ckpt_name = ModelCheckpoint._format_checkpoint_name( - 'epoch={epoch:03d}-val_acc={val/acc}', - 3, - 2, - {'val/acc': 0.03}, - auto_insert_metric_name=False) + 'epoch={epoch:03d}-val_acc={val/acc}', 3, 2, {'val/acc': 0.03}, auto_insert_metric_name=False + ) assert ckpt_name == 'epoch=003-val_acc=0.03' @@ -524,6 +521,45 @@ def test_none_monitor_save_last(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_last=False) +def test_invalid_every_n_val_epochs(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + +def test_invalid_every_n_train_steps(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + +def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir): + """ + Test that a MisconfigurationException is raised if both + every_n_val_epochs and every_n_train_steps are enabled together. + """ + with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0) + + +def test_none_every_n_train_steps_val_epochs(tmpdir): + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) + assert checkpoint_callback.period == 1 + assert checkpoint_callback._every_n_val_epochs == 1 + assert checkpoint_callback._every_n_train_steps == 0 + + def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): """ Test that it is possible to save all checkpoints when monitor=None. """ seed_everything() @@ -567,9 +603,8 @@ def test_model_checkpoint_period(tmpdir, period: int): default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, - limit_train_batches=0.1, - limit_val_batches=0.1, - val_check_interval=1.0, + limit_train_batches=1, + limit_val_batches=1, logger=False, ) trainer.fit(model) @@ -579,6 +614,87 @@ def test_model_checkpoint_period(tmpdir, period: int): assert set(os.listdir(tmpdir)) == set(expected) +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=1, + limit_val_batches=1, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + + +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs): + """ Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """ + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_val_epochs=(2 * every_n_val_epochs), + period=every_n_val_epochs + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=1, + limit_val_batches=1, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + + +def test_ckpt_every_n_train_steps(tmpdir): + """ Tests that the checkpoints are saved every n training steps. """ + + model = LogInTwoMethods() + every_n_train_steps = 16 + max_epochs = 2 + epoch_length = 64 + checkpoint_callback = ModelCheckpoint( + filename="{step}", + every_n_val_epochs=0, + every_n_train_steps=every_n_train_steps, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + progress_bar_refresh_rate=0, + callbacks=[checkpoint_callback], + logger=False, + ) + + trainer.fit(model) + expected = [ + f"step={i}.ckpt" for i in range(every_n_train_steps - 1, max_epochs * epoch_length, every_n_train_steps) + ] + assert set(os.listdir(tmpdir)) == set(expected) + + def test_model_checkpoint_topk_zero(tmpdir): """ Test that no checkpoints are saved when save_top_k=0. """ model = LogInTwoMethods() diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 7d8c7d2adeea1..e65ebbab254de 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -104,3 +104,10 @@ def configure_optimizers(self): with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"): trainer.fit(model) + + +def test_v1_5_0_model_checkpoint_period(tmpdir): + with no_warning_call(DeprecationWarning): + ModelCheckpoint(dirpath=tmpdir) + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + ModelCheckpoint(dirpath=tmpdir, period=1)