From 24eb3c133188aac714442a6359ebfab247cdc84e Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Wed, 24 Jan 2024 23:20:43 -0800 Subject: [PATCH] Revert "Fix some more core aten ops (#6342)" (#6377) --- codegen/xla_native_functions.yaml | 2 +- test/test_core_aten_ops.py | 22 ++++++---------------- torch_xla/csrc/aten_xla_type.cpp | 6 ++++++ torch_xla/csrc/ops/ops_lower_fn.cpp | 12 ------------ torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 14 +------------- torch_xla/csrc/ops/ops_xla_shape_fn.h | 2 -- 6 files changed, 14 insertions(+), 44 deletions(-) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index fbe059fb8c9..a7f59e26575 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -77,7 +77,6 @@ full_codegen: - rsqrt - selu - sgn - - sigmoid - sign - silu - silu_backward @@ -303,6 +302,7 @@ supported: - select_scatter - selu_ - set_.source_Tensor + - sigmoid - sigmoid_backward - slice_copy.Tensor - slice_scatter diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index cb445ac7daf..1442e26fe28 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -1904,17 +1904,11 @@ def test_aten_gelu_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs) + @unittest.skip def test_aten_gelu_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.gelu, - args, - kwargs, - rtol=0.001, - atol=0.01, - ) + run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs) def test_aten_glu_0(self): args = ( @@ -3091,6 +3085,7 @@ def test_aten_native_group_norm_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs) + @unittest.skip def test_aten_native_group_norm_1(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float16), @@ -3103,14 +3098,7 @@ def test_aten_native_group_norm_1(self): 0.0, ) kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.native_group_norm, - args, - kwargs, - rtol=0.001, - atol=0.01, - ) + run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs) def test_aten_native_layer_norm_0(self): args = ( @@ -3417,6 +3405,7 @@ def test_aten_reciprocal_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) + @unittest.skip def test_aten_reciprocal_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -4014,6 +4003,7 @@ def test_aten_sigmoid_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs) + @unittest.skip def test_aten_sigmoid_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 113fbd276c7..a40982452d4 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2774,6 +2774,12 @@ at::Tensor& XLANativeFunctions::set_(at::Tensor& self, return self; } +at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + return bridge::AtenFromXlaTensor( + tensor_methods::sigmoid(bridge::GetXlaTensor(self))); +} + at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output, const at::Tensor& output) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 9a765db749a..64445a6789a 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -684,10 +684,6 @@ torch_xla::XlaOpVector NeTensor::Lower(LoweringContext* loctx) const { torch_xla::XlaOpVector Reciprocal::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); - if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { - xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); - } return ReturnOp(BuildReciprocal(xla_input), loctx); } @@ -730,14 +726,6 @@ torch_xla::XlaOpVector Sgn::Lower(LoweringContext* loctx) const { return ReturnOp(BuildSgn(xla_input), loctx); } -torch_xla::XlaOpVector Sigmoid::Lower(LoweringContext* loctx) const { - xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); - if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { - xla_input = xla::ConvertElementType(xla_input, xla::PrimitiveType::F32); - } - return ReturnOp(xla::Logistic(xla_input), loctx); -} - torch_xla::XlaOpVector Sign::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); return ReturnOp(BuildSign(xla_input), loctx); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index b0133da3ec7..c94dbd0924a 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -762,11 +762,7 @@ xla::Shape NeTensorOutputShape(const torch::lazy::Value& self, } xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input) { - xla::Shape result_shape = GetXlaShape(input); - if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { - result_shape.set_element_type(xla::PrimitiveType::F32); - } - return result_shape; + return GetXlaShape(input); } xla::Shape ReluOutputShape(const torch::lazy::Value& input) { @@ -808,14 +804,6 @@ xla::Shape SgnOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } -xla::Shape SigmoidOutputShape(const torch::lazy::Value& input) { - xla::Shape result_shape = GetXlaShape(input); - if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { - result_shape.set_element_type(xla::PrimitiveType::F32); - } - return result_shape; -} - xla::Shape SignOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 639edc1679b..6f961f50cde 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -248,8 +248,6 @@ xla::Shape SeluOutputShape(const torch::lazy::Value& input); xla::Shape SgnOutputShape(const torch::lazy::Value& input); -xla::Shape SigmoidOutputShape(const torch::lazy::Value& input); - xla::Shape SignOutputShape(const torch::lazy::Value& input); xla::Shape SiluOutputShape(const torch::lazy::Value& input);