From b52365a3dc14e4b62ee262a8fe0b1e1a0a56e1fa Mon Sep 17 00:00:00 2001 From: Frederik Diehl Date: Wed, 15 Jan 2020 16:18:01 +0100 Subject: [PATCH 1/2] Added atomic checkpoint creation --- pytorch_lightning/trainer/training_io.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 168e983585c2a..e45dfb3dfd175 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -255,17 +255,22 @@ def term_handler(self, signum, frame): # -------------------- # MODEL SAVE CHECKPOINT # -------------------- + def _atomic_save(self, checkpoint, filepath): + tmp_path = str(filepath) + ".part" + torch.save(checkpoint, tmp_path) + os.replace(tmp_path, filepath) + def save_checkpoint(self, filepath): checkpoint = self.dump_checkpoint() # do the actual save try: - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) except AttributeError: if 'hparams' in checkpoint: del checkpoint['hparams'] - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) def restore(self, checkpoint_path, on_gpu): # if on_gpu: @@ -412,12 +417,12 @@ def hpc_save(self, folderpath, logger): # do the actual save # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) except AttributeError: if 'hparams' in checkpoint: del checkpoint['hparams'] - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) return filepath From eda52d5c6e98c2765886a90ecf3d1cdb3ad80e4a Mon Sep 17 00:00:00 2001 From: Frederik Diehl Date: Mon, 20 Jan 2020 16:00:59 +0100 Subject: [PATCH 2/2] Added documentation for _atomic_checkpoint --- pytorch_lightning/trainer/training_io.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index e45dfb3dfd175..e5848c12495c1 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -256,6 +256,18 @@ def term_handler(self, signum, frame): # MODEL SAVE CHECKPOINT # -------------------- def _atomic_save(self, checkpoint, filepath): + """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. + + This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once + saving is finished. + + Args: + checkpoint (object): The object to save. + Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save`` + accepts. + filepath (str|pathlib.Path): The path to which the checkpoint will be saved. + This points to the file that the checkpoint will be stored in. + """ tmp_path = str(filepath) + ".part" torch.save(checkpoint, tmp_path) os.replace(tmp_path, filepath)