From 6b8cfe6f3d0ac738446220c0f03750c106d6e2ce Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 20 May 2024 14:20:45 -0300 Subject: [PATCH] Fix overflow for `div` arguments. (#7081) --- test/cpp/test_aten_xla_tensor_4.cpp | 13 +++++++++++++ torch_xla/csrc/tensor_methods.cpp | 11 ++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor_4.cpp b/test/cpp/test_aten_xla_tensor_4.cpp index ff6130ca1b95..7a02a1079a6c 100644 --- a/test/cpp/test_aten_xla_tensor_4.cpp +++ b/test/cpp/test_aten_xla_tensor_4.cpp @@ -504,6 +504,19 @@ TEST_F(AtenXlaTensorTest, TestDivScalar) { ExpectCounterChanged("xla::div", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestDivScalarHalfOverflow) { + torch::Tensor input = torch::rand({3, 4}, torch::TensorOptions(torch::kHalf)); + torch::Scalar other = torch::Scalar(100000); + torch::Tensor out = torch::div(input, other); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_out = torch::div(xla_input, other); + AllClose(out, xla_out); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::div", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestDivScalarInPlace) { for (torch::ScalarType scalar_type : {torch::kFloat}) { torch::Tensor a = diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index bad4a7c6c2e1..7baa951c9a6e 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1,5 +1,6 @@ #include "torch_xla/csrc/tensor_methods.h" +#include #include #include #include @@ -1260,10 +1261,14 @@ XLATensorPtr div(const XLATensorPtr& input, const at::Scalar& other) { if (input_is_float) { scalar_type = MaybeUpcastToHostTorchType(input_type); } - torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type); + at::ScalarType op_math_type = at::toOpMathType(scalar_type); + torch::lazy::Value input_value = + torch::lazy::MakeNode(input->GetIrValue(), op_math_type); torch::lazy::Value other_value = XLAGraphExecutor::Get()->GetIrValueForScalar( - other, GetXlaShape(input_value).element_type(), input->GetDevice()); - return input->CreateFrom(Div(input_value, other_value), scalar_type); + other, XlaTypeFromTorchType(op_math_type), input->GetDevice()); + return input->CreateFrom( + torch::lazy::MakeNode(Div(input_value, other_value), scalar_type), + scalar_type); } XLATensorPtr einsum(const std::string& equation,