From 5518f7677f163a59836655973a3a7f355f9224b8 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Sun, 4 Apr 2021 12:58:43 +0530 Subject: [PATCH 01/15] Update TPU Training Type Plugin --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 85a1c2fe1c2a6..ff8235b5ab57d 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -32,12 +32,11 @@ if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm - import torch_xla.distributed.parallel_loader as xla_pl import torch_xla.distributed.xla_multiprocessing as xmp from torch_xla.core.xla_model import rendezvous - from torch_xla.distributed.parallel_loader import ParallelLoader + from torch_xla.distributed.parallel_loader import MpDeviceLoader else: - xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5 + xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 if _OMEGACONF_AVAILABLE: from omegaconf import DictConfig, ListConfig, OmegaConf @@ -74,10 +73,9 @@ def distributed_sampler_kwargs(self) -> dict: def is_distributed(self): return self.world_size != 1 - def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> ParallelLoader: + def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader: device = xm.xla_device() - dataloader = xla_pl.ParallelLoader(dataloader, [device]) - dataloader = dataloader.per_device_loader(device) + dataloader = MpDeviceLoader(dataloader, device) return dataloader def configure_ddp(self) -> None: From 90578b5ad636aa04ac6d78333b6263563fd35ea6 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Sun, 4 Apr 2021 21:31:04 +0530 Subject: [PATCH 02/15] Update post dispatch --- .../plugins/training_type/tpu_spawn.py | 69 +++---------------- 1 file changed, 9 insertions(+), 60 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index ff8235b5ab57d..6e625e946876c 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -22,7 +22,6 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin -from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -113,7 +112,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: results = trainer.run_stage() - self.__save_end_of_training_weights(self.lightning_module) self.transfer_distrib_spawn_state_on_fit_end(results) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 @@ -123,12 +121,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: if self.global_rank == 0: time.sleep(2) - def __save_end_of_training_weights(self, model: LightningModule) -> None: - # when training ends on these platforms dump weights to get out of the main process - if on_colab_kaggle(): - rank_zero_warn("cleaning up... please do not interrupt") - self.save_spawn_weights(model) - def model_to_device(self) -> None: self._model.to(xm.xla_device()) @@ -170,36 +162,6 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj = torch.load(buffer) return obj - def load_spawn_weights(self, original_model: LightningModule) -> LightningModule: - """ - Load the temp weights saved in the process - To recover the trained model from the ddp process we load the saved weights - """ - - loaded_model = original_model - - if self.is_global_zero: - # load weights saved in ddp - path = os.path.join(original_model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt") - loaded_model = original_model.__class__.load_from_checkpoint(path) - - # copy loaded weights to old model - original_model.load_state_dict(loaded_model.state_dict()) - - # remove ddp weights - os.remove(path) - - return loaded_model - - def save_spawn_weights(self, model: LightningModule) -> Optional[str]: - """ - Dump a temporary checkpoint after ddp ends to get weights out of the process - """ - if model.trainer.is_global_zero: - path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt") - model.trainer.save_checkpoint(path) - return path - def reduce_decision(self, decision: bool) -> bool: decision = torch.tensor(int(decision), device=self.device) decision = self.reduce(decision, "sum") @@ -225,37 +187,24 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ return output def post_dispatch(self) -> None: - # TODO: Check if trainer references can be resolved otherwise - model = self.lightning_module - # restore main state with best weights best_path = self.mp_queue.get() last_path = self.mp_queue.get() self._results = self.mp_queue.get() + # recover the weights of the processes trained in the children + self.__recover_child_process_weights(best_path, last_path) + + def __recover_child_process_weights(self, best_path, last_path): # transfer back the best path to the trainer - if self.lightning_module.trainer.checkpoint_callback is not None: + if self.lightning_module.trainer.checkpoint_callback: self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also bets score + # todo, pass also best score # load last weights - if last_path and model.trainer.state == TrainerState.FITTING: - ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt) - - self._model = model - - # when training completes, load the weights back in main process - self.__load_weights_on_main_process() - - def __load_weights_on_main_process(self) -> None: - model = self.lightning_module - - # load weights if not interrupted - if on_colab_kaggle() and model.trainer.state == TrainerState.FITTING: - self.load_spawn_weights(model) - - self._model = model + if last_path and self.lightning_module.trainer.state == TrainerState.FITTING: + ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) + self.lightning_module.load_state_dict(ckpt) def _close_logger(self, trainer) -> None: if trainer.logger is not None: From ad87ffdaebaabb507467e33d9fd0d7c6f2982a8e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 5 Apr 2021 01:25:11 +0530 Subject: [PATCH 03/15] Update TPU Spawn --- .../plugins/training_type/tpu_spawn.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 6e625e946876c..455affb7f740e 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -186,26 +186,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ return output - def post_dispatch(self) -> None: - # restore main state with best weights - best_path = self.mp_queue.get() - last_path = self.mp_queue.get() - self._results = self.mp_queue.get() - - # recover the weights of the processes trained in the children - self.__recover_child_process_weights(best_path, last_path) - - def __recover_child_process_weights(self, best_path, last_path): - # transfer back the best path to the trainer - if self.lightning_module.trainer.checkpoint_callback: - self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also best score - - # load last weights - if last_path and self.lightning_module.trainer.state == TrainerState.FITTING: - ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) - self.lightning_module.load_state_dict(ckpt) - def _close_logger(self, trainer) -> None: if trainer.logger is not None: trainer.logger.finalize("success") From 5e6eb38e2b3ff9d8d30a162ba3d1211de8640030 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 5 Apr 2021 01:30:41 +0530 Subject: [PATCH 04/15] fix reduce boolean decision --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 455affb7f740e..75520ba77c1a4 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -162,7 +162,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj = torch.load(buffer) return obj - def reduce_decision(self, decision: bool) -> bool: + def reduce_boolean_decision(self, decision: bool) -> bool: decision = torch.tensor(int(decision), device=self.device) decision = self.reduce(decision, "sum") decision = bool(decision == self.world_size) From 3dba5b90068b0f0aec9cdcec4b58fc34efd79df1 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 5 Apr 2021 01:47:44 +0530 Subject: [PATCH 05/15] Update TPU Spawn --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 75520ba77c1a4..09a00aac1d1ce 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -43,15 +43,8 @@ class TPUSpawnPlugin(DDPSpawnPlugin): - def __init__( - self, - parallel_devices: Optional[List[torch.device]] = None, - num_nodes: int = 1, - **kwargs: Dict[str, Any] - ) -> None: - super().__init__( - parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs - ) + def __init__(self, parallel_devices: Optional[List[torch.device]] = None, **kwargs: Dict[str, Any]) -> None: + super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False, **kwargs) self.tpu_local_core_rank = 0 self.start_method = None From 197ab092b2674707eb4c1f7127e94c8d9924e885 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 5 Apr 2021 13:27:31 +0530 Subject: [PATCH 06/15] Update type hint --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 09a00aac1d1ce..e5f7e8385de45 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -43,7 +43,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): - def __init__(self, parallel_devices: Optional[List[torch.device]] = None, **kwargs: Dict[str, Any]) -> None: + def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None: super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False, **kwargs) self.tpu_local_core_rank = 0 self.start_method = None From db8ab871a213d2f1c22cec641e97f5ff4aacdf9c Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 5 Apr 2021 13:57:50 +0530 Subject: [PATCH 07/15] Fix Acc connector for TPUs --- .../plugins/training_type/tpu_spawn.py | 2 +- .../connectors/accelerator_connector.py | 21 ++++++++++--------- pytorch_lightning/utilities/enums.py | 1 + 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index e5f7e8385de45..6a9a52147382b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -44,7 +44,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None: - super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False, **kwargs) + super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False) self.tpu_local_core_rank = 0 self.start_method = None diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 30d2b48975a84..1e00d33cdf05a 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -257,7 +257,7 @@ def use_dp(self) -> bool: def use_ddp(self) -> bool: return self._distrib_type in ( DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED + DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN ) @property @@ -297,7 +297,8 @@ def parallel_devices(self) -> List[Union[torch.device, int]]: elif self.on_tpu: # explicitly don't make a tpu device here! # https://github.com/PyTorchLightning/pytorch-lightning/issues/3169 - devices = [i for i in self.parallel_device_ids] + if isinstance(self.tpu_cores, int): + devices = list(range(self.tpu_cores)) else: devices = [torch.device("cpu")] * self.num_processes return devices @@ -376,6 +377,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu + use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED @@ -386,7 +388,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: if os.environ.get("PL_IN_DDP_SUBPROCESS", False): use_torchelastic_ddp = False - if self.on_tpu: + if use_tpu_spawn: ddp_plugin_cls = TPUSpawnPlugin elif use_ddp_sharded: ddp_plugin_cls = DDPShardedPlugin @@ -409,11 +411,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) elif self.use_horovod: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) - elif self.on_tpu: - if isinstance(self.tpu_cores, list): - plugin = SingleTPUPlugin(self.tpu_id) - else: - plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores))) + elif self.on_tpu and isinstance(self.tpu_cores, list): + plugin = SingleTPUPlugin(self.tpu_id) else: single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids) plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu")) @@ -507,6 +506,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): # special case with TPUs elif self.distributed_backend == 'tpu' or self.tpu_cores is not None: self._device_type = DeviceType.TPU + if isinstance(self.tpu_cores, int): + self._distrib_type = DistributedType.TPU_SPAWN elif self.distributed_backend and self._distrib_type is None: self._distrib_type = DistributedType(self.distributed_backend) @@ -515,9 +516,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): if self.num_gpus > 0 and not _on_cpu: self._device_type = DeviceType.GPU - _distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) + _gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) # DP and DDP2 cannot run without GPU - if self.num_gpus == 0 and self._distrib_type in _distrib_types and not _on_cpu: + if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu: rank_zero_warn( 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' ) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 169481fa63e67..decd635aca979 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -72,6 +72,7 @@ def is_interactive_compatible(self) -> bool: DDP = 'ddp' DDP2 = 'ddp2' DDP_SPAWN = 'ddp_spawn' + TPU_SPAWN = 'tpu_spawn' DEEPSPEED = 'deepspeed' HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' From ae04ea8b7b56531c294b9837fd123e9ec2dda833 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 5 Apr 2021 14:13:42 +0530 Subject: [PATCH 08/15] Update Single device tpu --- .../plugins/training_type/single_tpu.py | 24 ++++--------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index b8d670ff16881..af2aa5e753367 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -36,9 +36,14 @@ def __init__(self, device: Union[torch.device, int]): self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 + @property def on_tpu(self) -> bool: return True + @property + def is_distributed(self): + return False + def model_to_device(self) -> None: self.model.to(self.root_device) @@ -49,21 +54,6 @@ def pre_dispatch(self) -> None: self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() - def post_dispatch(self) -> None: - model = self.lightning_module - - if on_colab_kaggle(): - rank_zero_warn("cleaning up... please do not interrupt") - self.save_spawn_weights(model) - - def save_spawn_weights(self, model: LightningModule) -> Optional[str]: - """ - Dump a temporary checkpoint after ddp ends to get weights out of the process - """ - path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt") - model.trainer.save_checkpoint(path) - return path - def on_save(self, checkpoint: dict) -> dict: """ Move XLA tensors to CPU before saving @@ -71,7 +61,3 @@ def on_save(self, checkpoint: dict) -> dict: https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors """ return move_data_to_device(checkpoint, torch.device("cpu")) - - @property - def is_distributed(self): - return False From 2f40f2ee58184aa65bde93765bb5188efc968d86 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 5 Apr 2021 14:49:06 +0530 Subject: [PATCH 09/15] fix code format issue --- pytorch_lightning/plugins/training_type/single_tpu.py | 8 +------- pytorch_lightning/plugins/training_type/tpu_spawn.py | 1 - 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index af2aa5e753367..4dd996d570ac0 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -11,15 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Optional, Union - import torch -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin -from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device if _TPU_AVAILABLE: @@ -28,7 +22,7 @@ class SingleTPUPlugin(SingleDevicePlugin): - def __init__(self, device: Union[torch.device, int]): + def __init__(self, device: int): if isinstance(device, int): device = xm.xla_device(device) super().__init__(device) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 6a9a52147382b..68068935127e2 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -20,7 +20,6 @@ import torch import torch.multiprocessing as mp -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn From 3ee4a8b89f186629b549a8ab194e6e39d0e25cd0 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 5 Apr 2021 14:53:48 +0530 Subject: [PATCH 10/15] fix --- pytorch_lightning/plugins/training_type/single_tpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 4dd996d570ac0..445072bf0112d 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -14,6 +14,7 @@ import torch from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin +from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device if _TPU_AVAILABLE: From 30ae218e4639faf9b4f8b5fb629129fc94bb3bc6 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 5 Apr 2021 15:56:14 +0530 Subject: [PATCH 11/15] tpu spawn interactive --- pytorch_lightning/utilities/enums.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index decd635aca979..1ec04549cd87e 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -62,7 +62,9 @@ class DistributedType(LightningEnum): @staticmethod def interactive_compatible_types() -> List['DistributedType']: """Returns a list containing interactive compatible DistributeTypes""" - return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN] + return [ + DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN, DistributedType.TPU_SPAWN + ] def is_interactive_compatible(self) -> bool: """Returns whether self is interactive compatible""" From b3fbad516834dcf47fc5f4fa51a893f4886d6af8 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 5 Apr 2021 16:57:19 +0530 Subject: [PATCH 12/15] fix tests --- tests/models/test_tpu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index b2ed0db87d8d5..f3a00818f7654 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -205,7 +205,7 @@ def validation_step(self, *args, **kwargs): def test_tpu_grad_norm(tmpdir): """Test if grad_norm works on TPU.""" tutils.reset_seed() - trainer_options = dict( + trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=4, @@ -216,7 +216,7 @@ def test_tpu_grad_norm(tmpdir): ) model = BoringModel() - tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + trainer.fit(model) @RunIf(tpu=True) From a81bcfa9c4ff40d0053eaa916bee79b12e5dda01 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 5 Apr 2021 22:59:51 +0530 Subject: [PATCH 13/15] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81846809fbf85..bdf2530116b92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -197,6 +197,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) +- Fixed TPU Colab hang issue, post training ([#6816](https://github.com/PyTorchLightning/pytorch-lightning/pull/6816)) + + ## [1.2.6] - 2021-03-30 ### Changed From 3c912cb83bf699cf4c5c28ff7ae597c39668402a Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 6 Apr 2021 01:39:53 +0530 Subject: [PATCH 14/15] Update values --- tests/models/test_tpu.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index f3a00818f7654..f541c578df1af 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -205,18 +205,18 @@ def validation_step(self, *args, **kwargs): def test_tpu_grad_norm(tmpdir): """Test if grad_norm works on TPU.""" tutils.reset_seed() - trainer = Trainer( + trainer_options = dict( default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=4, tpu_cores=1, - limit_train_batches=4, - limit_val_batches=4, + limit_train_batches=0.4, + limit_val_batches=0.4, gradient_clip_val=0.5, ) model = BoringModel() - trainer.fit(model) + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) @RunIf(tpu=True) From 851a8d971a55c724096d964791fef9bb96f34777 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 6 Apr 2021 13:01:59 +0530 Subject: [PATCH 15/15] refactor --- pytorch_lightning/plugins/training_type/single_tpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 445072bf0112d..df1238861ae06 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -24,8 +24,8 @@ class SingleTPUPlugin(SingleDevicePlugin): def __init__(self, device: int): - if isinstance(device, int): - device = xm.xla_device(device) + + device = xm.xla_device(device) super().__init__(device) self.tpu_local_core_rank = 0 @@ -36,7 +36,7 @@ def on_tpu(self) -> bool: return True @property - def is_distributed(self): + def is_distributed(self) -> bool: return False def model_to_device(self) -> None: