From bc0dc13b54071ada7b6be153090edd8dbaf46df6 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 13 Oct 2023 21:21:05 -0300 Subject: [PATCH 1/3] Add support for `_unsafe_index`. --- codegen/xla_native_functions.yaml | 1 + torch_xla/csrc/aten_xla_type.cpp | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index bdb6c38e8cc..1f1c69e22f0 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -130,6 +130,7 @@ supported: - _to_cpu - _to_copy - _unsafe_view + - _unsafe_index.Tensor - adaptive_max_pool2d - adaptive_max_pool2d_backward - add.Scalar diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index a3f77bd54be..5205ed81ab2 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -626,6 +626,14 @@ at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self, return view_copy_symint(self, c10::fromIntArrayRefSlow(size)); } +at::Tensor XLANativeFunctions::_unsafe_index( + const at::Tensor& self, + const c10::List>& indices) { + TORCH_LAZY_FN_COUNTER("xla::"); + return index(self, indices); +} + + at::Tensor XLANativeFunctions::add(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) { From 1b588dfe3733420e44feed6b0c7df4db5dbdcf35 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 16 Oct 2023 17:06:25 -0300 Subject: [PATCH 2/3] Fix lint issues. --- torch_xla/csrc/aten_xla_type.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 5205ed81ab2..777b45e3262 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -633,7 +633,6 @@ at::Tensor XLANativeFunctions::_unsafe_index( return index(self, indices); } - at::Tensor XLANativeFunctions::add(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) { From a5e90908b869af5a5aa3ff9eb7a747eee5d514fb Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 17 Oct 2023 17:54:55 -0300 Subject: [PATCH 3/3] Add tests. --- test/cpp/test_aten_xla_tensor_1.cpp | 28 ++++++++++++++++++++++++++++ test/dynamo/test_bridge.py | 15 +++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index d59b5b32360..5d281a626ec 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -2033,6 +2033,34 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMaxInPlace) { ExpectCounterChanged("xla::scatter_reduce", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestUnsafeIndex) { + for (torch::ScalarType scalar_type : + {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + torch::kLong}) { + torch::Tensor a = + isFloatingType(scalar_type) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); + for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) { + torch::List> indices{ + torch::tensor({0, 1}, torch::TensorOptions(index_scalar_type)), + torch::tensor({2, 3}, torch::TensorOptions(index_scalar_type))}; + torch::Tensor c0 = torch::_unsafe_index(a, indices); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::List> xla_indices{ + CopyToDevice(*indices.get(0), device), + CopyToDevice(*indices.get(1), device)}; + torch::Tensor xla_c0 = torch::_unsafe_index(xla_a, xla_indices); + AllEqual(c0, xla_c0); + }); + } + } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_unsafe_index", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestIndexSelect) { for (torch::ScalarType scalar_type : {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index 778d77591e4..6d74a135eea 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -78,6 +78,19 @@ def get_random_inputs(self): return (torch.randn(10), torch.randn(10)) +class UpsampleModule(nn.Module): + + def __init__(self): + super().__init__() + self.upsample = nn.Upsample(scale_factor=2) + + def forward(self, x): + return self.upsample(x) + + def get_random_inputs(self): + return (torch.randn((1, 1, 5)),) + + def allclose(expected, actual): def unwrap(cont): @@ -179,6 +192,7 @@ def test_wrapper(self): model = model.to(device=xla_dev) inputs = tuple(inp.to(device=xla_dev) for inp in inputs) + inputs = tuple(inp.requires_grad_() for inp in inputs) # do baseline baseline_model = copy.deepcopy(model) @@ -206,6 +220,7 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase): test_training_linear = make_training_test(LinearModule) test_training_maxpool = make_training_test(MaxPoolModule) + test_training_upsample = make_training_test(UpsampleModule) if __name__ == "__main__":