diff --git a/scripts/gen_lazy_tensor.py b/scripts/gen_lazy_tensor.py index b52bbd1eac0..b3a0eccbfe6 100644 --- a/scripts/gen_lazy_tensor.py +++ b/scripts/gen_lazy_tensor.py @@ -50,7 +50,7 @@ def node_base_ctor_call(self, schema: LazyIrSchema) -> str: base_ctor_value_args = ", ".join(base_ctor_value_args_list) shape_fn_inputs_list = [ - f"{a.name}" for a in schema.positional_args + f"{a.name}" for a in (schema.positional_args + schema.keyword_args) if (a.is_lazy_value or isinstance(a.lazy_type, VectorCType) or is_boolean_dtype(a.lazy_type) or a.name == 'reduction') ] diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index a5ac7dcb6ba..7599b649731 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -612,37 +612,6 @@ at::Tensor XLANativeFunctions::add(const at::Tensor& self, }); } -at::Tensor XLANativeFunctions::addcdiv(const at::Tensor& self, - const at::Tensor& tensor1, - const at::Tensor& tensor2, - const at::Scalar& value) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::addcdiv( - bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1), - bridge::GetXlaTensor(tensor2))); -} - -at::Tensor& XLANativeFunctions::addcdiv_(at::Tensor& self, - const at::Tensor& tensor1, - const at::Tensor& tensor2, - const at::Scalar& value) { - XLA_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - XLATensor::addcdiv_(self_tensor, value, bridge::GetXlaTensor(tensor1), - bridge::GetXlaTensor(tensor2)); - return self; -} - -at::Tensor XLANativeFunctions::addcmul(const at::Tensor& self, - const at::Tensor& tensor1, - const at::Tensor& tensor2, - const at::Scalar& value) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::addcmul( - bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1), - bridge::GetXlaTensor(tensor2))); -} - at::Tensor XLANativeFunctions::addmm(const at::Tensor& self, const at::Tensor& mat1, const at::Tensor& mat2, diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index d562883b38e..d8ba7a99fc2 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -7,6 +7,7 @@ #include "torch_xla/csrc/matrix.h" #include "torch_xla/csrc/pooling.h" #include "torch_xla/csrc/reduction.h" +#include "torch_xla/csrc/xla_lower_util.h" namespace torch_xla { torch_xla::XlaOpVector Abs::Lower(LoweringContext* loctx) const { @@ -69,6 +70,22 @@ torch_xla::XlaOpVector Amin::Lower(LoweringContext* loctx) const { return ReturnOp(BuildMinInDims(input, dim, keepdim), loctx); } +torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1)); + xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2)); + xla::XlaOp xla_val = loctx->GetOutputOp(operand(3)); + return ReturnOp(BuildAddcdiv(xla_input, xla_t1, xla_t2, xla_val), loctx); +} + +torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1)); + xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2)); + xla::XlaOp xla_val = loctx->GetOutputOp(operand(3)); + return ReturnOp(BuildAddcmul(xla_input, xla_t1, xla_t2, xla_val), loctx); +} + torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); return ReturnOp(xla::Asin(xla_input), loctx); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 7872dfb09bd..55e4736fdf2 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -7,6 +7,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/pooling.h" #include "torch_xla/csrc/reduction.h" +#include "torch_xla/csrc/xla_lower_util.h" namespace torch_xla { namespace { @@ -109,6 +110,31 @@ xla::Shape AllOutputShape(const torch::lazy::Value& input) { return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); } +xla::Shape AddcdivOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& t1, + const torch::lazy::Value& t2, + const torch::lazy::Value& value) { + auto shape_fn = [](absl::Span operands) -> xla::XlaOp { + return BuildAddcdiv(operands[0], operands[1], operands[2], operands[3]); + }; + return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2), + GetXlaShape(value)}, + shape_fn); +} + +xla::Shape AddcmulOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& t1, + const torch::lazy::Value& t2, + const torch::lazy::Value& value) { + auto shape_fn = [](absl::Span operands) -> xla::XlaOp { + return BuildAddcmul(operands[0], operands[1], operands[2], operands[3]); + }; + + return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2), + GetXlaShape(value)}, + shape_fn); +} + xla::Shape AsinOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 329990cde3c..e5d8ab55ef2 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -29,6 +29,16 @@ xla::Shape AminOutputShape(const torch::lazy::Value& input, xla::Shape AllOutputShape(const torch::lazy::Value& input); +xla::Shape AddcdivOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& t1, + const torch::lazy::Value& t2, + const torch::lazy::Value& value); + +xla::Shape AddcmulOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& t1, + const torch::lazy::Value& t2, + const torch::lazy::Value& value); + xla::Shape AsinOutputShape(const torch::lazy::Value& input); xla::Shape AsinhOutputShape(const torch::lazy::Value& input); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 7679f27228c..66a7895917b 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -660,35 +660,6 @@ XLATensorPtr XLATensor::add( logical_element_type); } -XLATensorPtr XLATensor::addcdiv(const XLATensorPtr& input, - const at::Scalar& value, - const XLATensorPtr& tensor1, - const XLATensorPtr& tensor2) { - torch::lazy::Value constant = GetIrValueForScalar( - value, tensor1->shape().get().element_type(), input->GetDevice()); - torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue(); - return input->CreateFrom(input->GetIrValue() + div * constant); -} - -void XLATensor::addcdiv_(XLATensorPtr& input, const at::Scalar& value, - const XLATensorPtr& tensor1, - const XLATensorPtr& tensor2) { - torch::lazy::Value constant = GetIrValueForScalar( - value, tensor1->shape().get().element_type(), input->GetDevice()); - torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue(); - input->SetInPlaceIrValue(input->GetIrValue() + div * constant); -} - -XLATensorPtr XLATensor::addcmul(const XLATensorPtr& input, - const at::Scalar& value, - const XLATensorPtr& tensor1, - const XLATensorPtr& tensor2) { - torch::lazy::Value constant = GetIrValueForScalar( - value, tensor1->shape().get().element_type(), input->GetDevice()); - torch::lazy::Value mul = tensor1->GetIrValue() * tensor2->GetIrValue(); - return input->CreateFrom(input->GetIrValue() + mul * constant); -} - XLATensorPtr XLATensor::addmm(const XLATensorPtr& input, const XLATensorPtr& weight, const XLATensorPtr& bias) { diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 6171b78384c..01299aa0890 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1012,4 +1012,17 @@ xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span shifts, return need_flatten ? xla::Reshape(input, input_shape.dimensions()) : input; } +xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, + xla::XlaOp val) { + return XlaHelpers::PromotedAdd( + input, XlaHelpers::PromotedMul(XlaHelpers::PromotedDiv(t1, t2), val)); +} + +xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, + xla::XlaOp val) { + val = MaybeConvertTo(val, XlaHelpers::ShapeOfXlaOp(t1).element_type()); + return XlaHelpers::PromotedAdd( + input, XlaHelpers::PromotedMul(XlaHelpers::PromotedMul(t1, t2), val)); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 731095ca8df..be39d3de013 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -119,4 +119,10 @@ xla::XlaOp BuildXLogY(xla::XlaOp input, xla::XlaOp other); xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span shifts, absl::Span dims); +xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, + xla::XlaOp val); + +xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, + xla::XlaOp val); + } // namespace torch_xla diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 27c7ab2e1e7..11a4d958bd1 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -7,6 +7,8 @@ full_codegen: - all - amax - amin + - addcdiv + - addcmul - asin - asinh - atan @@ -95,9 +97,6 @@ supported: - adaptive_max_pool2d_backward - add.Scalar - add.Tensor - - addcdiv - - addcdiv_ - - addcmul - addmm - alias - all.dim