Skip to content

Commit

Permalink
Remove _unsafe_index implementation. (#5769)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored and ManfeiBai committed Nov 29, 2023
1 parent df17954 commit 2bcdae1
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 36 deletions.
1 change: 0 additions & 1 deletion codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ supported:
- _to_cpu
- _to_copy
- _unsafe_view
- _unsafe_index.Tensor
- adaptive_max_pool2d
- adaptive_max_pool2d_backward
- add.Scalar
Expand Down
28 changes: 0 additions & 28 deletions test/cpp/test_aten_xla_tensor_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2043,34 +2043,6 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMaxInPlace) {
ExpectCounterChanged("xla::scatter_reduce", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestUnsafeIndex) {
for (torch::ScalarType scalar_type :
{torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt,
torch::kLong}) {
torch::Tensor a =
isFloatingType(scalar_type)
? torch::rand({3, 4}, torch::TensorOptions(scalar_type))
: torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type));
for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) {
torch::List<torch::optional<torch::Tensor>> indices{
torch::tensor({0, 1}, torch::TensorOptions(index_scalar_type)),
torch::tensor({2, 3}, torch::TensorOptions(index_scalar_type))};
torch::Tensor c0 = torch::_unsafe_index(a, indices);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::List<torch::optional<torch::Tensor>> xla_indices{
CopyToDevice(*indices.get(0), device),
CopyToDevice(*indices.get(1), device)};
torch::Tensor xla_c0 = torch::_unsafe_index(xla_a, xla_indices);
AllEqual(c0, xla_c0);
});
}
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::index", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::_unsafe_index", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestIndexSelect) {
for (torch::ScalarType scalar_type :
{torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt,
Expand Down
7 changes: 0 additions & 7 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,13 +626,6 @@ at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self,
return view_copy_symint(self, c10::fromIntArrayRefSlow(size));
}

at::Tensor XLANativeFunctions::_unsafe_index(
const at::Tensor& self,
const c10::List<c10::optional<at::Tensor>>& indices) {
TORCH_LAZY_FN_COUNTER("xla::");
return index(self, indices);
}

at::Tensor XLANativeFunctions::add(const at::Tensor& self,
const at::Tensor& other,
const at::Scalar& alpha) {
Expand Down

0 comments on commit 2bcdae1

Please sign in to comment.