Skip to content

Commit

Permalink
Reapply "Fix some more core aten ops (#6342)" (#6377) (#6387)
Browse files Browse the repository at this point in the history
  • Loading branch information
cota authored Jan 26, 2024
1 parent c70e4cc commit 9e4db96
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 14 deletions.
2 changes: 1 addition & 1 deletion codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ full_codegen:
- rsqrt
- selu
- sgn
- sigmoid
- sign
- silu
- silu_backward
Expand Down Expand Up @@ -304,7 +305,6 @@ supported:
- select_scatter
- selu_
- set_.source_Tensor
- sigmoid
- sigmoid_backward
- slice_copy.Tensor
- slice_scatter
Expand Down
22 changes: 16 additions & 6 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,11 +1904,17 @@ def test_aten_gelu_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs)

@unittest.skip
def test_aten_gelu_1(self):
args = (torch.randn((10, 10)).to(torch.float16),)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs)
run_export_and_compare(
self,
torch.ops.aten.gelu,
args,
kwargs,
rtol=0.001,
atol=0.01,
)

def test_aten_glu_0(self):
args = (
Expand Down Expand Up @@ -3082,7 +3088,6 @@ def test_aten_native_group_norm_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs)

@unittest.skip
def test_aten_native_group_norm_1(self):
args = (
torch.randn((1, 3, 2, 10)).to(torch.float16),
Expand All @@ -3095,7 +3100,14 @@ def test_aten_native_group_norm_1(self):
0.0,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs)
run_export_and_compare(
self,
torch.ops.aten.native_group_norm,
args,
kwargs,
rtol=0.001,
atol=0.01,
)

def test_aten_native_layer_norm_0(self):
args = (
Expand Down Expand Up @@ -3411,7 +3423,6 @@ def test_aten_reciprocal_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs)

@unittest.skip
def test_aten_reciprocal_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
kwargs = dict()
Expand Down Expand Up @@ -4009,7 +4020,6 @@ def test_aten_sigmoid_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs)

@unittest.skip
def test_aten_sigmoid_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
kwargs = dict()
Expand Down
6 changes: 0 additions & 6 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2791,12 +2791,6 @@ at::Tensor& XLANativeFunctions::set_(at::Tensor& self,
return self;
}

at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(
tensor_methods::sigmoid(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output,
const at::Tensor& output) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,10 @@ torch_xla::XlaOpVector NeTensor::Lower(LoweringContext* loctx) const {

torch_xla::XlaOpVector Reciprocal::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(BuildReciprocal(xla_input), loctx);
}

Expand Down Expand Up @@ -726,6 +730,14 @@ torch_xla::XlaOpVector Sgn::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildSgn(xla_input), loctx);
}

torch_xla::XlaOpVector Sigmoid::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla_input = xla::ConvertElementType(xla_input, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Logistic(xla_input), loctx);
}

torch_xla::XlaOpVector Sign::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildSign(xla_input), loctx);
Expand Down
14 changes: 13 additions & 1 deletion torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,11 @@ xla::Shape NeTensorOutputShape(const torch::lazy::Value& self,
}

xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
xla::Shape result_shape = GetXlaShape(input);
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
result_shape.set_element_type(xla::PrimitiveType::F32);
}
return result_shape;
}

xla::Shape ReluOutputShape(const torch::lazy::Value& input) {
Expand Down Expand Up @@ -804,6 +808,14 @@ xla::Shape SgnOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape SigmoidOutputShape(const torch::lazy::Value& input) {
xla::Shape result_shape = GetXlaShape(input);
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
result_shape.set_element_type(xla::PrimitiveType::F32);
}
return result_shape;
}

xla::Shape SignOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ xla::Shape SeluOutputShape(const torch::lazy::Value& input);

xla::Shape SgnOutputShape(const torch::lazy::Value& input);

xla::Shape SigmoidOutputShape(const torch::lazy::Value& input);

xla::Shape SignOutputShape(const torch::lazy::Value& input);

xla::Shape SiluOutputShape(const torch::lazy::Value& input);
Expand Down

0 comments on commit 9e4db96

Please sign in to comment.