Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly lower add and mul (#6731) #6744

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def test_xla_sharded_hlo_dump(self):
partition_spec)
xst2 = xst1 + 5
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xst2.global_tensor])
self.assertIn('%p1.4 = f32[1,8]{1,0} parameter(1), sharding', hlo)
self.assertIn('%p1.3 = f32[1,8]{1,0} parameter(1), sharding', hlo)
# scalar 5 should be implicitly replicated, so the pre-optimization HLO
# shouldn't mark it with sharding.
self.assertNotIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo)
Expand Down Expand Up @@ -826,7 +826,7 @@ def test_mark_sharding_ir(self):
actual += 0
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
self.assertIn(
'%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.10, f32[1,128]{1,0} %broadcast.11)',
'%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9, f32[1,128]{1,0} %broadcast.11)',
hlo)

self.assertTrue(torch.allclose(expected, actual.cpu()))
Expand Down
24 changes: 24 additions & 0 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,4 +500,28 @@ xla::XlaOp BuildSub(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha) {
return sub_result;
}

xla::XlaOp BuildAdd(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha) {
// Three-way shape and value promotion
std::tie(input, other) = XlaHelpers::Promote(input, other);
std::tie(input, alpha) = XlaHelpers::Promote(input, alpha);
std::tie(input, other) = XlaHelpers::Promote(input, other);

xla::XlaOp multiplied =
xla::Mul(other, alpha, XlaHelpers::getBroadcastDimensions(other, alpha));
xla::XlaOp add_result = xla::Add(
input, multiplied, XlaHelpers::getBroadcastDimensions(input, multiplied));

return add_result;
}

xla::XlaOp BuildMul(xla::XlaOp input, xla::XlaOp other) {
// Shape and value promotion
std::tie(input, other) = XlaHelpers::Promote(input, other);

xla::XlaOp mul_result =
xla::Mul(input, other, XlaHelpers::getBroadcastDimensions(input, other));

return mul_result;
}

} // namespace torch_xla
12 changes: 10 additions & 2 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,22 @@ xla::XlaOp BuildEluBackward(xla::XlaOp grad_output, xla::XlaOp output,
// based on a scalar or tensor weight and returns the resulting out tensor.
xla::XlaOp BuildLerp(xla::XlaOp start, xla::XlaOp end, xla::XlaOp weight);

// Compuate the rsub function. Subtracts input, scaled by alpha, from other.
// Computes the rsub function. Subtracts input, scaled by alpha, from other.
// out = other − alpha * input
xla::XlaOp BuildRsub(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha);

// Compuate the sub function. Subtracts other, scaled by alpha, from input.
// Computes the sub function. Subtracts other, scaled by alpha, from input.
// out = input − alpha * other
xla::XlaOp BuildSub(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha);

// Computes the add function. Adds other, scaled by alpha, from input.
// out = input + alpha * other
xla::XlaOp BuildAdd(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha);

// Computes the mul function.
// out = input * other
xla::XlaOp BuildMul(xla::XlaOp input, xla::XlaOp other);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_ELEMENTWISE_H_
51 changes: 51 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,4 +1017,55 @@ torch::lazy::NodePtr Sub(const torch::lazy::Value& input,
std::move(lower_fn));
}

torch::lazy::NodePtr Add(const torch::lazy::Value& input,
const torch::lazy::Value& other,
const torch::lazy::Value& alpha) {
torch::lazy::ScopePusher ir_scope(at::aten::add.toQualString());
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_alpha = loctx->GetOutputOp(node.operand(2));
xla::XlaOp xla_output = BuildAdd(xla_input, xla_other, xla_alpha);
return node.ReturnOp(xla_output, loctx);
};
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 3) << "Unexpected number of operands";
return BuildAdd(operands[0], operands[1], operands[2]);
};
return GenericOp(
torch::lazy::OpKind(at::aten::add), {input, other, alpha},
[&]() {
return InferOutputShape(
{GetXlaShape(input), GetXlaShape(other), GetXlaShape(alpha)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr Mul(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
torch::lazy::ScopePusher ir_scope(at::aten::mul.toQualString());
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_output = BuildMul(xla_input, xla_other);
return node.ReturnOp(xla_output, loctx);
};
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 2) << "Unexpected number of operands";
return BuildMul(operands[0], operands[1]);
};
return GenericOp(
torch::lazy::OpKind(at::aten::mul), {input, other},
[&]() {
return InferOutputShape({GetXlaShape(input), GetXlaShape(other)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

} // namespace torch_xla
7 changes: 7 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ torch::lazy::NodePtr Sub(const torch::lazy::Value& input,
const torch::lazy::Value& other,
const torch::lazy::Value& alpha);

torch::lazy::NodePtr Add(const torch::lazy::Value& input,
const torch::lazy::Value& other,
const torch::lazy::Value& alpha);

torch::lazy::NodePtr Mul(const torch::lazy::Value& input,
const torch::lazy::Value& other);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_OPS_H_
12 changes: 7 additions & 5 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,9 @@ XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other,
sym_int_elements, logical_element_type, device);
}

return input->CreateFrom(input->GetIrValue() + other->GetIrValue() * constant,
logical_element_type);
return input->CreateFrom(
Add(input->GetIrValue(), other->GetIrValue(), constant),
logical_element_type);
}

XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other,
Expand All @@ -787,8 +788,9 @@ XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(input->dtype(), &device)),
logical_element_type, device);

return input->CreateFrom(
input->GetIrValue() + other_constant * alpha_constant,
Add(input->GetIrValue(), other_constant, alpha_constant),
logical_element_type);
}

Expand Down Expand Up @@ -1877,7 +1879,7 @@ XLATensorPtr mse_loss_backward(const XLATensorPtr& grad_output,

XLATensorPtr mul(const XLATensorPtr& input, const XLATensorPtr& other,
c10::optional<at::ScalarType> logical_element_type) {
return input->CreateFrom(input->GetIrValue() * other->GetIrValue(),
return input->CreateFrom(Mul(input->GetIrValue(), other->GetIrValue()),
logical_element_type);
}

Expand All @@ -1889,7 +1891,7 @@ XLATensorPtr mul(const XLATensorPtr& input, const at::Scalar& other,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(input->dtype(), &device)),
logical_element_type, device);
return input->CreateFrom(input->GetIrValue() * constant,
return input->CreateFrom(Mul(input->GetIrValue(), constant),
logical_element_type);
}

Expand Down
Loading