Skip to content

Commit

Permalink
Full codegen addcmul
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed May 13, 2022
1 parent 56a52d5 commit 66e0da5
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 2 deletions.
9 changes: 9 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ torch_xla::XlaOpVector Acosh::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::Acosh(xla_input), loctx);
}

torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_tensor_1 = loctx->GetOutputOp(operand(1));
xla::XlaOp xla_tensor_2 = loctx->GetOutputOp(operand(2));
xla::XlaOp xla_constant = loctx->GetOutputOp(operand(3));
xla::XlaOp mul = xla_tensor_1 * xla_tensor_2;
return ReturnOp(xla_input + mul * xla_constant, loctx);
}

torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Asin(xla_input), loctx);
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ xla::Shape AcosOutputShape(const XlaValue& input) { return input.xla_shape(); }

xla::Shape AcoshOutputShape(const XlaValue& input) { return input.xla_shape(); }

xla::Shape AddcmulOutputShape(const XlaValue& input, const XlaValue& tensor1, const XlaValue& tensor2, const XlaValue& value) { return input.xla_shape(); }

xla::Shape AsinOutputShape(const XlaValue& input) { return input.xla_shape(); }

xla::Shape AsinhOutputShape(const XlaValue& input) { return input.xla_shape(); }
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ xla::Shape AcosOutputShape(const XlaValue& input);

xla::Shape AcoshOutputShape(const XlaValue& input);

xla::Shape AddcmulOutputShape(const XlaValue& input, const XlaValue& tensor1, const XlaValue& tensor2, const XlaValue& value);

xla::Shape AsinOutputShape(const XlaValue& input);

xla::Shape AsinhOutputShape(const XlaValue& input);
Expand Down
4 changes: 2 additions & 2 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
backend: XLA
cpp_namespace: torch_xla
full_codegen:
- abs
- acos
- acosh
- abs
- addcmul
- asin
- asinh
- atan
Expand Down Expand Up @@ -42,7 +43,6 @@ supported:
- add.Tensor
- addcdiv
- addcdiv_
- addcmul
- addmm
- alias
- all
Expand Down

0 comments on commit 66e0da5

Please sign in to comment.