Skip to content

Commit

Permalink
Revert "Fix some more core aten ops (#6342)" (#6377)
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 authored and bhavya01 committed Apr 22, 2024
1 parent f31a14a commit 24eb3c1
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 44 deletions.
2 changes: 1 addition & 1 deletion codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ full_codegen:
- rsqrt
- selu
- sgn
- sigmoid
- sign
- silu
- silu_backward
Expand Down Expand Up @@ -303,6 +302,7 @@ supported:
- select_scatter
- selu_
- set_.source_Tensor
- sigmoid
- sigmoid_backward
- slice_copy.Tensor
- slice_scatter
Expand Down
22 changes: 6 additions & 16 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,17 +1904,11 @@ def test_aten_gelu_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs)

@unittest.skip
def test_aten_gelu_1(self):
args = (torch.randn((10, 10)).to(torch.float16),)
kwargs = dict()
run_export_and_compare(
self,
torch.ops.aten.gelu,
args,
kwargs,
rtol=0.001,
atol=0.01,
)
run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs)

def test_aten_glu_0(self):
args = (
Expand Down Expand Up @@ -3091,6 +3085,7 @@ def test_aten_native_group_norm_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs)

@unittest.skip
def test_aten_native_group_norm_1(self):
args = (
torch.randn((1, 3, 2, 10)).to(torch.float16),
Expand All @@ -3103,14 +3098,7 @@ def test_aten_native_group_norm_1(self):
0.0,
)
kwargs = dict()
run_export_and_compare(
self,
torch.ops.aten.native_group_norm,
args,
kwargs,
rtol=0.001,
atol=0.01,
)
run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs)

def test_aten_native_layer_norm_0(self):
args = (
Expand Down Expand Up @@ -3417,6 +3405,7 @@ def test_aten_reciprocal_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs)

@unittest.skip
def test_aten_reciprocal_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
kwargs = dict()
Expand Down Expand Up @@ -4014,6 +4003,7 @@ def test_aten_sigmoid_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs)

@unittest.skip
def test_aten_sigmoid_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
kwargs = dict()
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2774,6 +2774,12 @@ at::Tensor& XLANativeFunctions::set_(at::Tensor& self,
return self;
}

at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(
tensor_methods::sigmoid(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output,
const at::Tensor& output) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand Down
12 changes: 0 additions & 12 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,10 +684,6 @@ torch_xla::XlaOpVector NeTensor::Lower(LoweringContext* loctx) const {

torch_xla::XlaOpVector Reciprocal::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(BuildReciprocal(xla_input), loctx);
}

Expand Down Expand Up @@ -730,14 +726,6 @@ torch_xla::XlaOpVector Sgn::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildSgn(xla_input), loctx);
}

torch_xla::XlaOpVector Sigmoid::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla_input = xla::ConvertElementType(xla_input, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Logistic(xla_input), loctx);
}

torch_xla::XlaOpVector Sign::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildSign(xla_input), loctx);
Expand Down
14 changes: 1 addition & 13 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,11 +762,7 @@ xla::Shape NeTensorOutputShape(const torch::lazy::Value& self,
}

xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input) {
xla::Shape result_shape = GetXlaShape(input);
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
result_shape.set_element_type(xla::PrimitiveType::F32);
}
return result_shape;
return GetXlaShape(input);
}

xla::Shape ReluOutputShape(const torch::lazy::Value& input) {
Expand Down Expand Up @@ -808,14 +804,6 @@ xla::Shape SgnOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape SigmoidOutputShape(const torch::lazy::Value& input) {
xla::Shape result_shape = GetXlaShape(input);
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
result_shape.set_element_type(xla::PrimitiveType::F32);
}
return result_shape;
}

xla::Shape SignOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,6 @@ xla::Shape SeluOutputShape(const torch::lazy::Value& input);

xla::Shape SgnOutputShape(const torch::lazy::Value& input);

xla::Shape SigmoidOutputShape(const torch::lazy::Value& input);

xla::Shape SignOutputShape(const torch::lazy::Value& input);

xla::Shape SiluOutputShape(const torch::lazy::Value& input);
Expand Down

0 comments on commit 24eb3c1

Please sign in to comment.