Skip to content

Commit

Permalink
Remove or simplify hardcoded lists of device types (#6235)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored and bhavya01 committed Apr 22, 2024
1 parent 04455c2 commit deff513
Show file tree
Hide file tree
Showing 32 changed files with 136 additions and 174 deletions.
3 changes: 1 addition & 2 deletions .circleci/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ function run_torch_xla_python_tests() {
if [ -x "$(command -v nvidia-smi)" ]; then
# These tests fail on CUDA with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# TODO(xiowei replace gpu with cuda): remove the test below with PJRT_DEVICE=GPU because PJRT_DEVICE=GPU is being deprecated.
PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# Syncfree SGD optimizer tests
Expand Down
2 changes: 1 addition & 1 deletion test/pjrt/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _ddp_init(index: int = ...):
def test_ddp_init(self):
pjrt.run_multiprocess(self._ddp_init)

@absltest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
@absltest.skipIf(xr.device_type() == 'CUDA',
"GPU device is not supported by pjrt.spawn_threads")
def test_ddp_init_threaded(self):
pjrt.spawn_threads(self._ddp_init)
Expand Down
2 changes: 1 addition & 1 deletion test/pjrt/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt):
xr.using_pjrt()

if expect_using_pjrt:
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU', 'ROCM', 'GPU'])
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU'])
else:
self.assertIsNone(xr.device_type())

Expand Down
2 changes: 1 addition & 1 deletion test/pjrt/test_runtime_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from absl.testing import absltest, parameterized


@unittest.skipIf(xr.device_type() not in ('GPU', 'CUDA', 'ROCM'),
@unittest.skipIf(xr.device_type() != "CUDA",
f"GPU tests should only run on GPU devices.")
class TestExperimentalPjrtGpu(parameterized.TestCase):

Expand Down
72 changes: 24 additions & 48 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_debugging_spmd_single_host_tiled_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -108,10 +106,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_partial_replication_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -159,10 +155,8 @@ def test_single_host_partial_replication_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_replicated_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -204,10 +198,8 @@ def test_single_host_replicated_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_debugging_spmd_single_host_tiled_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -247,10 +239,8 @@ def test_debugging_spmd_single_host_tiled_cpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_single_host_partial_replication_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -291,10 +281,8 @@ def test_single_host_partial_replication_cpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_single_host_replicated_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -341,10 +329,8 @@ def test_single_host_replicated_cpu(self):
# e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}
# e.g.: sharding={replicated}

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_debugging_spmd_multi_host_tiled_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}'
Expand Down Expand Up @@ -453,10 +439,8 @@ def test_debugging_spmd_multi_host_tiled_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_multi_host_partial_replication_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}'
Expand Down Expand Up @@ -537,10 +521,8 @@ def test_multi_host_partial_replication_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_multi_host_replicated_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{replicated}'
Expand Down Expand Up @@ -573,10 +555,8 @@ def test_multi_host_replicated_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_debugging_spmd_multi_host_tiled_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}'
Expand Down Expand Up @@ -685,10 +665,8 @@ def test_debugging_spmd_multi_host_tiled_cpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_multi_host_partial_replication_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}'
Expand Down Expand Up @@ -769,10 +747,8 @@ def test_multi_host_partial_replication_cpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_multi_host_replicated_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{replicated}'
Expand Down
8 changes: 4 additions & 4 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,8 @@ def test_manager_async_step_tracking(self, tmpdir):
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))

@unittest.skipUnless(xr.device_type() == 'TPU',
'TPU required for worker IP discovery')
@unittest.skipIf(xr.device_type() != 'TPU',
'TPU required for worker IP discovery')
@unittest.mock.patch('torch_xla._internal.tpu.get_worker_ips')
def test_master_ip_discovery(self, patched_get_worker_ips):
# A basic test to verify the SPMD codepath returns the correct IP. Two IPs
Expand Down Expand Up @@ -543,8 +543,8 @@ def test_preemption_sync_manager(self):
# Scope the PreemptionSyncManager to the lifespan of the test.
torch_xla._XLAC._deactivate_preemption_sync_manager()

@unittest.skipUnless(xr.device_type() == 'TPU',
'TPU required for worker IP discovery')
@unittest.skipIf(xr.device_type() != 'TPU',
'TPU required for worker IP discovery')
@run_with_tmpdir
def test_auto_checkpoint(self, tmpdir):
# Create a checkpoint manager with a long save interval
Expand Down
5 changes: 2 additions & 3 deletions test/spmd/test_xla_sharding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
import torch_xla.utils.utils as xu


@unittest.skipIf(not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'),
f"Requires PJRT_DEVICE set to `TPU` or `CPU`.")
@unittest.skipUnless(xr.device_type() in ("TPU", "CPU"),
f"Requires PJRT_DEVICE set to `TPU` or `CPU`.")
class XlaShardingTest(unittest.TestCase):

class SimpleLinear(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_spmd_python_api_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

@unittest.skipIf(xr.device_type() not in ['GPU', 'TPU', 'CUDA', 'ROCM'],
@unittest.skipIf(xr.device_type() not in ['TPU', 'CUDA'],
f"TPU/GPU autocast test.")
def test_xla_autocast_api(self):
device = xm.xla_device()
Expand Down
2 changes: 1 addition & 1 deletion test/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _ddp_correctness(rank, use_large_net: bool, debug: bool):
# We cannot run this guard before XMP,
# see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing.
device = xm.xla_device()
if xm.xla_device_hw(device) not in ('GPU', 'TPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) not in ('TPU', 'CUDA'):
print(
'Default device {} is not a TPU device'.format(device),
file=sys.stderr)
Expand Down
10 changes: 5 additions & 5 deletions test/test_fsdp_auto_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def forward(self, x):
hidden2 = self.fc2(x)
return hidden1, hidden2

@unittest.skipIf(xr.device_type() in (
'GPU', 'ROCM', 'CUDA'
), "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)"
)
@unittest.skipIf(
xr.device_type() == 'CUDA',
"This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)"
)
def test(self):
dev = xm.xla_device()
input = torch.zeros([16, 16], device=dev)
Expand All @@ -50,7 +50,7 @@ def test(self):

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
test = unittest.main(exit=False)
sys.exit(0 if test.result.wasSuccessful() else 1)
else:
Expand Down
7 changes: 2 additions & 5 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import unittest
Expand Down Expand Up @@ -171,11 +172,7 @@ def test_metrics_report(self):
report = met.metrics_report()
self.assertIn("CachedCompile", report)

@unittest.skipIf(
xm.get_xla_supported_devices("CUDA") or
xm.get_xla_supported_devices("GPU") or
xm.get_xla_supported_devices("ROCM") or
xm.get_xla_supported_devices("TPU"), f"This test only works on CPU.")
@unittest.skipIf(xr.device_type() != "CPU", f"This test only works on CPU.")
def test_execute_time_metric(self):
# Initialize the client before starting the timer.
xm.xla_device()
Expand Down
2 changes: 1 addition & 1 deletion test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _mp_fn(index):
device = xm.xla_device()
world_size = xm.xrt_world_size()
input_list_size = 5
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM', 'NEURON'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'):
# Testing with a single replica group
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor, dim=0)
Expand Down
2 changes: 1 addition & 1 deletion test/test_mp_distributed_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def _mp_fn(index):
device = xm.xla_device()

if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xm.xrt_world_size()
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
Expand Down
3 changes: 1 addition & 2 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,7 @@ def test_get_real_xla_devices(self):
devices = xm.get_xla_supported_devices()
xla_devices = torch_xla._XLAC._xla_real_devices(devices)
for device, xdevice in zip(devices, xla_devices):
self.assertTrue(
re.match(r'(CPU|GPU|TPU|CUDA|ROCM):\d+$', xdevice) is not None)
self.assertIsNotNone(re.fullmatch(r'[A-Z]+:\d+$', xdevice))

def test_negative_slice(self):
t = _gen_tensor(32, 24, 32)
Expand Down
2 changes: 1 addition & 1 deletion test/test_torch_distributed_all_gather_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
2 changes: 1 addition & 1 deletion test/test_torch_distributed_all_reduce_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
2 changes: 1 addition & 1 deletion test/test_torch_distributed_reduce_scatter_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_imagenet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def train_imagenet():
if FLAGS.amp:
if device_hw == 'TPU':
scaler = None
elif device_hw in ('GPU', 'CUDA', 'ROCM'):
elif device_hw == 'CUDA':
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)

def train_loop_fn(loader, epoch):
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/_internal/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def run_multiprocess(fn: Callable[..., R],
num_processes = plugins.default().physical_chip_count()
elif runtime.device_type() == 'TPU':
num_processes = tpu.num_local_processes()
elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'):
elif runtime.device_type() == 'CUDA':
num_processes = gpu.num_local_processes()
elif runtime.device_type() == 'NEURON':
num_processes = neuron.num_local_processes()
Expand Down Expand Up @@ -216,7 +216,7 @@ def _initialize_single_process(local_rank: int, local_world_size: int):
def spawn_threads(fn: Callable, args: Tuple = ()) -> None:
"""Run function in one process with one thread per addressable device."""
assert runtime.device_type() not in (
'GPU', 'ROCM', 'CUDA'), "spawn_threads does not support GPU device"
'CUDA'), "spawn_threads does not support GPU device"
spawn_fn = _SpawnFn(fn, *args)
_run_thread_per_device(
local_rank=0,
Expand Down
Loading

0 comments on commit deff513

Please sign in to comment.