From 77f2cba51bb383840c0842024db03f884653dc7f Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 19 Oct 2023 15:11:35 -0300 Subject: [PATCH] Add support for `_unsafe_index`. (#5707) * Add support for `_unsafe_index`. * Fix lint issues. * Add tests. --- codegen/xla_native_functions.yaml | 1 + test/cpp/test_aten_xla_tensor_1.cpp | 28 ++++++++++++++++++++++++++++ test/dynamo/test_bridge.py | 15 +++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 7 +++++++ 4 files changed, 51 insertions(+) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index bdb6c38e8cc0..1f1c69e22f0a 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -130,6 +130,7 @@ supported: - _to_cpu - _to_copy - _unsafe_view + - _unsafe_index.Tensor - adaptive_max_pool2d - adaptive_max_pool2d_backward - add.Scalar diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index d59b5b323605..5d281a626eca 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -2033,6 +2033,34 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMaxInPlace) { ExpectCounterChanged("xla::scatter_reduce", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestUnsafeIndex) { + for (torch::ScalarType scalar_type : + {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + torch::kLong}) { + torch::Tensor a = + isFloatingType(scalar_type) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); + for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) { + torch::List> indices{ + torch::tensor({0, 1}, torch::TensorOptions(index_scalar_type)), + torch::tensor({2, 3}, torch::TensorOptions(index_scalar_type))}; + torch::Tensor c0 = torch::_unsafe_index(a, indices); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::List> xla_indices{ + CopyToDevice(*indices.get(0), device), + CopyToDevice(*indices.get(1), device)}; + torch::Tensor xla_c0 = torch::_unsafe_index(xla_a, xla_indices); + AllEqual(c0, xla_c0); + }); + } + } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_unsafe_index", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestIndexSelect) { for (torch::ScalarType scalar_type : {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index 8ca735073bec..a419f9335db3 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -78,6 +78,19 @@ def get_random_inputs(self): return (torch.randn(10), torch.randn(10)) +class UpsampleModule(nn.Module): + + def __init__(self): + super().__init__() + self.upsample = nn.Upsample(scale_factor=2) + + def forward(self, x): + return self.upsample(x) + + def get_random_inputs(self): + return (torch.randn((1, 1, 5)),) + + def allclose(expected, actual): def unwrap(cont): @@ -179,6 +192,7 @@ def test_wrapper(self): model = model.to(device=xla_dev) inputs = tuple(inp.to(device=xla_dev) for inp in inputs) + inputs = tuple(inp.requires_grad_() for inp in inputs) # do baseline baseline_model = copy.deepcopy(model) @@ -206,6 +220,7 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase): test_training_linear = make_training_test(LinearModule) test_training_maxpool = make_training_test(MaxPoolModule) + test_training_upsample = make_training_test(UpsampleModule) def test_non_tensor_args_for_partition(self): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index b51fbba98cdf..086eb24e51c1 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -626,6 +626,13 @@ at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self, return view_copy_symint(self, c10::fromIntArrayRefSlow(size)); } +at::Tensor XLANativeFunctions::_unsafe_index( + const at::Tensor& self, + const c10::List>& indices) { + TORCH_LAZY_FN_COUNTER("xla::"); + return index(self, indices); +} + at::Tensor XLANativeFunctions::add(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) {