Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] TPU Training Type Plugin #6816

Merged
merged 18 commits into from
Apr 6, 2021
Merged
33 changes: 7 additions & 26 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional, Union

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.apply_func import move_data_to_device

if _TPU_AVAILABLE:
Expand All @@ -28,17 +23,22 @@

class SingleTPUPlugin(SingleDevicePlugin):

def __init__(self, device: Union[torch.device, int]):
def __init__(self, device: int):
if isinstance(device, int):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
device = xm.xla_device(device)
super().__init__(device)

self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0

@property
def on_tpu(self) -> bool:
return True

@property
def is_distributed(self):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
return False

def model_to_device(self) -> None:
self.model.to(self.root_device)

Expand All @@ -49,29 +49,10 @@ def pre_dispatch(self) -> None:
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()

def post_dispatch(self) -> None:
model = self.lightning_module

if on_colab_kaggle():
rank_zero_warn("cleaning up... please do not interrupt")
self.save_spawn_weights(model)

def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
"""
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
model.trainer.save_checkpoint(path)
return path

def on_save(self, checkpoint: dict) -> dict:
"""
Move XLA tensors to CPU before saving
Recommended on XLA Guide:
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
"""
return move_data_to_device(checkpoint, torch.device("cpu"))

@property
def is_distributed(self):
return False
95 changes: 7 additions & 88 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
import torch
import torch.multiprocessing as mp

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -32,28 +30,20 @@

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.core.xla_model import rendezvous
from torch_xla.distributed.parallel_loader import ParallelLoader
from torch_xla.distributed.parallel_loader import MpDeviceLoader
else:
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4

if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


class TPUSpawnPlugin(DDPSpawnPlugin):

def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
**kwargs: Dict[str, Any]
) -> None:
super().__init__(
parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs
)
def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None:
super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False)
self.tpu_local_core_rank = 0
self.start_method = None

Expand All @@ -74,10 +64,9 @@ def distributed_sampler_kwargs(self) -> dict:
def is_distributed(self):
return self.world_size != 1

def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> ParallelLoader:
def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader:
device = xm.xla_device()
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)
dataloader = MpDeviceLoader(dataloader, device)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return dataloader

def configure_ddp(self) -> None:
Expand Down Expand Up @@ -115,7 +104,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:

results = trainer.run_stage()

self.__save_end_of_training_weights(self.lightning_module)
self.transfer_distrib_spawn_state_on_fit_end(results)

# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
Expand All @@ -125,12 +113,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
if self.global_rank == 0:
time.sleep(2)

def __save_end_of_training_weights(self, model: LightningModule) -> None:
# when training ends on these platforms dump weights to get out of the main process
if on_colab_kaggle():
rank_zero_warn("cleaning up... please do not interrupt")
self.save_spawn_weights(model)

def model_to_device(self) -> None:
self._model.to(xm.xla_device())

Expand Down Expand Up @@ -172,37 +154,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj = torch.load(buffer)
return obj

def load_spawn_weights(self, original_model: LightningModule) -> LightningModule:
"""
Load the temp weights saved in the process
To recover the trained model from the ddp process we load the saved weights
"""

loaded_model = original_model

if self.is_global_zero:
# load weights saved in ddp
path = os.path.join(original_model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
loaded_model = original_model.__class__.load_from_checkpoint(path)

# copy loaded weights to old model
original_model.load_state_dict(loaded_model.state_dict())

# remove ddp weights
os.remove(path)

return loaded_model

def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
"""
if model.trainer.is_global_zero:
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
model.trainer.save_checkpoint(path)
return path

def reduce_decision(self, decision: bool) -> bool:
def reduce_boolean_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.device)
decision = self.reduce(decision, "sum")
decision = bool(decision == self.world_size)
Expand All @@ -226,39 +178,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[

return output

def post_dispatch(self) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Check if trainer references can be resolved otherwise
model = self.lightning_module

# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
self._results = self.mp_queue.get()

# transfer back the best path to the trainer
if self.lightning_module.trainer.checkpoint_callback is not None:
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also bets score

# load last weights
if last_path and model.trainer.state == TrainerState.FITTING:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

self._model = model

# when training completes, load the weights back in main process
self.__load_weights_on_main_process()

def __load_weights_on_main_process(self) -> None:
model = self.lightning_module

# load weights if not interrupted
if on_colab_kaggle() and model.trainer.state == TrainerState.FITTING:
self.load_spawn_weights(model)

self._model = model

def _close_logger(self, trainer) -> None:
if trainer.logger is not None:
trainer.logger.finalize("success")
Expand Down
21 changes: 11 additions & 10 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def use_dp(self) -> bool:
def use_ddp(self) -> bool:
return self._distrib_type in (
DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED,
DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED
DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN
)

@property
Expand Down Expand Up @@ -297,7 +297,8 @@ def parallel_devices(self) -> List[Union[torch.device, int]]:
elif self.on_tpu:
# explicitly don't make a tpu device here!
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3169
devices = [i for i in self.parallel_device_ids]
if isinstance(self.tpu_cores, int):
devices = list(range(self.tpu_cores))
else:
devices = [torch.device("cpu")] * self.num_processes
return devices
Expand Down Expand Up @@ -376,6 +377,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
Expand All @@ -386,7 +388,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
use_torchelastic_ddp = False

if self.on_tpu:
if use_tpu_spawn:
ddp_plugin_cls = TPUSpawnPlugin
elif use_ddp_sharded:
ddp_plugin_cls = DDPShardedPlugin
Expand All @@ -409,11 +411,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
elif self.use_horovod:
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
elif self.on_tpu:
if isinstance(self.tpu_cores, list):
plugin = SingleTPUPlugin(self.tpu_id)
else:
plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores)))
elif self.on_tpu and isinstance(self.tpu_cores, list):
plugin = SingleTPUPlugin(self.tpu_id)
else:
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
Expand Down Expand Up @@ -507,6 +506,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
# special case with TPUs
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
self._device_type = DeviceType.TPU
if isinstance(self.tpu_cores, int):
self._distrib_type = DistributedType.TPU_SPAWN
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)

Expand All @@ -515,9 +516,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
if self.num_gpus > 0 and not _on_cpu:
self._device_type = DeviceType.GPU

_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
_gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
# DP and DDP2 cannot run without GPU
if self.num_gpus == 0 and self._distrib_type in _distrib_types and not _on_cpu:
if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu:
rank_zero_warn(
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
)
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ class DistributedType(LightningEnum):
@staticmethod
def interactive_compatible_types() -> List['DistributedType']:
"""Returns a list containing interactive compatible DistributeTypes"""
return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN]
return [
DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN, DistributedType.TPU_SPAWN
]

def is_interactive_compatible(self) -> bool:
"""Returns whether self is interactive compatible"""
Expand All @@ -72,6 +74,7 @@ def is_interactive_compatible(self) -> bool:
DDP = 'ddp'
DDP2 = 'ddp2'
DDP_SPAWN = 'ddp_spawn'
TPU_SPAWN = 'tpu_spawn'
DEEPSPEED = 'deepspeed'
HOROVOD = 'horovod'
DDP_SHARDED = 'ddp_sharded'
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def validation_step(self, *args, **kwargs):
def test_tpu_grad_norm(tmpdir):
"""Test if grad_norm works on TPU."""
tutils.reset_seed()
trainer_options = dict(
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=4,
Expand All @@ -216,7 +216,7 @@ def test_tpu_grad_norm(tmpdir):
)

model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
trainer.fit(model)
carmocca marked this conversation as resolved.
Show resolved Hide resolved


@RunIf(tpu=True)
Expand Down