Skip to content

Commit

Permalink
[Fix] TPU Training Type Plugin (#6816)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored and lexierule committed Apr 7, 2021
1 parent edf6289 commit c7422d4
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 129 deletions.
50 changes: 22 additions & 28 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import os
from typing import Optional, Union

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.apply_func import move_data_to_device

if _TPU_AVAILABLE:
Expand All @@ -15,21 +23,26 @@

class SingleTPUPlugin(SingleDevicePlugin):

def __init__(self, device: Union[torch.device, int]):
if isinstance(device, int):
device = xm.xla_device(device)
def __init__(self, device: int):

device = xm.xla_device(device)
super().__init__(device)

self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0

@property
def on_tpu(self) -> bool:
return True

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
self.model_to_device()
return self._model

@property
def is_distributed(self) -> bool:
return False

def model_to_device(self) -> None:
self._model.to(self.root_device)
Expand All @@ -41,29 +54,10 @@ def pre_dispatch(self) -> None:
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()

def post_dispatch(self) -> None:
model = self.lightning_module

if on_colab_kaggle():
rank_zero_warn("cleaning up... please do not interrupt")
self.save_spawn_weights(model)

def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
"""
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
model.trainer.save_checkpoint(path)
return path

def on_save(self, checkpoint: dict) -> dict:
"""
Move XLA tensors to CPU before saving
Recommended on XLA Guide:
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
"""
return move_data_to_device(checkpoint, torch.device("cpu"))

@property
def is_distributed(self):
return False
96 changes: 7 additions & 89 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
import torch
import torch.multiprocessing as mp

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
Expand All @@ -18,28 +16,20 @@

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.core.xla_model import rendezvous
from torch_xla.distributed.parallel_loader import ParallelLoader
from torch_xla.distributed.parallel_loader import MpDeviceLoader
else:
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4

if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


class TPUSpawnPlugin(DDPSpawnPlugin):

def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
**kwargs: Dict[str, Any]
) -> None:
super().__init__(
parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs
)
def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None:
super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False)
self.tpu_local_core_rank = 0
self.start_method = None

Expand All @@ -61,10 +51,9 @@ def distributed_sampler_kwargs(self) -> dict:
def is_distributed(self):
return self.world_size != 1

def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> ParallelLoader:
def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader:
device = xm.xla_device()
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)
dataloader = MpDeviceLoader(dataloader, device)
return dataloader

def configure_ddp(self) -> None:
Expand Down Expand Up @@ -104,7 +93,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:

results = trainer.train_or_test_or_predict()

self.__save_end_of_training_weights(self.lightning_module)
self.transfer_distrib_spawn_state_on_fit_end(results)

# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
Expand All @@ -114,12 +102,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
if self.global_rank == 0:
time.sleep(2)

def __save_end_of_training_weights(self, model: LightningModule) -> None:
# when training ends on these platforms dump weights to get out of the main process
if on_colab_kaggle():
rank_zero_warn("cleaning up... please do not interrupt")
self.save_spawn_weights(model)

def model_to_device(self) -> None:
self._model.to(xm.xla_device())

Expand Down Expand Up @@ -159,37 +141,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj = torch.load(buffer)
return obj

def load_spawn_weights(self, original_model: LightningModule) -> LightningModule:
"""
Load the temp weights saved in the process
To recover the trained model from the ddp process we load the saved weights
"""

loaded_model = original_model

if self.is_global_zero:
# load weights saved in ddp
path = os.path.join(original_model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
loaded_model = original_model.__class__.load_from_checkpoint(path)

# copy loaded weights to old model
original_model.load_state_dict(loaded_model.state_dict())

# remove ddp weights
os.remove(path)

return loaded_model

def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
"""
if model.trainer.is_global_zero:
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
model.trainer.save_checkpoint(path)
return path

def reduce_decision(self, decision: bool) -> bool:
def reduce_boolean_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.device)
decision = self.reduce(decision, "sum")
decision = bool(decision == self.world_size)
Expand All @@ -213,40 +165,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[

return output

def post_dispatch(self) -> None:
# TODO: Check if trainer references can be resolved otherwise
model = self.lightning_module

# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
self._results = self.mp_queue.get()

# transfer back the best path to the trainer
if self.lightning_module.trainer.checkpoint_callback is not None:
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also bets score

# load last weights
if last_path and not self.lightning_module.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

self._model = model

# when training completes, load the weights back in main process
self.__load_weights_on_main_process()

def __load_weights_on_main_process(self) -> None:
model = self.lightning_module

# load weights if not interrupted
# TODO: check for trainer reference
if on_colab_kaggle() and not model.trainer.testing:
self.load_spawn_weights(model)

self._model = model

def _close_logger(self, trainer) -> None:
if trainer.logger is not None:
trainer.logger.finalize("success")
Expand Down
21 changes: 11 additions & 10 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def use_dp(self) -> bool:
def use_ddp(self) -> bool:
return self._distrib_type in (
DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED,
DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED
DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN
)

@property
Expand Down Expand Up @@ -291,7 +291,8 @@ def parallel_devices(self) -> Union[List[torch.device], int]:
elif self.on_tpu:
# explicitly don't make a tpu device here!
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3169
devices = [i for i in self.parallel_device_ids]
if isinstance(self.tpu_cores, int):
devices = list(range(self.tpu_cores))
else:
devices = [torch.device("cpu")] * self.num_processes
return devices
Expand Down Expand Up @@ -369,6 +370,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
Expand All @@ -379,7 +381,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
use_torchelastic_ddp = False

if self.on_tpu:
if use_tpu_spawn:
ddp_plugin_cls = TPUSpawnPlugin
elif use_ddp_sharded:
ddp_plugin_cls = DDPShardedPlugin
Expand All @@ -402,11 +404,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
elif self.use_horovod:
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
elif self.on_tpu:
if isinstance(self.tpu_cores, list):
plugin = SingleTPUPlugin(self.tpu_id)
else:
plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores)))
elif self.on_tpu and isinstance(self.tpu_cores, list):
plugin = SingleTPUPlugin(self.tpu_id)
else:
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
Expand Down Expand Up @@ -507,6 +506,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
# special case with TPUs
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
self._device_type = DeviceType.TPU
if isinstance(self.tpu_cores, int):
self._distrib_type = DistributedType.TPU_SPAWN
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)

Expand All @@ -515,9 +516,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
if self.num_gpus > 0 and not _on_cpu:
self._device_type = DeviceType.GPU

_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
_gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
# DP and DDP2 cannot run without GPU
if self.num_gpus == 0 and self._distrib_type in _distrib_types and not _on_cpu:
if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu:
rank_zero_warn(
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
)
Expand Down
13 changes: 13 additions & 0 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,23 @@ class DistributedType(LightningEnum):
>>> DistributedType.DDP2 in ('ddp2', )
True
"""

@staticmethod
def interactive_compatible_types() -> List['DistributedType']:
"""Returns a list containing interactive compatible DistributeTypes"""
return [
DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN, DistributedType.TPU_SPAWN
]

def is_interactive_compatible(self) -> bool:
"""Returns whether self is interactive compatible"""
return self in DistributedType.interactive_compatible_types()

DP = 'dp'
DDP = 'ddp'
DDP2 = 'ddp2'
DDP_SPAWN = 'ddp_spawn'
TPU_SPAWN = 'tpu_spawn'
DEEPSPEED = 'deepspeed'
HOROVOD = 'horovod'
DDP_SHARDED = 'ddp_sharded'
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ def test_tpu_grad_norm(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=1,
limit_train_batches=4,
limit_val_batches=4,
limit_train_batches=0.4,
limit_val_batches=0.4,
gradient_clip_val=0.5,
)

Expand Down

0 comments on commit c7422d4

Please sign in to comment.