Skip to content

Commit

Permalink
[Fix] TPU Training Type Plugin (#6816)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 committed Apr 6, 2021
1 parent eafec7d commit cf8e828
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 129 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))


- Fixed TPU Colab hang issue, post training ([#6816](https://github.com/PyTorchLightning/pytorch-lightning/pull/6816))


- Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588))


Expand Down
37 changes: 9 additions & 28 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional, Union

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.apply_func import move_data_to_device

if _TPU_AVAILABLE:
Expand All @@ -28,17 +23,22 @@

class SingleTPUPlugin(SingleDevicePlugin):

def __init__(self, device: Union[torch.device, int]):
if isinstance(device, int):
device = xm.xla_device(device)
def __init__(self, device: int):

device = xm.xla_device(device)
super().__init__(device)

self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0

@property
def on_tpu(self) -> bool:
return True

@property
def is_distributed(self) -> bool:
return False

def model_to_device(self) -> None:
self.model.to(self.root_device)

Expand All @@ -49,29 +49,10 @@ def pre_dispatch(self) -> None:
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()

def post_dispatch(self) -> None:
model = self.lightning_module

if on_colab_kaggle():
rank_zero_warn("cleaning up... please do not interrupt")
self.save_spawn_weights(model)

def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
"""
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
model.trainer.save_checkpoint(path)
return path

def on_save(self, checkpoint: dict) -> dict:
"""
Move XLA tensors to CPU before saving
Recommended on XLA Guide:
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
"""
return move_data_to_device(checkpoint, torch.device("cpu"))

@property
def is_distributed(self):
return False
95 changes: 7 additions & 88 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
import torch
import torch.multiprocessing as mp

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -32,28 +30,20 @@

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.core.xla_model import rendezvous
from torch_xla.distributed.parallel_loader import ParallelLoader
from torch_xla.distributed.parallel_loader import MpDeviceLoader
else:
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4

if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


class TPUSpawnPlugin(DDPSpawnPlugin):

def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
**kwargs: Dict[str, Any]
) -> None:
super().__init__(
parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs
)
def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None:
super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False)
self.tpu_local_core_rank = 0
self.start_method = None

Expand All @@ -74,10 +64,9 @@ def distributed_sampler_kwargs(self) -> dict:
def is_distributed(self):
return self.world_size != 1

def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> ParallelLoader:
def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader:
device = xm.xla_device()
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)
dataloader = MpDeviceLoader(dataloader, device)
return dataloader

def configure_ddp(self) -> None:
Expand Down Expand Up @@ -115,7 +104,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:

results = trainer.run_stage()

self.__save_end_of_training_weights(self.lightning_module)
self.transfer_distrib_spawn_state_on_fit_end(results)

# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
Expand All @@ -125,12 +113,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
if self.global_rank == 0:
time.sleep(2)

def __save_end_of_training_weights(self, model: LightningModule) -> None:
# when training ends on these platforms dump weights to get out of the main process
if on_colab_kaggle():
rank_zero_warn("cleaning up... please do not interrupt")
self.save_spawn_weights(model)

def model_to_device(self) -> None:
self._model.to(xm.xla_device())

Expand Down Expand Up @@ -172,37 +154,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj = torch.load(buffer)
return obj

def load_spawn_weights(self, original_model: LightningModule) -> LightningModule:
"""
Load the temp weights saved in the process
To recover the trained model from the ddp process we load the saved weights
"""

loaded_model = original_model

if self.is_global_zero:
# load weights saved in ddp
path = os.path.join(original_model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
loaded_model = original_model.__class__.load_from_checkpoint(path)

# copy loaded weights to old model
original_model.load_state_dict(loaded_model.state_dict())

# remove ddp weights
os.remove(path)

return loaded_model

def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
"""
if model.trainer.is_global_zero:
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
model.trainer.save_checkpoint(path)
return path

def reduce_decision(self, decision: bool) -> bool:
def reduce_boolean_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.device)
decision = self.reduce(decision, "sum")
decision = bool(decision == self.world_size)
Expand All @@ -226,39 +178,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[

return output

def post_dispatch(self) -> None:
# TODO: Check if trainer references can be resolved otherwise
model = self.lightning_module

# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
self._results = self.mp_queue.get()

# transfer back the best path to the trainer
if self.lightning_module.trainer.checkpoint_callback is not None:
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also bets score

# load last weights
if last_path and model.trainer.state == TrainerState.FITTING:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

self._model = model

# when training completes, load the weights back in main process
self.__load_weights_on_main_process()

def __load_weights_on_main_process(self) -> None:
model = self.lightning_module

# load weights if not interrupted
if on_colab_kaggle() and model.trainer.state == TrainerState.FITTING:
self.load_spawn_weights(model)

self._model = model

def _close_logger(self, trainer) -> None:
if trainer.logger is not None:
trainer.logger.finalize("success")
Expand Down
21 changes: 11 additions & 10 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def use_dp(self) -> bool:
def use_ddp(self) -> bool:
return self._distrib_type in (
DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED,
DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED
DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN
)

@property
Expand Down Expand Up @@ -297,7 +297,8 @@ def parallel_devices(self) -> List[Union[torch.device, int]]:
elif self.on_tpu:
# explicitly don't make a tpu device here!
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3169
devices = [i for i in self.parallel_device_ids]
if isinstance(self.tpu_cores, int):
devices = list(range(self.tpu_cores))
else:
devices = [torch.device("cpu")] * self.num_processes
return devices
Expand Down Expand Up @@ -376,6 +377,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
Expand All @@ -386,7 +388,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
use_torchelastic_ddp = False

if self.on_tpu:
if use_tpu_spawn:
ddp_plugin_cls = TPUSpawnPlugin
elif use_ddp_sharded:
ddp_plugin_cls = DDPShardedPlugin
Expand All @@ -409,11 +411,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
elif self.use_horovod:
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
elif self.on_tpu:
if isinstance(self.tpu_cores, list):
plugin = SingleTPUPlugin(self.tpu_id)
else:
plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores)))
elif self.on_tpu and isinstance(self.tpu_cores, list):
plugin = SingleTPUPlugin(self.tpu_id)
else:
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
Expand Down Expand Up @@ -507,6 +506,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
# special case with TPUs
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
self._device_type = DeviceType.TPU
if isinstance(self.tpu_cores, int):
self._distrib_type = DistributedType.TPU_SPAWN
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)

Expand All @@ -515,9 +516,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
if self.num_gpus > 0 and not _on_cpu:
self._device_type = DeviceType.GPU

_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
_gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
# DP and DDP2 cannot run without GPU
if self.num_gpus == 0 and self._distrib_type in _distrib_types and not _on_cpu:
if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu:
rank_zero_warn(
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
)
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ class DistributedType(LightningEnum):
@staticmethod
def interactive_compatible_types() -> List['DistributedType']:
"""Returns a list containing interactive compatible DistributeTypes"""
return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN]
return [
DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN, DistributedType.TPU_SPAWN
]

def is_interactive_compatible(self) -> bool:
"""Returns whether self is interactive compatible"""
Expand All @@ -72,6 +74,7 @@ def is_interactive_compatible(self) -> bool:
DDP = 'ddp'
DDP2 = 'ddp2'
DDP_SPAWN = 'ddp_spawn'
TPU_SPAWN = 'tpu_spawn'
DEEPSPEED = 'deepspeed'
HOROVOD = 'horovod'
DDP_SHARDED = 'ddp_sharded'
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ def test_tpu_grad_norm(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=1,
limit_train_batches=4,
limit_val_batches=4,
limit_train_batches=0.4,
limit_val_batches=0.4,
gradient_clip_val=0.5,
)

Expand Down

0 comments on commit cf8e828

Please sign in to comment.