From 73f7e978ee166598f3b222dcffe139fab3db07ce Mon Sep 17 00:00:00 2001 From: Emilio Cota Date: Fri, 26 Jan 2024 02:17:06 -0500 Subject: [PATCH] Reapply "Fix some more core aten ops (#6342)" (#6377) (#6387) --- codegen/xla_native_functions.yaml | 2 +- test/test_core_aten_ops.py | 22 ++++++++++++++++------ torch_xla/csrc/aten_xla_type.cpp | 6 ------ torch_xla/csrc/ops/ops_lower_fn.cpp | 12 ++++++++++++ torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 14 +++++++++++++- torch_xla/csrc/ops/ops_xla_shape_fn.h | 2 ++ 6 files changed, 44 insertions(+), 14 deletions(-) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index 09c9a0a0190..4db126feade 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -77,6 +77,7 @@ full_codegen: - rsqrt - selu - sgn + - sigmoid - sign - silu - silu_backward @@ -304,7 +305,6 @@ supported: - select_scatter - selu_ - set_.source_Tensor - - sigmoid - sigmoid_backward - slice_copy.Tensor - slice_scatter diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 498a17ae90e..9217788f9d1 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -1904,11 +1904,17 @@ def test_aten_gelu_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs) - @unittest.skip def test_aten_gelu_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs) + run_export_and_compare( + self, + torch.ops.aten.gelu, + args, + kwargs, + rtol=0.001, + atol=0.01, + ) def test_aten_glu_0(self): args = ( @@ -3082,7 +3088,6 @@ def test_aten_native_group_norm_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs) - @unittest.skip def test_aten_native_group_norm_1(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float16), @@ -3095,7 +3100,14 @@ def test_aten_native_group_norm_1(self): 0.0, ) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs) + run_export_and_compare( + self, + torch.ops.aten.native_group_norm, + args, + kwargs, + rtol=0.001, + atol=0.01, + ) def test_aten_native_layer_norm_0(self): args = ( @@ -3411,7 +3423,6 @@ def test_aten_reciprocal_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) - @unittest.skip def test_aten_reciprocal_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -4009,7 +4020,6 @@ def test_aten_sigmoid_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs) - @unittest.skip def test_aten_sigmoid_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c6bbd6f5718..89e8921c577 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2807,12 +2807,6 @@ at::Tensor& XLANativeFunctions::set_(at::Tensor& self, return self; } -at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::sigmoid(bridge::GetXlaTensor(self))); -} - at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output, const at::Tensor& output) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 64445a6789a..9a765db749a 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -684,6 +684,10 @@ torch_xla::XlaOpVector NeTensor::Lower(LoweringContext* loctx) const { torch_xla::XlaOpVector Reciprocal::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { + xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); + } return ReturnOp(BuildReciprocal(xla_input), loctx); } @@ -726,6 +730,14 @@ torch_xla::XlaOpVector Sgn::Lower(LoweringContext* loctx) const { return ReturnOp(BuildSgn(xla_input), loctx); } +torch_xla::XlaOpVector Sigmoid::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { + xla_input = xla::ConvertElementType(xla_input, xla::PrimitiveType::F32); + } + return ReturnOp(xla::Logistic(xla_input), loctx); +} + torch_xla::XlaOpVector Sign::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); return ReturnOp(BuildSign(xla_input), loctx); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index c94dbd0924a..b0133da3ec7 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -762,7 +762,11 @@ xla::Shape NeTensorOutputShape(const torch::lazy::Value& self, } xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input) { - return GetXlaShape(input); + xla::Shape result_shape = GetXlaShape(input); + if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { + result_shape.set_element_type(xla::PrimitiveType::F32); + } + return result_shape; } xla::Shape ReluOutputShape(const torch::lazy::Value& input) { @@ -804,6 +808,14 @@ xla::Shape SgnOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } +xla::Shape SigmoidOutputShape(const torch::lazy::Value& input) { + xla::Shape result_shape = GetXlaShape(input); + if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { + result_shape.set_element_type(xla::PrimitiveType::F32); + } + return result_shape; +} + xla::Shape SignOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 6f961f50cde..639edc1679b 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -248,6 +248,8 @@ xla::Shape SeluOutputShape(const torch::lazy::Value& input); xla::Shape SgnOutputShape(const torch::lazy::Value& input); +xla::Shape SigmoidOutputShape(const torch::lazy::Value& input); + xla::Shape SignOutputShape(const torch::lazy::Value& input); xla::Shape SiluOutputShape(const torch::lazy::Value& input);