Skip to content

Commit

Permalink
Support SPMD through the xla:// init_method (#5706)
Browse files Browse the repository at this point in the history
* Support SPMD through the xla:// init_method

* Ensure compatibility with multithreading

* Add get_master_ip to runtime

* Use for loop instead of iterator
  • Loading branch information
jonb377 authored and bhavya01 committed Apr 22, 2024
1 parent 053034a commit 6f73a5b
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 3 deletions.
7 changes: 7 additions & 0 deletions test/pjrt/test_runtime_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ def test_execute_time_metric(self):
f"Expected exectue time of {i} to take more than "
f"{expected_time_seconds} seconds, got {v / 1e9} seconds")

@mock.patch('torch_xla._internal.tpu.get_worker_ips')
def test_master_ip_discovery(self, patched_get_worker_ips):
# A basic test to verify the non-SPMD codepath returns the correct IP. Two
# IPs are needed to avoid the short-circuit return of localhost.
patched_get_worker_ips.return_value = ['10.0.0.1', '10.0.0.2']
self.assertTrue(xr.get_master_ip(), '10.0.0.1')


if __name__ == '__main__':
absltest.main()
9 changes: 9 additions & 0 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,15 @@ def test_manager_async_step_tracking(self, tmpdir):
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))

@unittest.skipUnless(xr.device_type() == 'TPU',
'TPU required for worker IP discovery')
@unittest.mock.patch('torch_xla._internal.tpu.get_worker_ips')
def test_master_ip_discovery(self, patched_get_worker_ips):
# A basic test to verify the SPMD codepath returns the correct IP. Two IPs
# are needed to avoid the short-circuit return of localhost.
patched_get_worker_ips.return_value = ['10.0.0.1', '10.0.0.2']
self.assertTrue(xr.get_master_ip(), '10.0.0.1')


if __name__ == '__main__':
test = unittest.main()
Expand Down
7 changes: 5 additions & 2 deletions torch_xla/_internal/rendezvous.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def pjrt_rendezvous_handler(url: str,
) == 'TPU' else 'localhost'

master_port = xu.getenv_as('MASTER_PORT', int, 12355)
world_size = xr.world_size()
with _store_lock:
global _store
if not _store:
Expand All @@ -44,4 +43,8 @@ def pjrt_rendezvous_handler(url: str,
xr.process_count(),
is_master=xr.process_index() == 0)

yield (_store, xr.global_ordinal(), world_size)
# In SPMD, the world size and rank are determined by the process count and
# index, while in multiprocess they are based on the device count and ordinal.
world_size = xr.process_count() if xr.is_spmd() else xr.world_size()
rank = xr.process_index() if xr.is_spmd() else xr.global_ordinal()
yield (_store, rank, world_size)
30 changes: 29 additions & 1 deletion torch_xla/_internal/tpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import glob
from ipaddress import ip_address
import operator
import os
import pathlib
Expand All @@ -10,6 +11,7 @@
import yaml

import torch
import torch_xla
import torch_xla.utils.utils as xu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -268,23 +270,49 @@ def configure_topology(local_rank: int,


def discover_master_worker_ip(use_localhost: bool = True) -> str:
"""Find the IP of the TPU host with TPU:0.
"""Find the IP of the master TPU host.
In multiprocess, this is the host with TPU:0.
In SPMD mode, this is the host running process 0.
TPU device IDs are nondeterministic and independent from Cloud TPU worker IDs.
Args:
use_localhost: if there is only one TPU host, return 'localhost` instead
of that host's internal IP.
"""
import torch_xla.runtime as xr
worker_ips = get_worker_ips()
if len(worker_ips) == 1:
return 'localhost'

tpu_env = get_tpu_env()
current_worker_id = int(tpu_env[xenv.WORKER_ID])
if xr.is_spmd():
return _spmd_find_master_ip(worker_ips[current_worker_id])

t = torch.tensor([current_worker_id], device=xm.xla_device())
xm.collective_broadcast([t])
xm.mark_step()

master_worker_id = int(t.cpu())
return worker_ips[master_worker_id]


def _spmd_find_master_ip(current_worker_ip: str) -> str:
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
ip_int = int(ip_address(current_worker_ip))
n_dev = xr.global_runtime_device_count()
local_ndev = len(torch_xla._XLAC._xla_get_runtime_devices())
# Create a global (n_dev x 2) tensor containing all process indices and IPs,
# and find the process 0 IP as the master IP.
shard = torch.LongTensor([[xr.process_index(), ip_int]])
op_sharding = xs.Mesh(range(n_dev), (n_dev, 1)).get_op_sharding((0, 1))
global_tensor = from_cpu_shards([shard] * local_ndev, op_sharding).cpu()
# Process 0 may not control device 0, so we must do a linear search.
for proc, ip in global_tensor.tolist():
if proc == 0:
return str(ip_address(ip))
raise RuntimeError('Could not find IP of host running process 0')
11 changes: 11 additions & 0 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,14 @@ def is_spmd():
"""Returns if SPMD is set for execution."""
# TODO(yeounoh) replace this when we fully deprecate the flag.
return xu.check_env_flag('XLA_USE_SPMD')


@requires_pjrt
def get_master_ip() -> str:
"""Retrieve the master worker IP for the runtime. This calls into
backend-specific discovery APIs.
Returns master worker's IP address as a string."""
if device_type() == 'TPU':
return tpu.discover_master_worker_ip()
raise RuntimeError(f'IP discovery not supported for device: {device_type()}')

0 comments on commit 6f73a5b

Please sign in to comment.