Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codegen addcdiv and addcmul #3768

Merged
merged 6 commits into from
Aug 16, 2022
Merged

Codegen addcdiv and addcmul #3768

merged 6 commits into from
Aug 16, 2022

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Jul 26, 2022

Fix #3765,
Fix #3766,
Fix #3767

Example pr of codegen op that takes at::Scalar.

The current codegen upload does not take a scalar type, hence we always upload with the default type (f64). To fix this issue, we need to identify the Value that was a at::scalar in the original aten_xla_type.cpp and cast the value to the correct dtype.

LazyIR

class Addcdiv : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::addcdiv);
  }

  Addcdiv(const torch::lazy::Value& self, const torch::lazy::Value& tensor1,
          const torch::lazy::Value& tensor2, const torch::lazy::Value& value,
          std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(
            torch::lazy::OpKind(at::aten::addcdiv),
            {self, tensor1, tensor2, value}, std::move(shapes),
            [&]() { return AddcdivOutputShape(self, tensor1, tensor2, value); },
            /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self,
                   const torch::lazy::Value& tensor1,
                   const torch::lazy::Value& tensor2,
                   const torch::lazy::Value& value) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

class Addcmul : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::addcmul);
  }

  Addcmul(const torch::lazy::Value& self, const torch::lazy::Value& tensor1,
          const torch::lazy::Value& tensor2, const torch::lazy::Value& value,
          std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(
            torch::lazy::OpKind(at::aten::addcmul),
            {self, tensor1, tensor2, value}, std::move(shapes),
            [&]() { return AddcmulOutputShape(self, tensor1, tensor2, value); },
            /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self,
                   const torch::lazy::Value& tensor1,
                   const torch::lazy::Value& tensor2,
                   const torch::lazy::Value& value) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

XLANativeFunction

at::Tensor XLANativeFunctions::addcdiv(const at::Tensor& self,
                                       const at::Tensor& tensor1,
                                       const at::Tensor& tensor2,
                                       const at::Scalar& value) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self, tensor1, tensor2);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_tensor1 =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor1,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_tensor2 =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor2,
                                                              *common_device);
  auto node_value =
      torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
          value, *common_device);
  torch::lazy::NodePtr node = torch::lazy::ReuseNode<Addcdiv>(
      lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
      lazy_tensor2->GetIrValue(), node_value);
  if (!node) {
    auto self_meta = to_meta(self);
    auto tensor1_meta = to_meta(tensor1);
    auto tensor2_meta = to_meta(tensor2);
    auto out_meta =
        at::meta::addcdiv(self_meta, tensor1_meta, tensor2_meta, value);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self, tensor1, tensor2, value};
      const char* schema_str =
          "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, "
          "Scalar value=1) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<Addcdiv>(
        lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
        lazy_tensor2->GetIrValue(), node_value, std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};

at::Tensor XLANativeFunctions::addcmul(const at::Tensor& self,
                                       const at::Tensor& tensor1,
                                       const at::Tensor& tensor2,
                                       const at::Scalar& value) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self, tensor1, tensor2);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_tensor1 =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor1,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_tensor2 =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(tensor2,
                                                              *common_device);
  auto node_value =
      torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
          value, *common_device);
  torch::lazy::NodePtr node = torch::lazy::ReuseNode<Addcmul>(
      lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
      lazy_tensor2->GetIrValue(), node_value);
  if (!node) {
    auto self_meta = to_meta(self);
    auto tensor1_meta = to_meta(tensor1);
    auto tensor2_meta = to_meta(tensor2);
    auto out_meta =
        at::meta::addcmul(self_meta, tensor1_meta, tensor2_meta, value);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self, tensor1, tensor2, value};
      const char* schema_str =
          "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, "
          "Scalar value=1) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<Addcmul>(
        lazy_self->GetIrValue(), lazy_tensor1->GetIrValue(),
        lazy_tensor2->GetIrValue(), node_value, std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};

@JackCaoG JackCaoG changed the title Codegen addcdiv and addcmul [BLOCKED]Codegen addcdiv and addcmul Jul 26, 2022
@miladm miladm added the BLOCKED label Jul 28, 2022
@JackCaoG JackCaoG force-pushed the code_gen_addcdiv/mul branch from 85ce2b0 to 549a19b Compare August 8, 2022 23:54
@JackCaoG JackCaoG changed the title [BLOCKED]Codegen addcdiv and addcmul Codegen addcdiv and addcmul Aug 8, 2022
@JackCaoG JackCaoG force-pushed the code_gen_addcdiv/mul branch from 549a19b to 71ba7c3 Compare August 9, 2022 00:05
@JackCaoG JackCaoG removed the BLOCKED label Aug 9, 2022
@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Aug 9, 2022

pytorch/pytorch#82970 (review) fixed the device issue.

@JackCaoG JackCaoG mentioned this pull request Aug 9, 2022
@JackCaoG JackCaoG force-pushed the code_gen_addcdiv/mul branch from b734a75 to 9972815 Compare August 9, 2022 04:31
@JackCaoG
Copy link
Collaborator Author

Issue coming from trying to add f64 and f32

2 root error(s) found.
  (0) INTERNAL: during context [Unknown]: Seen floating point types of different precisions in %add.122 = f64[1,16]{1,0} add(f64[1,16]{1,0} %dot.119, f32[1,16]{1,0} %broadcast.121), metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="forward@linear.py" source_line=114}, but mixed precision is disallowed.
	 [[{{node XRTCompile}}]]
	 [[XRTCompile_G3]]

@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Aug 10, 2022

Issue was coming from that when we uploading a at::Scalar to the device, the current codegen upload does not take a scalar type, hence we always upload with the default type (f64). To fix this issue, we need to identify the Value that was a at::scalar in the original aten_xla_type.cpp and cast the value to the correct dtype. FYI @wonjoolee95

@JackCaoG JackCaoG force-pushed the code_gen_addcdiv/mul branch from f30c43a to 5de4435 Compare August 11, 2022 23:48
@JackCaoG JackCaoG requested a review from wonjoolee95 August 12, 2022 23:46
@JackCaoG
Copy link
Collaborator Author

@wonjoolee95 I think this one is ready for reivew

Copy link
Collaborator

@wonjoolee95 wonjoolee95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

addcmul addcdiv_ addcdiv
3 participants