Skip to content

Commit

Permalink
Revert the mul change (#7271) (#7285)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Jun 17, 2024
1 parent fda5828 commit 7296133
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
2 changes: 2 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,6 +2099,8 @@ def test(f, xshape, ishapes):
for xshape, i0shape, i1shape in cases[f2]:
test(f2, xshape, (i0shape, i1shape))

@unittest.skipIf(
True, "skip since https://github.com/pytorch/xla/pull/7130 is reverted")
def test_inplace_mul_scalar_different_dtype(self):
# This tests whether the returned output data-type agrees on PyTorch
# and XLA sides.
Expand Down
13 changes: 5 additions & 8 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2221,14 +2221,11 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output,
at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
const at::Tensor& other) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
using FnType = XLATensorPtr(const XLATensorPtr&, const XLATensorPtr&,
std::optional<at::ScalarType>);
return OpConfig::From(static_cast<FnType*>(tensor_methods::mul))
.add_input(self)
.add_input(other)
.cast_inputs_to_common_dtype()
.use_opmathtype_for_compute()
.run();
return DoBinaryOp(self, other,
[&](const XLATensorPtr& xself, const XLATensorPtr& xother,
at::ScalarType dtype) {
return tensor_methods::mul(xself, xother, dtype);
});
}

at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
Expand Down

0 comments on commit 7296133

Please sign in to comment.