Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ModelCheckpoint(monitor=None, save_last=True) not saving checkpoints #6136

Merged
merged 17 commits into from
Mar 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](https://github.com/PyTorchLightning/pytorch-lightning/pull/6109))


- Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136))


- Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136))


- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115))


Expand Down
136 changes: 81 additions & 55 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def on_validation_end(self, trainer, pl_module):
"""
checkpoints can be saved at the end of the val loop
"""
self.save_checkpoint(trainer, pl_module)
self.save_checkpoint(trainer)

def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
return {
Expand All @@ -204,12 +204,18 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]):
self.best_model_score = callback_state["best_model_score"]
self.best_model_path = callback_state["best_model_path"]

def save_checkpoint(self, trainer, pl_module):
def save_checkpoint(self, trainer, unused: Optional = None):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""
Performs the main logic around saving a checkpoint.
This method runs on all ranks, it is the responsibility of `self.save_function`
to handle correct behaviour in distributed training, i.e., saving only on rank 0.
"""
if unused is not None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
rank_zero_warn(
"`ModelCheckpoint.save_checkpoint` signature has changed in v1.3. The `pl_module` parameter"
" has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning
)

carmocca marked this conversation as resolved.
Show resolved Hide resolved
epoch = trainer.current_epoch
global_step = trainer.global_step

Expand All @@ -218,7 +224,6 @@ def save_checkpoint(self, trainer, pl_module):
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self.save_top_k == 0 # no models are saved
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or self._last_global_step_saved == global_step # already saved at the last step
Expand All @@ -236,28 +241,33 @@ def save_checkpoint(self, trainer, pl_module):

# callback supports multiple simultaneous modes
# here we call each mode sequentially
# Mode 1: save all checkpoints OR only the top k
if self.save_top_k:
self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates)

# Mode 2: save the last checkpoint
# Mode 1: save the top k checkpoints
self._save_top_k_checkpoint(trainer, monitor_candidates)
# Mode 2: save monitor=None checkpoints
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
# Mode 3: save last checkpoints
self._save_last_checkpoint(trainer, monitor_candidates)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
if self.monitor is None:
# None: save last epoch, -1: save all epochs, 0: nothing is saved
if self.save_top_k not in [None, -1, 0]:
if self.save_top_k not in (None, -1, 0):
raise MisconfigurationException(
f'ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid'
' configuration. No quantity for top_k to track.'
)
if self.save_last:
rank_zero_warn(
'ModelCheckpoint(save_last=True, monitor=None) is a redundant configuration.'
'ModelCheckpoint(save_last=True, save_top_k=None, monitor=None) is a redundant configuration.'
' You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).'
)
if self.save_top_k == -1 and self.save_last:
rank_zero_info(
'ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)'
carmocca marked this conversation as resolved.
Show resolved Hide resolved
' will duplicate the last checkpoint saved.'
)

def __init_ckpt_dir(self, dirpath, filename, save_top_k):

Expand Down Expand Up @@ -293,7 +303,16 @@ def _del_model(self, filepath: str):
self._fs.rm(filepath)
log.debug(f"Removed checkpoint: {filepath}")

def _save_model(self, filepath: str, trainer):
def _save_model(self, trainer, filepath: str):
if trainer.training_type_plugin.rpc_enabled:
# RPCPlugin manages saving all model states
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# TODO: the rpc plugin should wrap trainer.save_checkpoint
# instead of us having to do it here manually
trainer.training_type_plugin.rpc_save_model(trainer, self._do_save, filepath)
else:
self._do_save(trainer, filepath)

def _do_save(self, trainer, filepath: str):
# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

Expand All @@ -307,7 +326,7 @@ def _save_model(self, filepath: str, trainer):
else:
raise ValueError(".save_function() not set")

def check_monitor_top_k(self, current) -> bool:
def check_monitor_top_k(self, current: torch.Tensor) -> bool:
if current is None:
return False

Expand Down Expand Up @@ -462,17 +481,17 @@ def _validate_monitor_key(self, trainer):

def _get_metric_interpolated_filepath_name(
self,
ckpt_name_metrics: Dict[str, Any],
monitor_candidates: Dict[str, Any],
carmocca marked this conversation as resolved.
Show resolved Hide resolved
epoch: int,
step: int,
trainer,
del_filepath: Optional[str] = None,
) -> str:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates)

version_cnt = self.STARTING_VERSION
while self.file_exists(filepath, trainer) and filepath != del_filepath:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version_cnt)
version_cnt += 1

return filepath
Expand All @@ -482,47 +501,32 @@ def _monitor_candidates(self, trainer):
monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch)
return monitor_candidates

def _save_last_checkpoint(self, trainer, ckpt_name_metrics):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]):
if not self.save_last:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return

# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
last_filepath = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
ckpt_name_metrics,
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
else:
last_filepath = self._get_metric_interpolated_filepath_name(
ckpt_name_metrics,
trainer.current_epoch,
trainer.global_step,
trainer,
)
filepath = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
monitor_candidates,
)
filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}")

if trainer.training_type_plugin.rpc_enabled:
# RPCPlugin manages saving all model states
trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer)
else:
self._save_model(last_filepath, trainer)
if (
self.last_model_path and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last) and trainer.is_global_zero
):
self._save_model(trainer, filepath)

if self.last_model_path and self.last_model_path != filepath and trainer.is_global_zero:
self._del_model(self.last_model_path)
self.last_model_path = last_filepath

if self.monitor is None:
self.best_model_path = self.last_model_path
self.last_model_path = filepath

def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]):
if self.monitor is None or self.save_top_k == 0:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return

def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
current = metrics.get(self.monitor)
epoch = metrics.get("epoch")
step = metrics.get("step")
current = monitor_candidates.get(self.monitor)
epoch = monitor_candidates.get("epoch")
step = monitor_candidates.get("step")

# when `val_loss` is being logged and no ModelCheckpoint is being provided
# `val_loss` will be selected for monitor and need to be reduced to
Expand All @@ -533,15 +537,37 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
current = trainer.training_type_plugin.reduce(current, reduce_op="mean")

if self.check_monitor_top_k(current):
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
elif self.monitor is not None and self.verbose:
self._update_best_and_save(current, epoch, step, trainer, monitor_candidates)
elif self.verbose:
rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}")

def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]):
if self.monitor is not None or self.save_top_k == 0:
return

filepath = self._get_metric_interpolated_filepath_name(
monitor_candidates,
trainer.current_epoch,
trainer.global_step,
trainer,
)
self._save_model(trainer, filepath)

if (
self.save_top_k is None
and self.best_model_path
and self.best_model_path != filepath
and trainer.is_global_zero
):
self._del_model(self.best_model_path)

self.best_model_path = filepath
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def _is_valid_monitor_key(self, metrics):
return self.monitor in metrics or len(metrics) == 0

def _update_best_and_save(
self, current: torch.Tensor, epoch: int, step: int, trainer, pl_module, ckpt_name_metrics
self, current: torch.Tensor, epoch: int, step: int, trainer, monitor_candidates: Dict[str, Any]
):
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k

Expand All @@ -554,7 +580,7 @@ def _update_best_and_save(
if isinstance(current, torch.Tensor) and torch.isnan(current):
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))

filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, trainer, del_filepath)
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, step, trainer, del_filepath)

# save the current score
self.current_score = current
Expand All @@ -575,7 +601,7 @@ def _update_best_and_save(
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
)
self._save_model(filepath, trainer)
self._save_model(trainer, filepath)

if del_filepath is not None and filepath != del_filepath:
self._del_model(del_filepath)
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/plugins/training_type/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from contextlib import suppress
from typing import List, Optional
from typing import List, Optional, Callable

import torch

Expand Down Expand Up @@ -63,15 +63,15 @@ def init_rpc_connection(self, global_rank: int, world_size: int) -> None:
rpc._set_rpc_timeout(self.rpc_timeout_sec)
self._is_rpc_initialized = True

def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
def rpc_save_model(self, trainer, save_model_fn: Callable, filepath: str) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""
Override to save model to disk.
This is required as the main process will be required to handle aggregating model states from RPC processes.

Args:
save_model_fn: The saving function to save final model.
last_filepath: The filepath to save the model to.
trainer: The trainer object.
save_model_fn: The saving function to save final model.
filepath: The filepath to save the model to.
"""
raise NotImplementedError

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License
import logging
import os
from typing import List, Optional
from typing import List, Optional, Callable

import torch
import torch.distributed as torch_distrib
Expand Down Expand Up @@ -266,7 +266,7 @@ def configure_ddp(self):
self._model.require_backward_grad_sync = False

@rank_zero_only
def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
def rpc_save_model(self, trainer, save_model_fn: Callable, filepath: str) -> None:
model = self.lightning_module
if not hasattr(model.sequential_module, "foreach_worker"):
return
Expand All @@ -275,7 +275,7 @@ def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
save_layers_on_all_rank_zero_workers, {"gpus_per_model": self.gpus_per_model}, include_self=True
)
model.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model)
save_model_fn(last_filepath, trainer)
save_model_fn(trainer, filepath)
model.sequential_module = current_layers

def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None:
Expand Down
Loading