From a2295cb5223077503085d693729f23c638089680 Mon Sep 17 00:00:00 2001 From: Brendan Fahy Date: Wed, 12 Aug 2020 10:31:17 +0000 Subject: [PATCH] fix checkpointing to remote file paths (#2925) --- pytorch_lightning/callbacks/model_checkpoint.py | 15 ++++++++++----- .../trainer/distrib_data_parallel.py | 9 +++++++-- pytorch_lightning/trainer/trainer.py | 5 +++-- pytorch_lightning/trainer/training_io.py | 12 +++++++----- pytorch_lightning/utilities/cloud_io.py | 11 ++++++++++- 5 files changed, 37 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6fee7bdd6cc6bb..0346a1e8575bdf 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -16,7 +16,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only -from pytorch_lightning.utilities.cloud_io import gfile, makedirs +from pytorch_lightning.utilities.cloud_io import gfile, makedirs, is_remote_path class ModelCheckpoint(Callback): @@ -122,10 +122,10 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve if gfile.isdir(filepath): self.dirpath, self.filename = filepath, '{epoch}' else: - filepath = os.path.realpath(filepath) + if not is_remote_path(filepath): # dont normalize remote paths + filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) - if not gfile.exists(self.dirpath): - makedirs(self.dirpath) + makedirs(self.dirpath) # calls with exist_ok self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only @@ -174,7 +174,12 @@ def _del_model(self, filepath): # dependencies exist then this will work fine. gfile.remove(filepath) except AttributeError: - os.remove(filepath) + if is_remote_path(filepath): + log.warning("Unable to remove stale checkpoints due to running gfile in compatibility mode." + " Please install tensorflow to run gfile in full mode" + " if writing checkpoints to remote locations") + else: + os.remove(filepath) def _save_model(self, filepath, trainer, pl_module): diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index c550fb648f0ca6..93f29f93118ea3 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -127,6 +127,7 @@ def train_fx(trial_hparams, cluster_manager, _): """ +import io import os import re from abc import ABC, abstractmethod @@ -146,6 +147,7 @@ def train_fx(trial_hparams, cluster_manager, _): from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.cloud_io import cloud_open try: @@ -435,10 +437,13 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): # Can't use the new zipfile serialization for 1.6.0 because there's a bug in # torch.hub.load_state_dict_from_url() that prevents it from loading the new files. # More details can be found here: https://github.com/pytorch/pytorch/issues/42239 + bytesbuffer = io.BytesIO() if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: - torch.save(model.state_dict(), last_path, _use_new_zipfile_serialization=False) + torch.save(model.state_dict(), bytesbuffer, _use_new_zipfile_serialization=False) else: - torch.save(model.state_dict(), last_path) + torch.save(model.state_dict(), bytesbuffer) + with cloud_open(last_path, 'wb') as f: + f.write(bytesbuffer.getvalue()) mp_queue.put(last_path) def save_spawn_weights(self, model): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index edbba05813ad17..126d69fdb13374 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -53,6 +53,7 @@ from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.cloud_io import is_remote_path # warnings to ignore in trainer warnings.filterwarnings( @@ -880,7 +881,7 @@ def default_root_dir(self) -> str: The default location to save artifacts of loggers, checkpoints etc. It is used as a fallback if logger or checkpoint callback do not define specific save paths. """ - if "://" in str(self._default_root_dir): + if is_remote_path(self._default_root_dir): # it is a remote uri, use as is return self._default_root_dir return os.path.normpath(self._default_root_dir) @@ -891,7 +892,7 @@ def weights_save_path(self) -> str: The default root location to save weights (checkpoints), e.g., when the :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. """ - if "://" in str(self._weights_save_path): + if is_remote_path(self._weights_save_path): # it is a remote uri, use as is return self._weights_save_path return os.path.normpath(self._weights_save_path) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 7a1613b919a267..08babe22c2f451 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -83,6 +83,7 @@ """ +import io import os import re import signal @@ -104,7 +105,7 @@ ) from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.cloud_io import gfile, makedirs +from pytorch_lightning.utilities.cloud_io import cloud_open, gfile, makedirs try: import torch_xla @@ -269,15 +270,16 @@ def _atomic_save(self, checkpoint, filepath: str): filepath: The path to which the checkpoint will be saved. This points to the file that the checkpoint will be stored in. """ - tmp_path = str(filepath) + ".part" + bytesbuffer = io.BytesIO() # Can't use the new zipfile serialization for 1.6.0 because there's a bug in # torch.hub.load_state_dict_from_url() that prevents it from loading the new files. # More details can be found here: https://github.com/pytorch/pytorch/issues/42239 if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: - torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False) + torch.save(checkpoint, bytesbuffer, _use_new_zipfile_serialization=False) else: - torch.save(checkpoint, tmp_path) - os.replace(tmp_path, filepath) + torch.save(checkpoint, bytesbuffer) + with cloud_open(filepath, 'wb') as f: + f.write(bytesbuffer.getvalue()) def save_checkpoint(self, filepath, weights_only: bool = False): checkpoint = self.dump_checkpoint(weights_only) diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index f6b0f5b42b831f..303db975f998e4 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -28,6 +28,14 @@ def load(path_or_url: str, map_location=None): return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) +def is_remote_path(path: pathlike): + """Determine if a path is a local path or a remote path like s3://bucket/path + + This should catch paths like s3:// hdfs:// and gcs:// + """ + return "://" in str(path) + + def modern_gfile(): """Check the version number of tensorboard. @@ -61,6 +69,7 @@ def cloud_open(path: pathlike, mode: str, newline: str = None): def makedirs(path: pathlike): if hasattr(gfile, "makedirs") and modern_gfile(): - return gfile.makedirs(str(path)) + if not gfile.exists(str(path)): + return gfile.makedirs(str(path)) # otherwise minimal dependencies are installed and only local files will work return os.makedirs(path, exist_ok=True)