From 895b0c2ea14ae1d5c5daf76056de64df28d1b6f2 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 2 Apr 2024 17:00:18 -0700 Subject: [PATCH] Fix div lowering and core aten test script enhancement (#6873) Co-authored-by: Siyuan Liu --- test/test_core_aten_ops.py | 9 +++++++++ torch_xla/csrc/ops/ops.cpp | 4 +++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 4cb2e7c2076..591d2d18a4c 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -16,6 +16,7 @@ def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): output2_cpu = output2.detach().cpu() if output2_cpu.dtype != output1.dtype: output2_cpu = output2_cpu.to(output1.dtype) + testcase.assertEqual(output1.shape, output2.shape) testcase.assertTrue( torch.allclose( output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan)) @@ -1174,6 +1175,14 @@ def test_aten_div_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs) + def test_aten_div_Tensor_3(self): + args = ( + torch.rand(1, 3, 4, 1), + torch.rand(10), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs) + def test_aten_div_Tensor_mode_0(self): def aten_div_Tensor_mode_rounding_mode_trunc(input, other): diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index dfee7621adc..7391f8ff714 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -726,7 +726,9 @@ torch::lazy::NodePtr Div(const torch::lazy::Value& input, return node.ReturnOp(BuildDiv(xla_input, xla_divisor), loctx); }; return GenericOp(torch::lazy::OpKind(at::aten::div), {input, divisor}, - GetXlaShape(input), std::move(lower_fn)); + XlaHelpers::GetPromotedBinaryOpShape(GetXlaShape(input), + GetXlaShape(divisor)), + std::move(lower_fn)); } torch::lazy::NodePtr MaxUnary(const torch::lazy::Value& input) {