diff --git a/.circleci/common.sh b/.circleci/common.sh index 88c5fed8efcb..d773eef13ece 100755 --- a/.circleci/common.sh +++ b/.circleci/common.sh @@ -148,16 +148,16 @@ function run_torch_xla_python_tests() { else ./test/run_tests.sh - # GPU tests + # CUDA tests if [ -x "$(command -v nvidia-smi)" ]; then - # These tests fail on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) - PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 - PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 - XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 + # These tests fail on CUDA with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) + PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 + PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 + XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 # Syncfree SGD optimizer tests if [ -d ./torch_xla/amp/syncfree ]; then echo "Running Syncfree Optimizer Test" - PJRT_DEVICE=GPU python test/test_syncfree_optimizers.py + PJRT_DEVICE=CUDA python test/test_syncfree_optimizers.py # Following test scripts are mainly useful for # performance evaluation & comparison among different @@ -192,9 +192,9 @@ function run_torch_xla_cpp_tests() { if [ "$USE_COVERAGE" != "0" ]; then # TODO(yeounoh) shard the coverage testing if [ -x "$(command -v nvidia-smi)" ]; then - PJRT_DEVICE=GPU test/cpp/run_tests.sh $EXTRA_ARGS -L"" + PJRT_DEVICE=CUDA test/cpp/run_tests.sh $EXTRA_ARGS -L"" cp $XLA_DIR/bazel-out/_coverage/_coverage_report.dat /tmp/cov1.dat - PJRT_DEVICE=GPU test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS + PJRT_DEVICE=CUDA test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS cp $XLA_DIR/bazel-out/_coverage/_coverage_report.dat /tmp/cov2.dat lcov --add-tracefile /tmp/cov1.dat -a /tmp/cov2.dat -o /tmp/merged.dat else @@ -206,8 +206,8 @@ function run_torch_xla_cpp_tests() { else # Shard GPU testing if [ -x "$(command -v nvidia-smi)" ]; then - PJRT_DEVICE=GPU test/cpp/run_tests.sh $EXTRA_ARGS -L"" - PJRT_DEVICE=GPU test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS + PJRT_DEVICE=CUDA test/cpp/run_tests.sh $EXTRA_ARGS -L"" + PJRT_DEVICE=CUDA test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS else PJRT_DEVICE=CPU test/cpp/run_tests.sh $EXTRA_ARGS -L"" fi diff --git a/WORKSPACE b/WORKSPACE index a4e4027a67ef..c5d0910044b0 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -39,15 +39,14 @@ http_archive( patch_tool = "patch", patches = [ "//openxla_patches:cache_urls.diff", - "//openxla_patches:f16_abi_clang.diff", - "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:constexpr_return.diff", - "//openxla_patches:pjrt_api_tsl_logging.diff", - "//openxla_patches:pjrt_c_api_dynamic_dimensions.diff", + "//openxla_patches:gpu_race_condition.diff", + "//openxla_patches:f16_abi_clang.diff", + "//openxla_patches:gpu_topk_rewriter.diff", ], - strip_prefix = "xla-97a5f819faf9ff793b7ba68ff1f31f74f9459c18", + strip_prefix = "xla-51b59cfb1999c6f1b3ec59851675044b2c502aae", urls = [ - "https://github.com/openxla/xla/archive/97a5f819faf9ff793b7ba68ff1f31f74f9459c18.tar.gz", + "https://github.com/openxla/xla/archive/51b59cfb1999c6f1b3ec59851675044b2c502aae.tar.gz", ], ) diff --git a/openxla_patches/gpu_topk_rewriter.diff b/openxla_patches/gpu_topk_rewriter.diff new file mode 100644 index 000000000000..47ee3fa0f0a8 --- /dev/null +++ b/openxla_patches/gpu_topk_rewriter.diff @@ -0,0 +1,184 @@ +diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc +index da872d962..1b7141055 100644 +--- a/xla/service/topk_rewriter.cc ++++ b/xla/service/topk_rewriter.cc +@@ -196,6 +196,8 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { + return std::nullopt; + } + const int64_t sort_dim = sort->sort_dimension(); ++ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; ++ const bool has_batch = data->shape().rank() == 2; + + bool supported = true; + std::optional k; +@@ -220,15 +222,10 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { + supported = false; + break; + } +- for (int64_t i = 0; i < slice->slice_limits().size(); ++i) { +- if (i != sort_dim && +- slice->slice_limits(i) != slice->operand(0)->shape().dimensions(i)) { +- // Slicing along a non-sort dimension isn't supported. +- supported = false; +- break; +- } +- } +- if (!supported) { ++ if (has_batch && slice->slice_limits(batch_dim) != ++ slice->operand(0)->shape().dimensions(batch_dim)) { ++ // Slicing along the batch dimension isn't supported. ++ supported = false; + break; + } + if (k == std::nullopt) { +@@ -260,57 +257,29 @@ StatusOr TopkRewriter::TransformToCustomCall( + HloSortInstruction* sort = DynCast(inst); + HloInstruction* data = sort->mutable_operand(0); + const PrimitiveType element_type = data->shape().element_type(); +- const Shape data_shape = data->shape(); + +- if (element_type != F32 && element_type != BF16) { ++ if ((data->shape().rank() != 1 && data->shape().rank() != 2) || ++ (element_type != F32 && element_type != BF16)) { + continue; + } + +- // Sort dimension must be the first or last dimension. + const int64_t sort_dim = sort->sort_dimension(); +- if (sort_dim != 0 && sort_dim != data_shape.rank() - 1) { +- continue; +- } ++ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; ++ const bool has_batch = data->shape().rank() == 2; + + // Profitability check. + if (!is_profitable_to_convert_(sort, *k)) { + continue; + } + +- HloInstruction* input = data; +- const bool has_batch = data_shape.rank() >= 2; +- const int64_t input_size = data_shape.dimensions(sort_dim); +- int64_t batch_size = 1; +- Shape topk_input_shape; +- +- if (has_batch) { +- // The TopK custom call expects either a 1d tensor or a 2d tensor with +- // the last dimension being the sort dimension. An input with rank > 2 +- // is reshaped into a 2d tensor by combining non-sort dimensions into a +- // single batch dimension. The original non-sort dimensions are +- // restored for the outputs with another reshape after the custom call. +- batch_size = +- ShapeUtil::ElementsIn(data_shape) / data_shape.dimensions(sort_dim); +- topk_input_shape = +- ShapeUtil::MakeShape(element_type, {batch_size, input_size}); +- +- if (data_shape.rank() > 2) { +- // Reshape to 2d. +- input = comp->AddInstruction(HloInstruction::CreateReshape( +- sort_dim == 0 +- ? ShapeUtil::MakeShape(element_type, {input_size, batch_size}) +- : ShapeUtil::MakeShape(element_type, +- {batch_size, input_size}), +- input)); +- } +- +- if (sort_dim == 0) { +- // Transpose for the custom call when sorting the first dimension. +- input = comp->AddInstruction( +- HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0})); +- } +- } else { +- topk_input_shape = data_shape; ++ const int64_t batch_size = ++ has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1; ++ const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim); ++ HloInstruction* input = sort->mutable_operand(0); ++ if (has_batch && sort_dim == 0) { ++ input = comp->AddInstruction(HloInstruction::CreateTranspose( ++ ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input, ++ {1, 0})); + } + + Shape topk_shape = +@@ -331,26 +300,13 @@ StatusOr TopkRewriter::TransformToCustomCall( + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(1), topk, 1)); + +- if (has_batch) { +- if (sort_dim == 0) { +- // Transpose back. +- value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( +- ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), +- value_gte, {1, 0})); +- index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( +- ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, +- {1, 0})); +- } +- if (data_shape.rank() > 2) { +- // Reshape back. +- std::vector shape_dim(data_shape.dimensions().begin(), +- data_shape.dimensions().end()); +- shape_dim[sort_dim] = k.value(); +- value_gte = comp->AddInstruction(HloInstruction::CreateReshape( +- ShapeUtil::MakeShape(element_type, shape_dim), value_gte)); +- index_gte = comp->AddInstruction(HloInstruction::CreateReshape( +- ShapeUtil::MakeShape(S32, shape_dim), index_gte)); +- } ++ if (has_batch && sort_dim == 0) { ++ value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( ++ ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), ++ value_gte, {1, 0})); ++ index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( ++ ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, ++ {1, 0})); + } + + for (HloInstruction* user : sort->users()) { +diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc +index 36e723737..25ce150e0 100644 +--- a/xla/service/topk_rewriter_test.cc ++++ b/xla/service/topk_rewriter_test.cc +@@ -326,42 +326,6 @@ ENTRY cluster { + EXPECT_THAT(cc->custom_call_target(), "TopK"); + } + +-TEST_F(TopkRewriterTest, RewriteReshape) { +- const std::string hlo_string = R"( +-HloModule module +-)" + getComparator() + R"( +-ENTRY cluster { +- %arg_tuple.1 = f32[3,8,1234567] parameter(0) +- %iota.4 = s32[3,8,1234567] iota(), iota_dimension=2 +- %sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4), +- dimensions={2}, is_stable=true, to_apply=%compare +- %get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0 +- %slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]} +- %get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1 +- %slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]} +- ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31) +-})"; +- TF_ASSERT_OK_AND_ASSIGN(auto module, +- ParseAndReturnVerifiedModule(hlo_string)); +- TopkRewriter rewriter( +- [](const HloSortInstruction*, int64_t) { return true; }); +- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); +- TF_ASSERT_OK(HloDCE().Run(module.get()).status()); +- EXPECT_TRUE(changed); +- EXPECT_THAT(module->entry_computation()->root_instruction(), +- GmockMatch(m::Tuple( +- m::Reshape(m::GetTupleElement( +- m::CustomCall(m::Reshape(m::Parameter(0))), 0)), +- m::Reshape(m::GetTupleElement( +- m::CustomCall(m::Reshape(m::Parameter(0))), 1))))); +- const HloInstruction* cc = module->entry_computation() +- ->root_instruction() +- ->operand(0) +- ->operand(0) +- ->operand(0); +- EXPECT_THAT(cc->custom_call_target(), "TopK"); +-} +- + TEST_F(TopkRewriterTest, RewriteNoIota) { + const std::string hlo_string = R"( + HloModule module diff --git a/openxla_patches/pjrt_api_tsl_logging.diff b/openxla_patches/pjrt_api_tsl_logging.diff deleted file mode 100644 index 296bed91ad68..000000000000 --- a/openxla_patches/pjrt_api_tsl_logging.diff +++ /dev/null @@ -1,21 +0,0 @@ -# Fixes log spam when loading libtpu. We should fix this upstream. -diff --git a/xla/pjrt/pjrt_api.cc b/xla/pjrt/pjrt_api.cc -index 132cfaff0..887e842e0 100644 ---- a/xla/pjrt/pjrt_api.cc -+++ b/xla/pjrt/pjrt_api.cc -@@ -17,7 +17,6 @@ limitations under the License. - - #include - --#include "absl/log/log.h" - #include "absl/status/status.h" - #include "absl/strings/str_cat.h" - #include "xla/pjrt/c/pjrt_c_api.h" -@@ -33,6 +32,7 @@ limitations under the License. - #include "xla/pjrt/c/pjrt_c_api_helpers.h" - #include "xla/status.h" - #include "xla/statusor.h" -+#include "tsl/platform/logging.h" - #include "tsl/platform/errors.h" - - namespace pjrt { diff --git a/openxla_patches/pjrt_c_api_dynamic_dimensions.diff b/openxla_patches/pjrt_c_api_dynamic_dimensions.diff deleted file mode 100644 index ee1ec00eced5..000000000000 --- a/openxla_patches/pjrt_c_api_dynamic_dimensions.diff +++ /dev/null @@ -1,76 +0,0 @@ -# Partial backport of 6308dba2903e78961ac4122f361bc91b09f36891. Remove in next -# pin update. -diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc -index ef0b6686c..c0341e81e 100644 ---- a/xla/pjrt/pjrt_c_api_client.cc -+++ b/xla/pjrt/pjrt_c_api_client.cc -@@ -1584,6 +1584,34 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const { - return args.num_dynamic_dims > 0; - } - -+absl::Span PjRtCApiBuffer::is_dynamic_dimension() const { -+ { -+ absl::MutexLock lock(&mu_); -+ if (!is_dynamic_dimension_.has_value()) { -+ absl::InlinedVector& is_dynamic_dimension_value = -+ is_dynamic_dimension_.emplace(); -+ is_dynamic_dimension_value.assign(dimensions().size(), false); -+ -+ PJRT_Buffer_DynamicDimensionIndices_Args args; -+ args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; -+ args.priv = nullptr; -+ args.buffer = buffer_.get(); -+ const PJRT_Api* api = pjrt_c_api(); -+ std::unique_ptr error( -+ api->PJRT_Buffer_DynamicDimensionIndices(&args), -+ pjrt::MakeErrorDeleter(api)); -+ if (error && pjrt::GetErrorCode(error.get(), api) == -+ PJRT_Error_Code_UNIMPLEMENTED) { -+ return *is_dynamic_dimension_; -+ } -+ for (int i = 0; i < args.num_dynamic_dims; ++i) { -+ is_dynamic_dimension_value[args.dynamic_dim_indices[i]] = true; -+ } -+ } -+ } -+ return *is_dynamic_dimension_; -+} -+ - StatusOr> PjRtCApiBuffer::logical_dimensions() { - PJRT_Buffer_UnpaddedDimensions_Args args; - args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE; -diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h -index 9c460f246..279608e60 100644 ---- a/xla/pjrt/pjrt_c_api_client.h -+++ b/xla/pjrt/pjrt_c_api_client.h -@@ -27,6 +27,7 @@ limitations under the License. - #include - - #include "absl/container/flat_hash_map.h" -+#include "absl/container/inlined_vector.h" - #include "absl/log/check.h" - #include "absl/log/log.h" - #include "absl/strings/string_view.h" -@@ -369,11 +370,7 @@ class PjRtCApiBuffer : public PjRtBuffer { - - bool has_dynamic_dimensions() const override; - -- absl::Span is_dynamic_dimension() const override { -- LOG(FATAL) << "PjRtCApiBuffer::is_dynamic_dimension() not implemented. " -- << "Considering using has_dynamic_dimensions() or " -- "logical_dimensions() if applicable."; -- } -+ absl::Span is_dynamic_dimension() const override; - - StatusOr> logical_dimensions() override; - -@@ -455,6 +452,9 @@ class PjRtCApiBuffer : public PjRtBuffer { - std::shared_ptr::Promise> readiness_promise_; - // Set and cached the first time layout() is called. - mutable std::optional layout_; -+ // Set and cached the first time is_dynamic_dimension() is called. -+ mutable std::optional> -+ is_dynamic_dimension_; - // Used to synchronize concurrent setting of cached values. - mutable absl::Mutex mu_; - }; diff --git a/setup.py b/setup.py index 3a98db22c32f..ba5792190607 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_libtpu_version = '0.1.dev20230825' +_libtpu_version = '0.1.dev20231010' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index 9c4135f64c53..b8599c0e7d6a 100644 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -1512,7 +1512,7 @@ TEST_F(AtenXlaTensorTest, TestGroupNormBackward) { /*cudnn_enabled=*/false); }; torch::Tensor undef; - ForEachDevice({XlaDeviceType::GPU, XlaDeviceType::TPU}, + ForEachDevice({XlaDeviceType::CUDA, XlaDeviceType::TPU}, [&](const torch::Device& device) { TestBackward({input, undef_weight ? undef : weight, undef_weight ? undef : bias}, diff --git a/test/cpp/test_aten_xla_tensor_6.cpp b/test/cpp/test_aten_xla_tensor_6.cpp index d2e3e284f5b0..d7eb32619c4e 100644 --- a/test/cpp/test_aten_xla_tensor_6.cpp +++ b/test/cpp/test_aten_xla_tensor_6.cpp @@ -873,7 +873,7 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) { TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) { XlaDeviceType hw_type = static_cast(bridge::GetDefaultDevice()->type()); - if (hw_type != XlaDeviceType::GPU && hw_type != XlaDeviceType::CPU) { + if (hw_type != XlaDeviceType::CUDA && hw_type != XlaDeviceType::CPU) { return; } torch::Tensor growth_tracker = diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 08b039b9e5f2..6d7a54add0ce 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -94,7 +94,7 @@ class ReplicationTest : public AtenXlaTensorTestBase {}; TEST_F(ReplicationTest, TestNSingleReplication) { WithAllDevices( - {XlaDeviceType::TPU, XlaDeviceType::GPU}, + {XlaDeviceType::TPU, XlaDeviceType::CUDA}, [&](const std::vector& devices, const std::vector& all_devices) { TestSingleReplication(devices, all_devices); diff --git a/test/pjrt/test_ddp.py b/test/pjrt/test_ddp.py index f84cc30ec9eb..7b359311c8f5 100644 --- a/test/pjrt/test_ddp.py +++ b/test/pjrt/test_ddp.py @@ -32,7 +32,7 @@ def _ddp_init(index: int = ...): def test_ddp_init(self): pjrt.run_multiprocess(self._ddp_init) - @absltest.skipIf(xr.device_type() == 'GPU', + @absltest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'), "GPU device is not supported by pjrt.spawn_threads") def test_ddp_init_threaded(self): pjrt.spawn_threads(self._ddp_init) diff --git a/test/pjrt/test_runtime.py b/test/pjrt/test_runtime.py index 8e500ea4ef0e..8cb930714e09 100644 --- a/test/pjrt/test_runtime.py +++ b/test/pjrt/test_runtime.py @@ -16,7 +16,7 @@ class TestExperimentalPjrt(parameterized.TestCase): def setUp(self): xr.set_device_type('CPU') - @parameterized.parameters(('CPU', 'CPU'), ('GPU', 'GPU'), ('TPU', 'TPU'), + @parameterized.parameters(('CPU', 'CPU'), ('CUDA', 'CUDA'), ('TPU', 'TPU'), ('TPU_C_API', 'TPU'), ('TPU_LEGACY', 'TPU')) def test_device_type(self, pjrt_device, expected): with mock.patch.dict(os.environ, {'PJRT_DEVICE': pjrt_device}, clear=True): @@ -61,7 +61,7 @@ def test_xla_device_error(self): }, True), ('gpu_num_devives', { 'GPU_NUM_DEVICES': '4' }, True), ('pjrt_gpu', { - 'PJRT_DEVICE': 'GPU', + 'PJRT_DEVICE': 'CUDA', 'GPU_NUM_DEVICES': '4' }, True)) def test_pjrt_default_device(self, env_vars, expect_using_pjrt): @@ -77,7 +77,7 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt): xr.using_pjrt() if expect_using_pjrt: - self.assertIn(xr.device_type(), ['CPU', 'GPU', 'TPU']) + self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU', 'ROCM', 'GPU']) else: self.assertIsNone(xr.device_type()) diff --git a/test/pjrt/test_runtime_gpu.py b/test/pjrt/test_runtime_gpu.py index d82144b2c1ad..77fd4d94fb7e 100644 --- a/test/pjrt/test_runtime_gpu.py +++ b/test/pjrt/test_runtime_gpu.py @@ -17,12 +17,12 @@ from absl.testing import absltest, parameterized -@unittest.skipIf(xr.device_type() != 'GPU', +@unittest.skipIf(xr.device_type() not in ('GPU', 'CUDA', 'ROCM'), f"GPU tests should only run on GPU devices.") class TestExperimentalPjrtGpu(parameterized.TestCase): def setUp(self): - xr.set_device_type('GPU') + xr.set_device_type('CUDA') os.environ.update({ xenv.PJRT_GPU_ASYNC_CLIENT: 'true', diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 1c77e85b9f66..6835ebe79939 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -519,7 +519,7 @@ def union_of_disabled_tests(sets): DISABLED_TORCH_TESTS = { 'TPU': prepare_match_set(DISABLED_TORCH_TESTS_TPU), 'CPU': prepare_match_set(DISABLED_TORCH_TESTS_CPU), - 'GPU': prepare_match_set(DISABLED_TORCH_TESTS_GPU), + 'CUDA': prepare_match_set(DISABLED_TORCH_TESTS_GPU), } diff --git a/test/run_tests.sh b/test/run_tests.sh index 9b13aa4494fc..0bb9b965b83a 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -56,7 +56,7 @@ function run_coverage { function run_test { echo "Running in PjRt runtime: $@" if [ -x "$(command -v nvidia-smi)" ] && [ "$XLA_CUDA" != "0" ]; then - PJRT_DEVICE=GPU run_coverage "$@" + PJRT_DEVICE=CUDA run_coverage "$@" else # TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue. PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$@" diff --git a/test/spmd/test_xla_sharding_base.py b/test/spmd/test_xla_sharding_base.py index 4b83368d380a..54067512ce2a 100644 --- a/test/spmd/test_xla_sharding_base.py +++ b/test/spmd/test_xla_sharding_base.py @@ -10,7 +10,7 @@ @unittest.skipIf(not xr.using_pjrt() or - xu.getenv_as(xenv.PJRT_DEVICE, str) == "GPU", + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'), f"Requires PJRT_DEVICE set to `TPU` or `CPU`.") class XlaShardingTest(unittest.TestCase): diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index 8f9c319ac54b..6fb12f916d2e 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -122,7 +122,7 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() - @unittest.skipIf(xr.device_type() not in ['GPU', 'TPU'], + @unittest.skipIf(xr.device_type() not in ['GPU', 'TPU', 'CUDA', 'ROCM'], f"TPU/GPU autocast test.") def test_xla_autocast_api(self): device = xm.xla_device() diff --git a/test/test_autocast.py b/test/test_autocast.py index 9caa3017ea86..edbd834b61b7 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -341,7 +341,8 @@ def compare(first, second): self.assertFalse(self.is_autocast_enabled()) -@unittest.skipIf(not xm.get_xla_supported_devices("GPU"), f"GPU autocast test.") +@unittest.skipIf(not xm.get_xla_supported_devices("CUDA"), + f"CUDA autocast test.") class TestAutocastCuda(TestAutocastBase): def setUp(self): diff --git a/test/test_ddp.py b/test/test_ddp.py index 2389cc51f0df..25e53790cc5c 100644 --- a/test/test_ddp.py +++ b/test/test_ddp.py @@ -16,7 +16,7 @@ def _ddp_correctness(rank, use_large_net: bool, debug: bool): # We cannot run this guard before XMP, # see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing. device = xm.xla_device() - if xm.xla_device_hw(device) not in ('GPU', 'TPU'): + if xm.xla_device_hw(device) not in ('GPU', 'TPU', 'CUDA', 'ROCM'): print( 'Default device {} is not a TPU device'.format(device), file=sys.stderr) diff --git a/test/test_fsdp_auto_wrap.py b/test/test_fsdp_auto_wrap.py index 5bd85bb6b94a..b14fb769bc09 100644 --- a/test/test_fsdp_auto_wrap.py +++ b/test/test_fsdp_auto_wrap.py @@ -31,10 +31,10 @@ def forward(self, x): hidden2 = self.fc2(x) return hidden1, hidden2 - @unittest.skipIf( - xr.device_type() == 'GPU', - "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)" - ) + @unittest.skipIf(xr.device_type() in ( + 'GPU', 'ROCM', 'CUDA' + ), "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)" + ) def test(self): dev = xm.xla_device() input = torch.zeros([16, 16], device=dev) @@ -50,12 +50,12 @@ def test(self): def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) else: print( - 'Default device {} is not a TPU or GPU device'.format(device), + 'Default device {} is not a TPU or CUDA device'.format(device), file=sys.stderr) diff --git a/test/test_metrics.py b/test/test_metrics.py index 8f5b0cfd8503..3037692ea656 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -164,7 +164,9 @@ def test_metrics_report(self): self.assertIn("CachedCompile", report) @unittest.skipIf( + xm.get_xla_supported_devices("CUDA") or xm.get_xla_supported_devices("GPU") or + xm.get_xla_supported_devices("ROCM") or xm.get_xla_supported_devices("TPU"), f"This test only works on CPU.") def test_execute_time_metric(self): # Initialize the client before starting the timer. diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index b8fee7e29ba3..3ffeebc963d6 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -13,7 +13,7 @@ def all_gather(tensor, dim): def _mp_fn(index): device = xm.xla_device() world_size = xm.xrt_world_size() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): # Testing with a single replica group ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) result = xm.all_gather(ordinal_tensor, dim=0) diff --git a/test/test_mp_distributed_mm.py b/test/test_mp_distributed_mm.py index 3dc45732ac36..eb292f7a53d0 100644 --- a/test/test_mp_distributed_mm.py +++ b/test/test_mp_distributed_mm.py @@ -9,7 +9,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) diff --git a/test/test_operations.py b/test/test_operations.py index 24b17cffbc01..220805b3fbea 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -447,7 +447,8 @@ def test_get_real_xla_devices(self): devices = xm.get_xla_supported_devices() xla_devices = torch_xla._XLAC._xla_real_devices(devices) for device, xdevice in zip(devices, xla_devices): - self.assertTrue(re.match(r'(CPU|GPU|TPU):\d+$', xdevice) is not None) + self.assertTrue( + re.match(r'(CPU|GPU|TPU|CUDA|ROCM):\d+$', xdevice) is not None) def test_negative_slice(self): t = _gen_tensor(32, 24, 32) diff --git a/test/test_torch_distributed_all_gather_xla_backend.py b/test/test_torch_distributed_all_gather_xla_backend.py index 763c15d6f5bf..f75a019db862 100644 --- a/test/test_torch_distributed_all_gather_xla_backend.py +++ b/test/test_torch_distributed_all_gather_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_torch_distributed_all_reduce_xla_backend.py b/test/test_torch_distributed_all_reduce_xla_backend.py index 3f0bca31b8fa..9962c824b7d5 100644 --- a/test/test_torch_distributed_all_reduce_xla_backend.py +++ b/test/test_torch_distributed_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py index 79b65a469995..c626faf74474 100644 --- a/test/test_torch_distributed_fsdp_frozen_weight.py +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -8,9 +8,9 @@ def _mp_fn(index): dev = xm.xla_device() - if xm.xla_device_hw(dev) not in ('TPU', 'GPU'): + if xm.xla_device_hw(dev) not in ('TPU', 'CUDA'): print( - 'Default device {} is not a TPU or GPU device'.format(dev), + 'Default device {} is not a TPU or CUDA device'.format(dev), file=sys.stderr) return diff --git a/test/test_torch_distributed_multi_all_reduce_xla_backend.py b/test/test_torch_distributed_multi_all_reduce_xla_backend.py index cf16311ca986..e576c3ffb0f4 100644 --- a/test/test_torch_distributed_multi_all_reduce_xla_backend.py +++ b/test/test_torch_distributed_multi_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_torch_distributed_reduce_scatter_xla_backend.py b/test/test_torch_distributed_reduce_scatter_xla_backend.py index f278567379e0..fd146d98af74 100644 --- a/test/test_torch_distributed_reduce_scatter_xla_backend.py +++ b/test/test_torch_distributed_reduce_scatter_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index bec112a93781..3ed923897156 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -221,7 +221,7 @@ def train_imagenet(): if FLAGS.amp: if device_hw == 'TPU': scaler = None - elif device_hw == 'GPU': + elif device_hw in ('GPU', 'CUDA', 'ROCM'): scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) def train_loop_fn(loader, epoch): diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index ae4db1183007..3c9363f8d09c 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -142,7 +142,7 @@ def train_mnist(flags, **kwargs): if device_hw == 'TPU': scaler = None - elif device_hw == 'GPU': + elif device_hw == 'CUDA': # GradScaler only used for GPU scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) else: diff --git a/test/test_zero1.py b/test/test_zero1.py index cb7517265771..e9c3a3eeee62 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -13,7 +13,7 @@ class XlaZeRO1Test(TestCase): @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU") - @unittest.skipIf(xr.device_type() == 'GPU', + @unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'), "TODO(alanwaketan): Fix it for the token change.") def test_zero1(self): device = xm.xla_device() diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 9e7533955e4a..67844b3ab393 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -138,7 +138,7 @@ def run_multiprocess(fn: Callable[..., R], """ if runtime.device_type() == 'TPU': num_processes = tpu.num_local_processes() - elif runtime.device_type() == 'GPU': + elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'): num_processes = gpu.num_local_processes() gpu.initialize_distributed_runtime(num_processes) elif runtime.device_type() == 'NEURON': @@ -160,7 +160,7 @@ def run_multiprocess(fn: Callable[..., R], itertools.chain.from_iterable( result.items() for result in process_results)) - if runtime.device_type() == 'GPU': + if runtime.device_type() in ('GPU', 'ROCM', 'CUDA'): gpu.shutdown_distributed_runtime() return _merge_replica_results(replica_results) @@ -210,8 +210,8 @@ def _initialize_single_process(local_rank: int, local_world_size: int): def spawn_threads(fn: Callable, args: Tuple = ()) -> None: """Run function in one process with one thread per addressable device.""" - assert runtime.device_type( - ) != 'GPU', "spawn_threads does not support GPU device" + assert runtime.device_type() not in ( + 'GPU', 'ROCM', 'CUDA'), "spawn_threads does not support GPU device" spawn_fn = _SpawnFn(fn, *args) _run_thread_per_device( local_rank=0, diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index fcdd4a40840a..db7dc3d5d9b4 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -25,7 +25,7 @@ def __init__(self, self._enabled = enabled self._xla_device = xm.xla_device_hw(device) - if self._xla_device == 'GPU': + if self._xla_device in ('GPU', 'ROCM', 'CUDA'): backend = 'cuda' self._xla_bfloat16 = False # True if xla backend with bfloat16 dtype. if dtype is None: @@ -70,7 +70,7 @@ def __init__(self, def __enter__(self): # This ensures that xla autocast is enabled even for XLA:GPU, which calls # `torch.amp.autocast_mode.autocast` with `cuda` backend. - if self._xla_device == 'GPU': + if self._xla_device in ('GPU', 'ROCM', 'CUDA'): self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined] self.prev_dtype = torch.get_autocast_xla_dtype( ) # type: ignore[attr-defined] @@ -86,7 +86,7 @@ def __enter__(self): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] - if self._xla_device == 'GPU': + if self._xla_device in ('GPU', 'ROCM', 'CUDA'): if self._xla_bfloat16: # autocast_xla flags will be set by `torch.autocast` and we need to # set autocast flags as we call into `torch.autocast` apis. diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 7f9682c1000d..ff6d015b2b22 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -71,7 +71,7 @@ def is_xla_tensor(tensor): def parse_xla_device(device): - m = re.match(r'(CPU|TPU|GPU|XPU|NEURON):(\d+)$', device) + m = re.match(r'(CPU|TPU|GPU|ROCM|CUDA|XPU|NEURON):(\d+)$', device) if m: return (m.group(1), int(m.group(2))) @@ -89,7 +89,9 @@ def get_xla_supported_devices(devkind=None, max_devices=None): The list of device strings. """ xla_devices = _DEVICES.value - devkind = [devkind] if devkind else ['TPU', 'GPU', 'XPU', 'NEURON', 'CPU'] + devkind = [devkind] if devkind else [ + 'TPU', 'GPU', 'XPU', 'NEURON', 'CPU', 'CUDA', 'ROCM' + ] for kind in devkind: kind_devices = [] for i, device in enumerate(xla_devices): @@ -181,8 +183,8 @@ def xla_device(n=None, devkind=None): n (int, optional): The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of `devkind` will be returned. - devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU` - `NEURON` or `CPU`. + devkind (string..., optional): If specified, one of `TPU`, `CUDA`, `XPU` + `NEURON`, `ROCM` or `CPU`. Returns: A `torch.device` with the requested instance. @@ -217,7 +219,7 @@ def xla_device_hw(device): real device. Returns: - A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`) + A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`, `CUDA`, `ROCM`) of the given device. """ real_device = _xla_real_device(device) diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index 791b0271aca9..08489c30b8b8 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -17,6 +17,10 @@ std::string XlaDeviceTypeToString(XlaDeviceType hw_type) { return "CPU"; case XlaDeviceType::GPU: return "GPU"; + case XlaDeviceType::CUDA: + return "CUDA"; + case XlaDeviceType::ROCM: + return "ROCM"; case XlaDeviceType::TPU: return "TPU"; case XlaDeviceType::XPU: @@ -59,6 +63,12 @@ torch::lazy::BackendDevice ParseDeviceString(const std::string& device_spec) { } else if (device_spec_parts[0] == "CPU") { device_type->type = static_cast>(XlaDeviceType::CPU); + } else if (device_spec_parts[0] == "ROCM") { + device_type->type = + static_cast>(XlaDeviceType::ROCM); + } else if (device_spec_parts[0] == "CUDA") { + device_type->type = + static_cast>(XlaDeviceType::CUDA); } else if (device_spec_parts[0] == "GPU") { device_type->type = static_cast>(XlaDeviceType::GPU); diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index c17cbe85540c..1dc939bb17a1 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -15,7 +15,7 @@ namespace torch_xla { // TODO(yeounoh) `SPMD` is a virtual device that defers data `TransferToServer` // until after the paritioning pass. This avoids transfering the full input // tensor to the device. -enum class XlaDeviceType { CPU, GPU, TPU, XPU, NEURON, SPMD }; +enum class XlaDeviceType { CPU, CUDA, ROCM, GPU, TPU, XPU, NEURON, SPMD }; struct DeviceType : public torch::lazy::BackendDeviceType { DeviceType() { type = static_cast(XlaDeviceType::CPU); } diff --git a/torch_xla/csrc/random.cpp b/torch_xla/csrc/random.cpp index 44564c1d24fd..78376797bbe1 100644 --- a/torch_xla/csrc/random.cpp +++ b/torch_xla/csrc/random.cpp @@ -21,6 +21,8 @@ std::string GetDefaultGitGeneratorName() { static_cast(bridge::GetCurrentDevice().type()); switch (hw_type) { case XlaDeviceType::GPU: + case XlaDeviceType::CUDA: + case XlaDeviceType::ROCM: return "three_fry"; default: return "default"; diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index a09768c6a9ef..1e610be7959b 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -329,7 +329,7 @@ class ComputationClient { virtual int GetNumProcesses() const = 0; using DeviceAttribute = - std::variant, float>; + std::variant, float, bool>; virtual const absl::flat_hash_map< std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 4f175be7d717..237531d43cf8 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -109,14 +109,17 @@ PjRtComputationClient::PjRtComputationClient() { client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); } else if (device_type == "TPU" || device_type == "TPU_C_API") { TF_VLOG(1) << "Initializing TFRT TPU client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))); + XLA_CHECK_OK( + pjrt::LoadPjrtPlugin( + "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so")) + .status()); tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); XLA_CHECK(tpu_status.ok()); client_ = std::move(xla::GetCApiClient("TPU").value()); } else if (device_type == "TPU_LEGACY") { XLA_ERROR() << "TPU_LEGACY client is no longer available."; - } else if (device_type == "GPU") { + } else if (device_type == "GPU" || device_type == "CUDA" || + device_type == "ROCM") { TF_VLOG(1) << "Initializing PjRt GPU client..."; bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); int local_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); @@ -154,15 +157,18 @@ PjRtComputationClient::PjRtComputationClient() { .value()); } else if (device_type == "XPU") { TF_VLOG(1) << "Initializing PjRt XPU client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so"))); + XLA_CHECK_OK( + pjrt::LoadPjrtPlugin( + "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")) + .status()); client_ = std::move(xla::GetCApiClient("XPU").value()); } else if (device_type == "NEURON") { TF_VLOG(1) << "Initializing PjRt NEURON client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - "NEURON", sys_util::GetEnvString(env::kEnvNeuronLibraryPath, - "libneuronpjrt.so"))); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("NEURON", sys_util::GetEnvString( + env::kEnvNeuronLibraryPath, + "libneuronpjrt.so")) + .status()); client_ = std::move(xla::GetCApiClient("NEURON").value()); } else { XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 04bd60ce9a0a..6322f052265a 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -75,7 +75,9 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor) // Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU // so we must manually update Autocast to AutocastCUDA on XLA:GPU. torch::lazy::BackendDevice current_device = bridge::GetCurrentDevice(); - if (static_cast(current_device.type()) == XlaDeviceType::GPU) { + auto dev_type = static_cast(current_device.type()); + if (dev_type == XlaDeviceType::GPU || dev_type == XlaDeviceType::CUDA || + dev_type == XlaDeviceType::ROCM) { auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA); auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA); key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks; diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 3087f3c80f64..649c538bd334 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -41,7 +41,7 @@ def _maybe_select_default_device(): # TODO(wcromar): Detect GPU device elif xu.getenv_as(xenv.GPU_NUM_DEVICES, int, 0) > 0: logging.warning('GPU_NUM_DEVICES is set. Setting PJRT_DEVICE=GPU') - os.environ[xenv.PJRT_DEVICE] = 'GPU' + os.environ[xenv.PJRT_DEVICE] = 'CUDA' else: logging.warning('Defaulting to PJRT_DEVICE=CPU') os.environ[xenv.PJRT_DEVICE] = 'CPU'