Skip to content

Commit

Permalink
Full codegen asin, asinh, atan, and atanh (#3565)
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 authored May 13, 2022
1 parent 2cc13c0 commit 56a52d5
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 66 deletions.
22 changes: 0 additions & 22 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,28 +709,6 @@ const at::Tensor& XLANativeFunctions::as_strided_(
return self;
}

at::Tensor XLANativeFunctions::asin(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::asin(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::asinh(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::asinh(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::atan(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::atan(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::atanh(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::atanh(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::atan2(const at::Tensor& self,
const at::Tensor& other) {
XLA_FN_COUNTER("xla::");
Expand Down
4 changes: 0 additions & 4 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,8 @@ namespace torch_xla {

PTXLA_UNARY_OP(Cos, at::aten::cos, xla::Cos);
PTXLA_UNARY_OP(Cosh, at::aten::cosh, xla::Cosh);
PTXLA_UNARY_OP(Asin, at::aten::asin, xla::Asin);
PTXLA_UNARY_OP(Asinh, at::aten::asinh, xla::Asinh);
PTXLA_UNARY_OP(Sin, at::aten::sin, xla::Sin);
PTXLA_UNARY_OP(Sinh, at::aten::sinh, xla::Sinh);
PTXLA_UNARY_OP(Atan, at::aten::atan, xla::Atan);
PTXLA_UNARY_OP(Atanh, at::aten::atanh, xla::Atanh);
PTXLA_UNARY_OP(Tan, at::aten::tan, xla::Tan);
PTXLA_UNARY_OP(Tanh, at::aten::tanh, xla::Tanh);
PTXLA_UNARY_OP(Neg, at::aten::neg, xla::Neg);
Expand Down
12 changes: 0 additions & 12 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,14 @@ inline torch::lazy::NodePtr GenericOp(torch::lazy::OpKind op, xla::Shape shape,
hash_seed);
}

torch::lazy::NodePtr Acos(const XlaValue& input);

torch::lazy::NodePtr Acosh(const XlaValue& input);

torch::lazy::NodePtr Cos(const XlaValue& input);

torch::lazy::NodePtr Cosh(const XlaValue& input);

torch::lazy::NodePtr Asin(const XlaValue& input);

torch::lazy::NodePtr Asinh(const XlaValue& input);

torch::lazy::NodePtr Sin(const XlaValue& input);

torch::lazy::NodePtr Sinh(const XlaValue& input);

torch::lazy::NodePtr Atan(const XlaValue& input);

torch::lazy::NodePtr Atanh(const XlaValue& input);

torch::lazy::NodePtr Atan2(const XlaValue& input, const XlaValue& other);

torch::lazy::NodePtr Tan(const XlaValue& input);
Expand Down
20 changes: 20 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ torch_xla::XlaOpVector Acosh::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::Acosh(xla_input), loctx);
}

torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Asin(xla_input), loctx);
}

torch_xla::XlaOpVector Asinh::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Asinh(xla_input), loctx);
}

torch_xla::XlaOpVector Atan::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Atan(xla_input), loctx);
}

torch_xla::XlaOpVector Atanh::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Atanh(xla_input), loctx);
}

torch_xla::XlaOpVector Maximum::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ xla::Shape AcosOutputShape(const XlaValue& input) { return input.xla_shape(); }

xla::Shape AcoshOutputShape(const XlaValue& input) { return input.xla_shape(); }

xla::Shape AsinOutputShape(const XlaValue& input) { return input.xla_shape(); }

xla::Shape AsinhOutputShape(const XlaValue& input) { return input.xla_shape(); }

xla::Shape AtanOutputShape(const XlaValue& input) { return input.xla_shape(); }

xla::Shape AtanhOutputShape(const XlaValue& input) { return input.xla_shape(); }

xla::Shape MaximumOutputShape(const XlaValue& input, const XlaValue& other) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ xla::Shape AcosOutputShape(const XlaValue& input);

xla::Shape AcoshOutputShape(const XlaValue& input);

xla::Shape AsinOutputShape(const XlaValue& input);

xla::Shape AsinhOutputShape(const XlaValue& input);

xla::Shape AtanOutputShape(const XlaValue& input);

xla::Shape AtanhOutputShape(const XlaValue& input);

xla::Shape MaximumOutputShape(const XlaValue& input, const XlaValue& other);

} // namespace torch_xla
8 changes: 0 additions & 8 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,14 +393,6 @@ class XLATensor : public c10::intrusive_ptr_target {
std::vector<int64_t> stride,
c10::optional<int64_t> storage_offset);

static XLATensor asin(const XLATensor& input);

static XLATensor asinh(const XLATensor& input);

static XLATensor atan(const XLATensor& input);

static XLATensor atanh(const XLATensor& input);

static XLATensor atan2(
const XLATensor& input, const XLATensor& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
Expand Down
16 changes: 0 additions & 16 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,22 +822,6 @@ void XLATensor::as_strided_(XLATensor& input, std::vector<int64_t> size,
}
}

XLATensor XLATensor::asin(const XLATensor& input) {
return input.CreateFrom(Asin(input.GetIrValue()));
}

XLATensor XLATensor::asinh(const XLATensor& input) {
return input.CreateFrom(Asinh(input.GetIrValue()));
}

XLATensor XLATensor::atan(const XLATensor& input) {
return input.CreateFrom(Atan(input.GetIrValue()));
}

XLATensor XLATensor::atanh(const XLATensor& input) {
return input.CreateFrom(Atanh(input.GetIrValue()));
}

XLATensor XLATensor::atan2(const XLATensor& input, const XLATensor& other,
c10::optional<at::ScalarType> logical_element_type) {
return input.CreateFrom(Atan2(input.GetIrValue(), other.GetIrValue()),
Expand Down
8 changes: 4 additions & 4 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ full_codegen:
- acos
- acosh
- abs
- asin
- asinh
- atan
- atanh
- maximum
supported:
- __ilshift__.Scalar
Expand Down Expand Up @@ -52,11 +56,7 @@ supported:
- argmin
- as_strided
- as_strided_
- asin
- asinh
- atan
- atan2
- atanh
- avg_pool2d
- avg_pool2d_backward
- avg_pool3d
Expand Down

0 comments on commit 56a52d5

Please sign in to comment.