Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly lower div.scalar and div.tensor #6669

Merged
merged 3 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ std::vector<xla::XlaOp> BuildPreluBackward(xla::XlaOp grad, xla::XlaOp input,

xla::XlaOp BuildSigmoid(xla::XlaOp input) { return xla::Logistic(input); }

xla::XlaOp BuildDiv(xla::XlaOp input, xla::XlaOp divisor) {
// Shape and value promotion.
std::tie(input, divisor) = XlaHelpers::Promote(input, divisor);
xla::XlaOp div_result = xla::Div(input, divisor);
return div_result;
}

xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input) {
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input);
xla::XlaOp one = xla::One(input.builder(), shape.element_type());
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ std::vector<xla::XlaOp> BuildLogSigmoid(xla::XlaOp input);
// If eps is given, the input is clamped between eps and 1-eps.
xla::XlaOp BuildLogit(xla::XlaOp input, c10::optional<double> eps);

// Computes the division of input and the divisor.
xla::XlaOp BuildDiv(xla::XlaOp input, xla::XlaOp divisor);

// Computes the backward of LogSigmoid.
xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input,
xla::XlaOp buffer);
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,18 @@ torch::lazy::NodePtr Remainder(const torch::lazy::Value& input,
ScalarOp(0, GetXlaShape(input)));
}

torch::lazy::NodePtr Div(const torch::lazy::Value& input,
const torch::lazy::Value& divisor) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_divisor = loctx->GetOutputOp(node.operand(1));
return node.ReturnOp(BuildDiv(xla_input, xla_divisor), loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::div), {input, divisor},
GetXlaShape(input), std::move(lower_fn));
}

torch::lazy::NodePtr MaxUnary(const torch::lazy::Value& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ torch::lazy::NodePtr Rshift(const torch::lazy::Value& input,
torch::lazy::NodePtr Rshift(const torch::lazy::Value& input,
const torch::lazy::Value& other);

torch::lazy::NodePtr Div(const torch::lazy::Value& input,
const torch::lazy::Value& divisor);

torch::lazy::NodePtr Remainder(const torch::lazy::Value& input,
const torch::lazy::Value& divisor);

Expand Down
5 changes: 2 additions & 3 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1149,8 +1149,7 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other,
// divide and trunc divide.
torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type);
torch::lazy::Value other_value = GetFloatingIrValue(other, scalar_type);
torch::lazy::Value res = input_value / other_value;

torch::lazy::Value res = Div(input_value, other_value);
if (rounding_mode.has_value()) {
if (*rounding_mode == "trunc") {
res = torch::lazy::MakeNode<Trunc>(res);
Expand Down Expand Up @@ -1193,7 +1192,7 @@ XLATensorPtr div(const XLATensorPtr& input, const at::Scalar& other) {
torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type);
torch::lazy::Value other_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
other, GetXlaShape(input_value).element_type(), input->GetDevice());
return input->CreateFrom(input_value / other_value, scalar_type);
return input->CreateFrom(Div(input_value, other_value), scalar_type);
}

XLATensorPtr einsum(const std::string& equation,
Expand Down
Loading