diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 168e983585c2a..e45dfb3dfd175 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -255,17 +255,22 @@ def term_handler(self, signum, frame): # -------------------- # MODEL SAVE CHECKPOINT # -------------------- + def _atomic_save(self, checkpoint, filepath): + tmp_path = str(filepath) + ".part" + torch.save(checkpoint, tmp_path) + os.replace(tmp_path, filepath) + def save_checkpoint(self, filepath): checkpoint = self.dump_checkpoint() # do the actual save try: - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) except AttributeError: if 'hparams' in checkpoint: del checkpoint['hparams'] - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) def restore(self, checkpoint_path, on_gpu): # if on_gpu: @@ -412,12 +417,12 @@ def hpc_save(self, folderpath, logger): # do the actual save # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) except AttributeError: if 'hparams' in checkpoint: del checkpoint['hparams'] - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) return filepath