diff --git a/test/test_operations.py b/test/test_operations.py index ed8f5a88151..4db6be38cea 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1982,6 +1982,22 @@ def foo(x): for dtype in test_dtypes: test(dtype) + def test_gelu_backward_different_types(self): + + def foo(grad, inp): + return torch.ops.aten.gelu_backward.default(grad, inp) + + grad = torch.rand(10, 10, dtype=torch.bfloat16) + inp = torch.rand(10, 10) + + Xgrad = grad.to(xm.xla_device()) + Xinp = inp.to(xm.xla_device()) + + r = foo(grad, inp) + Xr = foo(Xgrad, Xinp) + + self.assertEqual(r, Xr.cpu()) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 12a49a91ad9..400d885b00c 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1504,8 +1504,10 @@ at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad, const at::Tensor& self, c10::string_view approximate) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + at::ScalarType result_type = at::result_type(grad, self); return bridge::AtenFromXlaTensor(tensor_methods::gelu_backward( - bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self), approximate)); + bridge::GetXlaTensor(grad.to(result_type)), + bridge::GetXlaTensor(self.to(result_type)), approximate)); } at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self,