Skip to content

Commit

Permalink
Codegen addcdiv and addcmul (#3768)
Browse files Browse the repository at this point in the history
* Codegen addcdiv and addcmul

* pin

* Use promoteAdd/Div/Mul

* remove comment

* Convert scalar to the right type

* Delete .torch_pin
  • Loading branch information
JackCaoG authored Aug 16, 2022
1 parent 6639bcc commit 6bfcd24
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 64 deletions.
2 changes: 1 addition & 1 deletion scripts/gen_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
base_ctor_value_args = ", ".join(base_ctor_value_args_list)

shape_fn_inputs_list = [
f"{a.name}" for a in schema.positional_args
f"{a.name}" for a in (schema.positional_args + schema.keyword_args)
if (a.is_lazy_value or isinstance(a.lazy_type, VectorCType) or
is_boolean_dtype(a.lazy_type) or a.name == 'reduction')
]
Expand Down
31 changes: 0 additions & 31 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,37 +612,6 @@ at::Tensor XLANativeFunctions::add(const at::Tensor& self,
});
}

at::Tensor XLANativeFunctions::addcdiv(const at::Tensor& self,
const at::Tensor& tensor1,
const at::Tensor& tensor2,
const at::Scalar& value) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::addcdiv(
bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1),
bridge::GetXlaTensor(tensor2)));
}

at::Tensor& XLANativeFunctions::addcdiv_(at::Tensor& self,
const at::Tensor& tensor1,
const at::Tensor& tensor2,
const at::Scalar& value) {
XLA_FN_COUNTER("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
XLATensor::addcdiv_(self_tensor, value, bridge::GetXlaTensor(tensor1),
bridge::GetXlaTensor(tensor2));
return self;
}

at::Tensor XLANativeFunctions::addcmul(const at::Tensor& self,
const at::Tensor& tensor1,
const at::Tensor& tensor2,
const at::Scalar& value) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::addcmul(
bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1),
bridge::GetXlaTensor(tensor2)));
}

at::Tensor XLANativeFunctions::addmm(const at::Tensor& self,
const at::Tensor& mat1,
const at::Tensor& mat2,
Expand Down
17 changes: 17 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "torch_xla/csrc/matrix.h"
#include "torch_xla/csrc/pooling.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
torch_xla::XlaOpVector Abs::Lower(LoweringContext* loctx) const {
Expand Down Expand Up @@ -69,6 +70,22 @@ torch_xla::XlaOpVector Amin::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildMinInDims(input, dim, keepdim), loctx);
}

torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
return ReturnOp(BuildAddcdiv(xla_input, xla_t1, xla_t2, xla_val), loctx);
}

torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
return ReturnOp(BuildAddcmul(xla_input, xla_t1, xla_t2, xla_val), loctx);
}

torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Asin(xla_input), loctx);
Expand Down
26 changes: 26 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/pooling.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
namespace {
Expand Down Expand Up @@ -109,6 +110,31 @@ xla::Shape AllOutputShape(const torch::lazy::Value& input) {
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& t1,
const torch::lazy::Value& t2,
const torch::lazy::Value& value) {
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildAddcdiv(operands[0], operands[1], operands[2], operands[3]);
};
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
GetXlaShape(value)},
shape_fn);
}

xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& t1,
const torch::lazy::Value& t2,
const torch::lazy::Value& value) {
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildAddcmul(operands[0], operands[1], operands[2], operands[3]);
};

return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
GetXlaShape(value)},
shape_fn);
}

xla::Shape AsinOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ xla::Shape AminOutputShape(const torch::lazy::Value& input,

xla::Shape AllOutputShape(const torch::lazy::Value& input);

xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& t1,
const torch::lazy::Value& t2,
const torch::lazy::Value& value);

xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& t1,
const torch::lazy::Value& t2,
const torch::lazy::Value& value);

xla::Shape AsinOutputShape(const torch::lazy::Value& input);

xla::Shape AsinhOutputShape(const torch::lazy::Value& input);
Expand Down
29 changes: 0 additions & 29 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,35 +660,6 @@ XLATensorPtr XLATensor::add(
logical_element_type);
}

XLATensorPtr XLATensor::addcdiv(const XLATensorPtr& input,
const at::Scalar& value,
const XLATensorPtr& tensor1,
const XLATensorPtr& tensor2) {
torch::lazy::Value constant = GetIrValueForScalar(
value, tensor1->shape().get().element_type(), input->GetDevice());
torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
return input->CreateFrom(input->GetIrValue() + div * constant);
}

void XLATensor::addcdiv_(XLATensorPtr& input, const at::Scalar& value,
const XLATensorPtr& tensor1,
const XLATensorPtr& tensor2) {
torch::lazy::Value constant = GetIrValueForScalar(
value, tensor1->shape().get().element_type(), input->GetDevice());
torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
input->SetInPlaceIrValue(input->GetIrValue() + div * constant);
}

XLATensorPtr XLATensor::addcmul(const XLATensorPtr& input,
const at::Scalar& value,
const XLATensorPtr& tensor1,
const XLATensorPtr& tensor2) {
torch::lazy::Value constant = GetIrValueForScalar(
value, tensor1->shape().get().element_type(), input->GetDevice());
torch::lazy::Value mul = tensor1->GetIrValue() * tensor2->GetIrValue();
return input->CreateFrom(input->GetIrValue() + mul * constant);
}

XLATensorPtr XLATensor::addmm(const XLATensorPtr& input,
const XLATensorPtr& weight,
const XLATensorPtr& bias) {
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,4 +1012,17 @@ xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span<const int64_t> shifts,
return need_flatten ? xla::Reshape(input, input_shape.dimensions()) : input;
}

xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
xla::XlaOp val) {
return XlaHelpers::PromotedAdd(
input, XlaHelpers::PromotedMul(XlaHelpers::PromotedDiv(t1, t2), val));
}

xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
xla::XlaOp val) {
val = MaybeConvertTo(val, XlaHelpers::ShapeOfXlaOp(t1).element_type());
return XlaHelpers::PromotedAdd(
input, XlaHelpers::PromotedMul(XlaHelpers::PromotedMul(t1, t2), val));
}

} // namespace torch_xla
6 changes: 6 additions & 0 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,10 @@ xla::XlaOp BuildXLogY(xla::XlaOp input, xla::XlaOp other);
xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span<const int64_t> shifts,
absl::Span<const int64_t> dims);

xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
xla::XlaOp val);

xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
xla::XlaOp val);

} // namespace torch_xla
5 changes: 2 additions & 3 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ full_codegen:
- all
- amax
- amin
- addcdiv
- addcmul
- asin
- asinh
- atan
Expand Down Expand Up @@ -95,9 +97,6 @@ supported:
- adaptive_max_pool2d_backward
- add.Scalar
- add.Tensor
- addcdiv
- addcdiv_
- addcmul
- addmm
- alias
- all.dim
Expand Down

0 comments on commit 6bfcd24

Please sign in to comment.