From e754e43ab48702e57e7e2455ea3f7eda57fafbb1 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Sat, 2 Jan 2021 20:14:45 +0900 Subject: [PATCH] change 4/5 - checkpoint_connector --- .../trainer/connectors/checkpoint_connector.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 429bddd88b77e..abf2346d996b8 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -43,7 +43,7 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False - def restore_weights(self, model: LightningModule): + def restore_weights(self, model: LightningModule) -> None: """ Attempt to restore a checkpoint (e.g. weights) in this priority: 1. from HPC weights @@ -73,12 +73,18 @@ def restore_weights(self, model: LightningModule): if self.trainer.on_gpu: torch.cuda.empty_cache() - def restore(self, checkpoint_path: str, on_gpu: bool): + def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. """ + # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. + fs = get_filesystem(checkpoint_path) + if not fs.exists(checkpoint_path): + rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch") + return False + # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) @@ -94,6 +100,9 @@ def restore(self, checkpoint_path: str, on_gpu: bool): # restore training state self.restore_training_state(checkpoint) + rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}") + return True + def restore_model_state(self, model: LightningModule, checkpoint) -> None: """ Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object