From 6078fee85eb0a30db5d513a45c5759d28c9e9948 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 11 Mar 2024 20:33:54 -0300 Subject: [PATCH 1/3] Add fix and test. --- test/test_operations.py | 12 ++++++++++++ torch_xla/csrc/tensor_ops.cpp | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/test/test_operations.py b/test/test_operations.py index 187fb62e8ff..50027629e5e 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1607,6 +1607,18 @@ def test_emb_bf16(self): emb_out = emb(index) assert emb_out.dtype == torch.bfloat16 + def test_embedding_int_indices(self): + model = torch.nn.Embedding(1024, 10) + + def test_on_device(device): + m = copy.deepcopy(model).to(device) + index = torch.ones(1, dtype=torch.int, device=device) + return m(index) + + out = test_on_device("cpu") + out_x = test_on_device(xm.xla_device()) + self.assertEqual(out, out_x.cpu()) + def test_transpose_1d(self): def test_fn(t1): diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index 2ac9e5608a5..a66ee923475 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -245,7 +245,7 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output, XLATensorPtr Embedding(const XLATensorPtr& weight, const XLATensorPtr& indices) { XLA_CHECK_EQ(weight->shape().get().rank(), 2); - XLA_CHECK_EQ(indices->dtype(), at::ScalarType::Long); + XLA_CHECK(indices->dtype() == at::kLong || indices->dtype() == at::kInt); if (indices->shape().get().rank() == 1) { return tensor_methods::index_select(weight, 0, indices); From f30dc782d94ca494685177a66cb2adcf3e90184e Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 11 Mar 2024 20:39:10 -0300 Subject: [PATCH 2/3] Add one more test case. --- test/test_operations.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 50027629e5e..48d3599d9be 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1610,14 +1610,17 @@ def test_emb_bf16(self): def test_embedding_int_indices(self): model = torch.nn.Embedding(1024, 10) - def test_on_device(device): - m = copy.deepcopy(model).to(device) - index = torch.ones(1, dtype=torch.int, device=device) - return m(index) - - out = test_on_device("cpu") - out_x = test_on_device(xm.xla_device()) - self.assertEqual(out, out_x.cpu()) + # 1 and 2-dimensional tensors. + # They have different execution paths. + for shape in ((5,), (2, 5)): + def test_on_device(device): + m = copy.deepcopy(model).to(device) + index = torch.ones(shape, dtype=torch.int, device=device) + return m(index) + + out = test_on_device("cpu") + out_x = test_on_device(xm.xla_device()) + self.assertEqual(out, out_x.cpu()) def test_transpose_1d(self): From 1bb662311f0efc6cef019b88560c1dad82dcab3a Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 11 Mar 2024 20:56:24 -0300 Subject: [PATCH 3/3] Fix lint issue. --- test/test_operations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_operations.py b/test/test_operations.py index 48d3599d9be..85d65ae294b 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1613,6 +1613,7 @@ def test_embedding_int_indices(self): # 1 and 2-dimensional tensors. # They have different execution paths. for shape in ((5,), (2, 5)): + def test_on_device(device): m = copy.deepcopy(model).to(device) index = torch.ones(shape, dtype=torch.int, device=device)