Skip to content

Commit

Permalink
Device updates for TPU Pod (#7243)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored Apr 30, 2021
1 parent 16d6c98 commit 490cc57
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `tpu_distributed` check for TPU Spawn barrier ([#7241](https://github.com/PyTorchLightning/pytorch-lightning/pull/7241))


- Added device updates to TPU Spawn for Pod training ([#7243](https://github.com/PyTorchLightning/pytorch-lightning/pull/7243))


- Added warning when missing `Callback` and using `resume_from_checkpoint` ([#7254](https://github.com/PyTorchLightning/pytorch-lightning/pull/7254))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def log(
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
# TODO: Find a way to make the reduction only once, so we don't need to clone.
if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor):
if (is_dist_initialized or tpu_distributed()) and isinstance(value, torch.Tensor):
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pytorch_lightning.utilities.seed import reset_seed

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.core.xla_model import rendezvous
Expand All @@ -58,7 +59,7 @@ def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = F

@property
def global_rank(self) -> int:
return self.tpu_local_core_rank
return self.tpu_global_core_rank

@property
def local_rank(self) -> int:
Expand Down Expand Up @@ -175,7 +176,8 @@ def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.device)

def barrier(self, name: Optional[str] = None) -> None:
if tpu_distributed():
# HOST_WORLD_SIZE is None outside the xmp.spawn process
if os.getenv(xenv.HOST_WORLD_SIZE, None) and tpu_distributed():
rendezvous(name)

def transfer_distrib_spawn_state_on_fit_end(self, results):
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/utilities/xla_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ def _is_device_tpu() -> bool:
Return:
A boolean value indicating if TPU devices are available
"""

# For the TPU Pod training process, for example, if we have
# TPU v3-32 with 4 VMs, the world size would be 4 and as
# we would have to use `torch_xla.distributed.xla_dist` for
# multiple VMs and TPU_CONFIG won't be available, running
# `xm.get_xla_supported_devices("TPU")` won't be possible.
if xm.xrt_world_size() > 1:
return True
return len(xm.get_xla_supported_devices("TPU")) > 0

@staticmethod
Expand Down
27 changes: 27 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,30 @@ def teardown(self, stage):

model = DebugModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_host_world_size(tmpdir):
"""Test Host World size env setup on TPU."""

class DebugModel(BoringModel):

def on_train_start(self):
assert os.environ.get("XRT_HOST_WORLD_SIZE") == str(1)

def teardown(self, stage):
assert "XRT_HOST_WORLD_SIZE" not in os.environ

tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=8,
limit_train_batches=0.4,
limit_val_batches=0.4,
)

model = DebugModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)

0 comments on commit 490cc57

Please sign in to comment.