Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove or simplify hardcoded lists of device types #6235

Merged
merged 18 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .circleci/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ function run_torch_xla_python_tests() {
if [ -x "$(command -v nvidia-smi)" ]; then
# These tests fail on CUDA with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# TODO(xiowei replace gpu with cuda): remove the test below with PJRT_DEVICE=GPU because PJRT_DEVICE=GPU is being deprecated.
PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# Syncfree SGD optimizer tests
Expand Down
2 changes: 1 addition & 1 deletion test/pjrt/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _ddp_init(index: int = ...):
def test_ddp_init(self):
pjrt.run_multiprocess(self._ddp_init)

@absltest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
@absltest.skipIf(xr.device_type() == 'CUDA',
"GPU device is not supported by pjrt.spawn_threads")
def test_ddp_init_threaded(self):
pjrt.spawn_threads(self._ddp_init)
Expand Down
2 changes: 1 addition & 1 deletion test/pjrt/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt):
xr.using_pjrt()

if expect_using_pjrt:
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU', 'ROCM', 'GPU'])
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU'])
else:
self.assertIsNone(xr.device_type())

Expand Down
2 changes: 1 addition & 1 deletion test/pjrt/test_runtime_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from absl.testing import absltest, parameterized


@unittest.skipIf(xr.device_type() not in ('GPU', 'CUDA', 'ROCM'),
@unittest.skipIf(xr.device_type() != "CUDA",
f"GPU tests should only run on GPU devices.")
class TestExperimentalPjrtGpu(parameterized.TestCase):

Expand Down
72 changes: 24 additions & 48 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_debugging_spmd_single_host_tiled_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -108,10 +106,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_partial_replication_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -159,10 +155,8 @@ def test_single_host_partial_replication_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_replicated_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -204,10 +198,8 @@ def test_single_host_replicated_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_debugging_spmd_single_host_tiled_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -247,10 +239,8 @@ def test_debugging_spmd_single_host_tiled_cpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_single_host_partial_replication_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -291,10 +281,8 @@ def test_single_host_partial_replication_cpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_single_host_replicated_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
Expand Down Expand Up @@ -341,10 +329,8 @@ def test_single_host_replicated_cpu(self):
# e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}
# e.g.: sharding={replicated}

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_debugging_spmd_multi_host_tiled_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}'
Expand Down Expand Up @@ -453,10 +439,8 @@ def test_debugging_spmd_multi_host_tiled_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_multi_host_partial_replication_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}'
Expand Down Expand Up @@ -537,10 +521,8 @@ def test_multi_host_partial_replication_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.device_type() != 'TPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_multi_host_replicated_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{replicated}'
Expand Down Expand Up @@ -573,10 +555,8 @@ def test_multi_host_replicated_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_debugging_spmd_multi_host_tiled_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}'
Expand Down Expand Up @@ -685,10 +665,8 @@ def test_debugging_spmd_multi_host_tiled_cpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_multi_host_partial_replication_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}'
Expand Down Expand Up @@ -769,10 +747,8 @@ def test_multi_host_partial_replication_cpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'),
f"Requires PJRT_DEVICE set to `CPU`.")
@unittest.skipIf(xr.device_type() != 'CPU',
f"Requires PJRT_DEVICE set to `CPU`.")
def test_multi_host_replicated_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{replicated}'
Expand Down
8 changes: 4 additions & 4 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,8 @@ def test_manager_async_step_tracking(self, tmpdir):
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))

@unittest.skipUnless(xr.device_type() == 'TPU',
'TPU required for worker IP discovery')
@unittest.skipIf(xr.device_type() != 'TPU',
'TPU required for worker IP discovery')
@unittest.mock.patch('torch_xla._internal.tpu.get_worker_ips')
def test_master_ip_discovery(self, patched_get_worker_ips):
# A basic test to verify the SPMD codepath returns the correct IP. Two IPs
Expand Down Expand Up @@ -543,8 +543,8 @@ def test_preemption_sync_manager(self):
# Scope the PreemptionSyncManager to the lifespan of the test.
torch_xla._XLAC._deactivate_preemption_sync_manager()

@unittest.skipUnless(xr.device_type() == 'TPU',
'TPU required for worker IP discovery')
@unittest.skipIf(xr.device_type() != 'TPU',
'TPU required for worker IP discovery')
@run_with_tmpdir
def test_auto_checkpoint(self, tmpdir):
# Create a checkpoint manager with a long save interval
Expand Down
5 changes: 2 additions & 3 deletions test/spmd/test_xla_sharding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
import torch_xla.utils.utils as xu


@unittest.skipIf(not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'),
f"Requires PJRT_DEVICE set to `TPU` or `CPU`.")
@unittest.skipUnless(xr.device_type() in ("TPU", "CPU"),
f"Requires PJRT_DEVICE set to `TPU` or `CPU`.")
class XlaShardingTest(unittest.TestCase):

class SimpleLinear(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_spmd_python_api_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

@unittest.skipIf(xr.device_type() not in ['GPU', 'TPU', 'CUDA', 'ROCM'],
@unittest.skipIf(xr.device_type() not in ['TPU', 'CUDA'],
f"TPU/GPU autocast test.")
def test_xla_autocast_api(self):
device = xm.xla_device()
Expand Down
2 changes: 1 addition & 1 deletion test/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _ddp_correctness(rank, use_large_net: bool, debug: bool):
# We cannot run this guard before XMP,
# see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing.
device = xm.xla_device()
if xm.xla_device_hw(device) not in ('GPU', 'TPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) not in ('TPU', 'CUDA'):
print(
'Default device {} is not a TPU device'.format(device),
file=sys.stderr)
Expand Down
10 changes: 5 additions & 5 deletions test/test_fsdp_auto_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def forward(self, x):
hidden2 = self.fc2(x)
return hidden1, hidden2

@unittest.skipIf(xr.device_type() in (
'GPU', 'ROCM', 'CUDA'
), "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)"
)
@unittest.skipIf(
xr.device_type() == 'CUDA',
"This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)"
)
def test(self):
dev = xm.xla_device()
input = torch.zeros([16, 16], device=dev)
Expand All @@ -50,7 +50,7 @@ def test(self):

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
test = unittest.main(exit=False)
sys.exit(0 if test.result.wasSuccessful() else 1)
else:
Expand Down
7 changes: 2 additions & 5 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import unittest
Expand Down Expand Up @@ -171,11 +172,7 @@ def test_metrics_report(self):
report = met.metrics_report()
self.assertIn("CachedCompile", report)

@unittest.skipIf(
xm.get_xla_supported_devices("CUDA") or
xm.get_xla_supported_devices("GPU") or
xm.get_xla_supported_devices("ROCM") or
xm.get_xla_supported_devices("TPU"), f"This test only works on CPU.")
@unittest.skipIf(xr.device_type() != "CPU", f"This test only works on CPU.")
def test_execute_time_metric(self):
# Initialize the client before starting the timer.
xm.xla_device()
Expand Down
2 changes: 1 addition & 1 deletion test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _mp_fn(index):
device = xm.xla_device()
world_size = xm.xrt_world_size()
input_list_size = 5
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM', 'NEURON'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'):
# Testing with a single replica group
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor, dim=0)
Expand Down
2 changes: 1 addition & 1 deletion test/test_mp_distributed_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def _mp_fn(index):
device = xm.xla_device()

if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xm.xrt_world_size()
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
Expand Down
3 changes: 1 addition & 2 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,7 @@ def test_get_real_xla_devices(self):
devices = xm.get_xla_supported_devices()
xla_devices = torch_xla._XLAC._xla_real_devices(devices)
for device, xdevice in zip(devices, xla_devices):
self.assertTrue(
re.match(r'(CPU|GPU|TPU|CUDA|ROCM):\d+$', xdevice) is not None)
self.assertIsNotNone(re.fullmatch(r'[A-Z]+:\d+$', xdevice))

def test_negative_slice(self):
t = _gen_tensor(32, 24, 32)
Expand Down
2 changes: 1 addition & 1 deletion test/test_torch_distributed_all_gather_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
2 changes: 1 addition & 1 deletion test/test_torch_distributed_all_reduce_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
2 changes: 1 addition & 1 deletion test/test_torch_distributed_reduce_scatter_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_imagenet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def train_imagenet():
if FLAGS.amp:
if device_hw == 'TPU':
scaler = None
elif device_hw in ('GPU', 'CUDA', 'ROCM'):
elif device_hw == 'CUDA':
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)

def train_loop_fn(loader, epoch):
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/_internal/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def run_multiprocess(fn: Callable[..., R],
num_processes = plugins.default().physical_chip_count()
elif runtime.device_type() == 'TPU':
num_processes = tpu.num_local_processes()
elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'):
elif runtime.device_type() == 'CUDA':
num_processes = gpu.num_local_processes()
elif runtime.device_type() == 'NEURON':
num_processes = neuron.num_local_processes()
Expand Down Expand Up @@ -216,7 +216,7 @@ def _initialize_single_process(local_rank: int, local_world_size: int):
def spawn_threads(fn: Callable, args: Tuple = ()) -> None:
"""Run function in one process with one thread per addressable device."""
assert runtime.device_type() not in (
'GPU', 'ROCM', 'CUDA'), "spawn_threads does not support GPU device"
'CUDA'), "spawn_threads does not support GPU device"
spawn_fn = _SpawnFn(fn, *args)
_run_thread_per_device(
local_rank=0,
Expand Down
Loading
Loading