diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index 1f1c69e22f0..bdb6c38e8cc 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -130,7 +130,6 @@ supported: - _to_cpu - _to_copy - _unsafe_view - - _unsafe_index.Tensor - adaptive_max_pool2d - adaptive_max_pool2d_backward - add.Scalar diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index 5d281a626ec..d59b5b32360 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -2033,34 +2033,6 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMaxInPlace) { ExpectCounterChanged("xla::scatter_reduce", cpp_test::GetIgnoredCounters()); } -TEST_F(AtenXlaTensorTest, TestUnsafeIndex) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); - for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) { - torch::List> indices{ - torch::tensor({0, 1}, torch::TensorOptions(index_scalar_type)), - torch::tensor({2, 3}, torch::TensorOptions(index_scalar_type))}; - torch::Tensor c0 = torch::_unsafe_index(a, indices); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::List> xla_indices{ - CopyToDevice(*indices.get(0), device), - CopyToDevice(*indices.get(1), device)}; - torch::Tensor xla_c0 = torch::_unsafe_index(xla_a, xla_indices); - AllEqual(c0, xla_c0); - }); - } - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_unsafe_index", cpp_test::GetIgnoredCounters()); -} - TEST_F(AtenXlaTensorTest, TestIndexSelect) { for (torch::ScalarType scalar_type : {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 086eb24e51c..b51fbba98cd 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -626,13 +626,6 @@ at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self, return view_copy_symint(self, c10::fromIntArrayRefSlow(size)); } -at::Tensor XLANativeFunctions::_unsafe_index( - const at::Tensor& self, - const c10::List>& indices) { - TORCH_LAZY_FN_COUNTER("xla::"); - return index(self, indices); -} - at::Tensor XLANativeFunctions::add(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) {