From 286b31f0c0c752306e4a80a566b1ec9e82653991 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Thu, 13 Jun 2024 17:03:54 -0700 Subject: [PATCH] Revert the mul change (#7271) --- test/test_operations.py | 2 ++ torch_xla/csrc/aten_xla_type.cpp | 13 +++++-------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 938817a6fd2..6fb0b79d78d 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2099,6 +2099,8 @@ def test(f, xshape, ishapes): for xshape, i0shape, i1shape in cases[f2]: test(f2, xshape, (i0shape, i1shape)) + @unittest.skipIf( + True, "skip since https://github.com/pytorch/xla/pull/7130 is reverted") def test_inplace_mul_scalar_different_dtype(self): # This tests whether the returned output data-type agrees on PyTorch # and XLA sides. diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index dc30734756d..3459b8935e8 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2221,14 +2221,11 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output, at::Tensor XLANativeFunctions::mul(const at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - using FnType = XLATensorPtr(const XLATensorPtr&, const XLATensorPtr&, - std::optional); - return OpConfig::From(static_cast(tensor_methods::mul)) - .add_input(self) - .add_input(other) - .cast_inputs_to_common_dtype() - .use_opmathtype_for_compute() - .run(); + return DoBinaryOp(self, other, + [&](const XLATensorPtr& xself, const XLATensorPtr& xother, + at::ScalarType dtype) { + return tensor_methods::mul(xself, xother, dtype); + }); } at::Tensor XLANativeFunctions::mul(const at::Tensor& self,