Skip to content

Commit

Permalink
Add support for _unsafe_index. (#5707)
Browse files Browse the repository at this point in the history
* Add support for `_unsafe_index`.

* Fix lint issues.

* Add tests.
  • Loading branch information
ysiraichi authored Oct 19, 2023
1 parent 6ea9947 commit bccbb5a
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 0 deletions.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ supported:
- _to_cpu
- _to_copy
- _unsafe_view
- _unsafe_index.Tensor
- adaptive_max_pool2d
- adaptive_max_pool2d_backward
- add.Scalar
Expand Down
28 changes: 28 additions & 0 deletions test/cpp/test_aten_xla_tensor_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2033,6 +2033,34 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMaxInPlace) {
ExpectCounterChanged("xla::scatter_reduce", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestUnsafeIndex) {
for (torch::ScalarType scalar_type :
{torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt,
torch::kLong}) {
torch::Tensor a =
isFloatingType(scalar_type)
? torch::rand({3, 4}, torch::TensorOptions(scalar_type))
: torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type));
for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) {
torch::List<torch::optional<torch::Tensor>> indices{
torch::tensor({0, 1}, torch::TensorOptions(index_scalar_type)),
torch::tensor({2, 3}, torch::TensorOptions(index_scalar_type))};
torch::Tensor c0 = torch::_unsafe_index(a, indices);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::List<torch::optional<torch::Tensor>> xla_indices{
CopyToDevice(*indices.get(0), device),
CopyToDevice(*indices.get(1), device)};
torch::Tensor xla_c0 = torch::_unsafe_index(xla_a, xla_indices);
AllEqual(c0, xla_c0);
});
}
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::index", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::_unsafe_index", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestIndexSelect) {
for (torch::ScalarType scalar_type :
{torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt,
Expand Down
15 changes: 15 additions & 0 deletions test/dynamo/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ def get_random_inputs(self):
return (torch.randn(10), torch.randn(10))


class UpsampleModule(nn.Module):

def __init__(self):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2)

def forward(self, x):
return self.upsample(x)

def get_random_inputs(self):
return (torch.randn((1, 1, 5)),)


def allclose(expected, actual):

def unwrap(cont):
Expand Down Expand Up @@ -179,6 +192,7 @@ def test_wrapper(self):

model = model.to(device=xla_dev)
inputs = tuple(inp.to(device=xla_dev) for inp in inputs)
inputs = tuple(inp.requires_grad_() for inp in inputs)

# do baseline
baseline_model = copy.deepcopy(model)
Expand Down Expand Up @@ -206,6 +220,7 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase):

test_training_linear = make_training_test(LinearModule)
test_training_maxpool = make_training_test(MaxPoolModule)
test_training_upsample = make_training_test(UpsampleModule)

def test_non_tensor_args_for_partition(self):

Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,13 @@ at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self,
return view_copy_symint(self, c10::fromIntArrayRefSlow(size));
}

at::Tensor XLANativeFunctions::_unsafe_index(
const at::Tensor& self,
const c10::List<c10::optional<at::Tensor>>& indices) {
TORCH_LAZY_FN_COUNTER("xla::");
return index(self, indices);
}

at::Tensor XLANativeFunctions::add(const at::Tensor& self,
const at::Tensor& other,
const at::Scalar& alpha) {
Expand Down

0 comments on commit bccbb5a

Please sign in to comment.