diff --git a/CHANGELOG.md b/CHANGELOG.md index ed7eec7cff7f9..97d2fa55fa4ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -89,6 +89,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](https://github.com/PyTorchLightning/pytorch-lightning/pull/6109)) +- Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) + + +- Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) + + - Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4233933af1b1a..f457e9de7d0fa 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -189,7 +189,7 @@ def on_validation_end(self, trainer, pl_module): """ checkpoints can be saved at the end of the val loop """ - self.save_checkpoint(trainer, pl_module) + self.save_checkpoint(trainer) def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { @@ -204,12 +204,18 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, pl_module): + def save_checkpoint(self, trainer, unused: Optional = None): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` to handle correct behaviour in distributed training, i.e., saving only on rank 0. """ + if unused is not None: + rank_zero_warn( + "`ModelCheckpoint.save_checkpoint` signature has changed in v1.3. The `pl_module` parameter" + " has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning + ) + epoch = trainer.current_epoch global_step = trainer.global_step @@ -218,7 +224,6 @@ def save_checkpoint(self, trainer, pl_module): trainer.fast_dev_run # disable checkpointing with fast_dev_run or trainer.state != TrainerState.FITTING # don't save anything during non-fit or trainer.sanity_checking # don't save anything during sanity check - or self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch or self._last_global_step_saved == global_step # already saved at the last step @@ -236,11 +241,11 @@ def save_checkpoint(self, trainer, pl_module): # callback supports multiple simultaneous modes # here we call each mode sequentially - # Mode 1: save all checkpoints OR only the top k - if self.save_top_k: - self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates) - - # Mode 2: save the last checkpoint + # Mode 1: save the top k checkpoints + self._save_top_k_checkpoint(trainer, monitor_candidates) + # Mode 2: save monitor=None checkpoints + self._save_none_monitor_checkpoint(trainer, monitor_candidates) + # Mode 3: save last checkpoints self._save_last_checkpoint(trainer, monitor_candidates) def __validate_init_configuration(self): @@ -248,16 +253,21 @@ def __validate_init_configuration(self): raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved - if self.save_top_k not in [None, -1, 0]: + if self.save_top_k not in (None, -1, 0): raise MisconfigurationException( f'ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid' ' configuration. No quantity for top_k to track.' ) if self.save_last: rank_zero_warn( - 'ModelCheckpoint(save_last=True, monitor=None) is a redundant configuration.' + 'ModelCheckpoint(save_last=True, save_top_k=None, monitor=None) is a redundant configuration.' ' You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).' ) + if self.save_top_k == -1 and self.save_last: + rank_zero_info( + 'ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)' + ' will duplicate the last checkpoint saved.' + ) def __init_ckpt_dir(self, dirpath, filename, save_top_k): @@ -293,7 +303,16 @@ def _del_model(self, filepath: str): self._fs.rm(filepath) log.debug(f"Removed checkpoint: {filepath}") - def _save_model(self, filepath: str, trainer): + def _save_model(self, trainer, filepath: str): + if trainer.training_type_plugin.rpc_enabled: + # RPCPlugin manages saving all model states + # TODO: the rpc plugin should wrap trainer.save_checkpoint + # instead of us having to do it here manually + trainer.training_type_plugin.rpc_save_model(trainer, self._do_save, filepath) + else: + self._do_save(trainer, filepath) + + def _do_save(self, trainer, filepath: str): # in debugging, track when we save checkpoints trainer.dev_debugger.track_checkpointing_history(filepath) @@ -307,7 +326,7 @@ def _save_model(self, filepath: str, trainer): else: raise ValueError(".save_function() not set") - def check_monitor_top_k(self, current) -> bool: + def check_monitor_top_k(self, current: torch.Tensor) -> bool: if current is None: return False @@ -462,17 +481,17 @@ def _validate_monitor_key(self, trainer): def _get_metric_interpolated_filepath_name( self, - ckpt_name_metrics: Dict[str, Any], + monitor_candidates: Dict[str, Any], epoch: int, step: int, trainer, del_filepath: Optional[str] = None, ) -> str: - filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) + filepath = self.format_checkpoint_name(epoch, step, monitor_candidates) version_cnt = self.STARTING_VERSION while self.file_exists(filepath, trainer) and filepath != del_filepath: - filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt) + filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version_cnt) version_cnt += 1 return filepath @@ -482,47 +501,32 @@ def _monitor_candidates(self, trainer): monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch) return monitor_candidates - def _save_last_checkpoint(self, trainer, ckpt_name_metrics): - should_save_last = self.monitor is None or self.save_last - if not should_save_last: + def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + if not self.save_last: return - # when user ALSO asked for the 'last.ckpt' change the name - if self.save_last: - last_filepath = self._format_checkpoint_name( - self.CHECKPOINT_NAME_LAST, - trainer.current_epoch, - trainer.global_step, - ckpt_name_metrics, - ) - last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}") - else: - last_filepath = self._get_metric_interpolated_filepath_name( - ckpt_name_metrics, - trainer.current_epoch, - trainer.global_step, - trainer, - ) + filepath = self._format_checkpoint_name( + self.CHECKPOINT_NAME_LAST, + trainer.current_epoch, + trainer.global_step, + monitor_candidates, + ) + filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}") - if trainer.training_type_plugin.rpc_enabled: - # RPCPlugin manages saving all model states - trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer) - else: - self._save_model(last_filepath, trainer) - if ( - self.last_model_path and self.last_model_path != last_filepath - and (self.save_top_k != -1 or self.save_last) and trainer.is_global_zero - ): + self._save_model(trainer, filepath) + + if self.last_model_path and self.last_model_path != filepath and trainer.is_global_zero: self._del_model(self.last_model_path) - self.last_model_path = last_filepath - if self.monitor is None: - self.best_model_path = self.last_model_path + self.last_model_path = filepath + + def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + if self.monitor is None or self.save_top_k == 0: + return - def _save_top_k_checkpoints(self, trainer, pl_module, metrics): - current = metrics.get(self.monitor) - epoch = metrics.get("epoch") - step = metrics.get("step") + current = monitor_candidates.get(self.monitor) + epoch = monitor_candidates.get("epoch") + step = monitor_candidates.get("step") # when `val_loss` is being logged and no ModelCheckpoint is being provided # `val_loss` will be selected for monitor and need to be reduced to @@ -533,15 +537,37 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics): current = trainer.training_type_plugin.reduce(current, reduce_op="mean") if self.check_monitor_top_k(current): - self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics) - elif self.monitor is not None and self.verbose: + self._update_best_and_save(current, epoch, step, trainer, monitor_candidates) + elif self.verbose: rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") + def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + if self.monitor is not None or self.save_top_k == 0: + return + + filepath = self._get_metric_interpolated_filepath_name( + monitor_candidates, + trainer.current_epoch, + trainer.global_step, + trainer, + ) + self._save_model(trainer, filepath) + + if ( + self.save_top_k is None + and self.best_model_path + and self.best_model_path != filepath + and trainer.is_global_zero + ): + self._del_model(self.best_model_path) + + self.best_model_path = filepath + def _is_valid_monitor_key(self, metrics): return self.monitor in metrics or len(metrics) == 0 def _update_best_and_save( - self, current: torch.Tensor, epoch: int, step: int, trainer, pl_module, ckpt_name_metrics + self, current: torch.Tensor, epoch: int, step: int, trainer, monitor_candidates: Dict[str, Any] ): k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k @@ -554,7 +580,7 @@ def _update_best_and_save( if isinstance(current, torch.Tensor) and torch.isnan(current): current = torch.tensor(float('inf' if self.mode == "min" else '-inf')) - filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, trainer, del_filepath) + filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, step, trainer, del_filepath) # save the current score self.current_score = current @@ -575,7 +601,7 @@ def _update_best_and_save( f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}" f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}' ) - self._save_model(filepath, trainer) + self._save_model(trainer, filepath) if del_filepath is not None and filepath != del_filepath: self._del_model(del_filepath) diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index 0e8ca557e447b..faf528d76b768 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -13,7 +13,7 @@ # limitations under the License. import os from contextlib import suppress -from typing import List, Optional +from typing import List, Optional, Callable import torch @@ -63,15 +63,15 @@ def init_rpc_connection(self, global_rank: int, world_size: int) -> None: rpc._set_rpc_timeout(self.rpc_timeout_sec) self._is_rpc_initialized = True - def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None: + def rpc_save_model(self, trainer, save_model_fn: Callable, filepath: str) -> None: """ Override to save model to disk. This is required as the main process will be required to handle aggregating model states from RPC processes. Args: - save_model_fn: The saving function to save final model. - last_filepath: The filepath to save the model to. trainer: The trainer object. + save_model_fn: The saving function to save final model. + filepath: The filepath to save the model to. """ raise NotImplementedError diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 8fd75555ecd14..336c16f0f1a03 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -13,7 +13,7 @@ # limitations under the License import logging import os -from typing import List, Optional +from typing import List, Optional, Callable import torch import torch.distributed as torch_distrib @@ -266,7 +266,7 @@ def configure_ddp(self): self._model.require_backward_grad_sync = False @rank_zero_only - def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None: + def rpc_save_model(self, trainer, save_model_fn: Callable, filepath: str) -> None: model = self.lightning_module if not hasattr(model.sequential_module, "foreach_worker"): return @@ -275,7 +275,7 @@ def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None: save_layers_on_all_rank_zero_workers, {"gpus_per_model": self.gpus_per_model}, include_self=True ) model.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model) - save_model_fn(last_filepath, trainer) + save_model_fn(trainer, filepath) model.sequential_module = current_layers def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 48e4a22e1ec05..3d5cddc4537a7 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -17,6 +17,7 @@ import pickle import re from argparse import Namespace +from logging import INFO from pathlib import Path from unittest import mock from unittest.mock import Mock @@ -500,20 +501,20 @@ def test_none_monitor_top_k(tmpdir): def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """ - with pytest.warns(UserWarning, match=r'ModelCheckpoint\(save_last=True, monitor=None\) is a redundant.*'): + with pytest.warns(UserWarning, match=r'ModelCheckpoint.*is a redundant.*'): ModelCheckpoint(dirpath=tmpdir, save_last=True) # These should not fail ModelCheckpoint(dirpath=tmpdir, save_last=None) ModelCheckpoint(dirpath=tmpdir, save_last=False) -def test_model_checkpoint_none_monitor(tmpdir): +def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): """ Test that it is possible to save all checkpoints when monitor=None. """ seed_everything() model = LogInTwoMethods() epochs = 2 - checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1) + checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], @@ -522,17 +523,22 @@ def test_model_checkpoint_none_monitor(tmpdir): max_epochs=epochs, logger=False, ) - trainer.fit(model) + + with caplog.at_level(INFO): + trainer.fit(model) + assert "will duplicate the last checkpoint saved" in caplog.text # these should not be set if monitor is None assert checkpoint_callback.monitor is None - assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1-step=19.ckpt' + assert checkpoint_callback.best_model_path == tmpdir / 'epoch=1-step=19.ckpt' + assert checkpoint_callback.last_model_path == tmpdir / 'last.ckpt' assert checkpoint_callback.best_model_score is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' # check that the correct ckpts were created expected = [f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19])] + expected.append('last.ckpt') assert set(os.listdir(tmpdir)) == set(expected) @@ -560,7 +566,7 @@ def test_model_checkpoint_period(tmpdir, period): def test_model_checkpoint_topk_zero(tmpdir): """ Test that no checkpoints are saved when save_top_k=0. """ model = LogInTwoMethods() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=0) + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=0, save_last=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], @@ -574,8 +580,9 @@ def test_model_checkpoint_topk_zero(tmpdir): assert checkpoint_callback.best_model_score is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' - # check that no ckpts were created - assert len(os.listdir(tmpdir)) == 0 + # check that only the last ckpt was created + assert os.listdir(tmpdir) == ['last.ckpt'] + assert checkpoint_callback.last_model_path == tmpdir / 'last.ckpt' def test_model_checkpoint_topk_all(tmpdir): @@ -1083,7 +1090,7 @@ def test_ckpt_version_after_rerun_same_trainer(tmpdir): # check best_k_models state assert {Path(f).name for f in mc.best_k_models.keys()} == expected # check created ckpts - assert set(sorted(os.listdir(tmpdir))) == expected + assert set(os.listdir(tmpdir)) == expected def test_model_checkpoint_mode_options(): diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index cc9bcc9d56c06..7d8c7d2adeea1 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -12,18 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in v1.5.0""" - from unittest import mock import pytest from torch import optim from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call +def test_v1_5_0_model_checkpoint_save_checkpoint(): + model_ckpt = ModelCheckpoint() + model_ckpt.save_function = lambda *_, **__: None + with pytest.deprecated_call(match="ModelCheckpoint.save_checkpoint` signature has changed"): + model_ckpt.save_checkpoint(Trainer(), object()) + + @mock.patch('pytorch_lightning.loggers.wandb.wandb') def test_v1_5_0_wandb_unused_sync_step(tmpdir): with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"): diff --git a/tests/plugins/test_rpc_plugin.py b/tests/plugins/test_rpc_plugin.py index 02dde1903ca5b..9ecc93a9b5055 100644 --- a/tests/plugins/test_rpc_plugin.py +++ b/tests/plugins/test_rpc_plugin.py @@ -58,7 +58,7 @@ def __init__(self, **kwargs): self.rpc_save_model_count = 0 self.worker_optimizer_step_count = 0 - def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None: + def rpc_save_model(self, *_) -> None: self.rpc_save_model_count += 1 def barrier(self, name: Optional[str] = None) -> None: