From cf8e828559450095cfdead77b074d05b131f34b6 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Tue, 6 Apr 2021 15:02:44 +0530 Subject: [PATCH] [Fix] TPU Training Type Plugin (#6816) --- CHANGELOG.md | 3 + .../plugins/training_type/single_tpu.py | 37 ++------ .../plugins/training_type/tpu_spawn.py | 95 ++----------------- .../connectors/accelerator_connector.py | 21 ++-- pytorch_lightning/utilities/enums.py | 5 +- tests/models/test_tpu.py | 4 +- 6 files changed, 36 insertions(+), 129 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c65c87371bda4..a956b63c5fb54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -200,6 +200,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) +- Fixed TPU Colab hang issue, post training ([#6816](https://github.com/PyTorchLightning/pytorch-lightning/pull/6816)) + + - Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588)) diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index b8d670ff16881..df1238861ae06 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -11,15 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Optional, Union - import torch -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin -from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device if _TPU_AVAILABLE: @@ -28,17 +23,22 @@ class SingleTPUPlugin(SingleDevicePlugin): - def __init__(self, device: Union[torch.device, int]): - if isinstance(device, int): - device = xm.xla_device(device) + def __init__(self, device: int): + + device = xm.xla_device(device) super().__init__(device) self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 + @property def on_tpu(self) -> bool: return True + @property + def is_distributed(self) -> bool: + return False + def model_to_device(self) -> None: self.model.to(self.root_device) @@ -49,21 +49,6 @@ def pre_dispatch(self) -> None: self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() - def post_dispatch(self) -> None: - model = self.lightning_module - - if on_colab_kaggle(): - rank_zero_warn("cleaning up... please do not interrupt") - self.save_spawn_weights(model) - - def save_spawn_weights(self, model: LightningModule) -> Optional[str]: - """ - Dump a temporary checkpoint after ddp ends to get weights out of the process - """ - path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt") - model.trainer.save_checkpoint(path) - return path - def on_save(self, checkpoint: dict) -> dict: """ Move XLA tensors to CPU before saving @@ -71,7 +56,3 @@ def on_save(self, checkpoint: dict) -> dict: https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors """ return move_data_to_device(checkpoint, torch.device("cpu")) - - @property - def is_distributed(self): - return False diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 85a1c2fe1c2a6..68068935127e2 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -20,9 +20,7 @@ import torch import torch.multiprocessing as mp -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin -from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -32,12 +30,11 @@ if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm - import torch_xla.distributed.parallel_loader as xla_pl import torch_xla.distributed.xla_multiprocessing as xmp from torch_xla.core.xla_model import rendezvous - from torch_xla.distributed.parallel_loader import ParallelLoader + from torch_xla.distributed.parallel_loader import MpDeviceLoader else: - xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5 + xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 if _OMEGACONF_AVAILABLE: from omegaconf import DictConfig, ListConfig, OmegaConf @@ -45,15 +42,8 @@ class TPUSpawnPlugin(DDPSpawnPlugin): - def __init__( - self, - parallel_devices: Optional[List[torch.device]] = None, - num_nodes: int = 1, - **kwargs: Dict[str, Any] - ) -> None: - super().__init__( - parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs - ) + def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None: + super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False) self.tpu_local_core_rank = 0 self.start_method = None @@ -74,10 +64,9 @@ def distributed_sampler_kwargs(self) -> dict: def is_distributed(self): return self.world_size != 1 - def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> ParallelLoader: + def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader: device = xm.xla_device() - dataloader = xla_pl.ParallelLoader(dataloader, [device]) - dataloader = dataloader.per_device_loader(device) + dataloader = MpDeviceLoader(dataloader, device) return dataloader def configure_ddp(self) -> None: @@ -115,7 +104,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: results = trainer.run_stage() - self.__save_end_of_training_weights(self.lightning_module) self.transfer_distrib_spawn_state_on_fit_end(results) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 @@ -125,12 +113,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: if self.global_rank == 0: time.sleep(2) - def __save_end_of_training_weights(self, model: LightningModule) -> None: - # when training ends on these platforms dump weights to get out of the main process - if on_colab_kaggle(): - rank_zero_warn("cleaning up... please do not interrupt") - self.save_spawn_weights(model) - def model_to_device(self) -> None: self._model.to(xm.xla_device()) @@ -172,37 +154,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj = torch.load(buffer) return obj - def load_spawn_weights(self, original_model: LightningModule) -> LightningModule: - """ - Load the temp weights saved in the process - To recover the trained model from the ddp process we load the saved weights - """ - - loaded_model = original_model - - if self.is_global_zero: - # load weights saved in ddp - path = os.path.join(original_model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt") - loaded_model = original_model.__class__.load_from_checkpoint(path) - - # copy loaded weights to old model - original_model.load_state_dict(loaded_model.state_dict()) - - # remove ddp weights - os.remove(path) - - return loaded_model - - def save_spawn_weights(self, model: LightningModule) -> Optional[str]: - """ - Dump a temporary checkpoint after ddp ends to get weights out of the process - """ - if model.trainer.is_global_zero: - path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt") - model.trainer.save_checkpoint(path) - return path - - def reduce_decision(self, decision: bool) -> bool: + def reduce_boolean_decision(self, decision: bool) -> bool: decision = torch.tensor(int(decision), device=self.device) decision = self.reduce(decision, "sum") decision = bool(decision == self.world_size) @@ -226,39 +178,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ return output - def post_dispatch(self) -> None: - # TODO: Check if trainer references can be resolved otherwise - model = self.lightning_module - - # restore main state with best weights - best_path = self.mp_queue.get() - last_path = self.mp_queue.get() - self._results = self.mp_queue.get() - - # transfer back the best path to the trainer - if self.lightning_module.trainer.checkpoint_callback is not None: - self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also bets score - - # load last weights - if last_path and model.trainer.state == TrainerState.FITTING: - ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt) - - self._model = model - - # when training completes, load the weights back in main process - self.__load_weights_on_main_process() - - def __load_weights_on_main_process(self) -> None: - model = self.lightning_module - - # load weights if not interrupted - if on_colab_kaggle() and model.trainer.state == TrainerState.FITTING: - self.load_spawn_weights(model) - - self._model = model - def _close_logger(self, trainer) -> None: if trainer.logger is not None: trainer.logger.finalize("success") diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 30d2b48975a84..1e00d33cdf05a 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -257,7 +257,7 @@ def use_dp(self) -> bool: def use_ddp(self) -> bool: return self._distrib_type in ( DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED + DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN ) @property @@ -297,7 +297,8 @@ def parallel_devices(self) -> List[Union[torch.device, int]]: elif self.on_tpu: # explicitly don't make a tpu device here! # https://github.com/PyTorchLightning/pytorch-lightning/issues/3169 - devices = [i for i in self.parallel_device_ids] + if isinstance(self.tpu_cores, int): + devices = list(range(self.tpu_cores)) else: devices = [torch.device("cpu")] * self.num_processes return devices @@ -376,6 +377,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu + use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED @@ -386,7 +388,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: if os.environ.get("PL_IN_DDP_SUBPROCESS", False): use_torchelastic_ddp = False - if self.on_tpu: + if use_tpu_spawn: ddp_plugin_cls = TPUSpawnPlugin elif use_ddp_sharded: ddp_plugin_cls = DDPShardedPlugin @@ -409,11 +411,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) elif self.use_horovod: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) - elif self.on_tpu: - if isinstance(self.tpu_cores, list): - plugin = SingleTPUPlugin(self.tpu_id) - else: - plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores))) + elif self.on_tpu and isinstance(self.tpu_cores, list): + plugin = SingleTPUPlugin(self.tpu_id) else: single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids) plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu")) @@ -507,6 +506,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): # special case with TPUs elif self.distributed_backend == 'tpu' or self.tpu_cores is not None: self._device_type = DeviceType.TPU + if isinstance(self.tpu_cores, int): + self._distrib_type = DistributedType.TPU_SPAWN elif self.distributed_backend and self._distrib_type is None: self._distrib_type = DistributedType(self.distributed_backend) @@ -515,9 +516,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): if self.num_gpus > 0 and not _on_cpu: self._device_type = DeviceType.GPU - _distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) + _gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) # DP and DDP2 cannot run without GPU - if self.num_gpus == 0 and self._distrib_type in _distrib_types and not _on_cpu: + if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu: rank_zero_warn( 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' ) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 169481fa63e67..1ec04549cd87e 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -62,7 +62,9 @@ class DistributedType(LightningEnum): @staticmethod def interactive_compatible_types() -> List['DistributedType']: """Returns a list containing interactive compatible DistributeTypes""" - return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN] + return [ + DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN, DistributedType.TPU_SPAWN + ] def is_interactive_compatible(self) -> bool: """Returns whether self is interactive compatible""" @@ -72,6 +74,7 @@ def is_interactive_compatible(self) -> bool: DDP = 'ddp' DDP2 = 'ddp2' DDP_SPAWN = 'ddp_spawn' + TPU_SPAWN = 'tpu_spawn' DEEPSPEED = 'deepspeed' HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index b2ed0db87d8d5..f541c578df1af 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -210,8 +210,8 @@ def test_tpu_grad_norm(tmpdir): progress_bar_refresh_rate=0, max_epochs=4, tpu_cores=1, - limit_train_batches=4, - limit_val_batches=4, + limit_train_batches=0.4, + limit_val_batches=0.4, gradient_clip_val=0.5, )