Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fori_loop|While_loop] Enable fori_loop with add/sub test case #6603

Merged
merged 64 commits into from
Mar 8, 2024

Conversation

ManfeiBai
Copy link
Collaborator

@ManfeiBai ManfeiBai commented Feb 23, 2024

For fori_loop implementation with while_loop, this PR is for lowering body/cond to replace formal placeholder

This is the step two PR, and father PR(#6532), child PR(#6529), source PR(#6563)


some issue fixed:

  • body fn is -, not a torch func, will test later tried torch.sub(a, b), passed too locally
  • current code has changed many logic of lowering, let's move these logics to a new function without affecting the existing functions
  • input are limited to list/tuple this match torch._higher_order_ops.while_loop required
  • input was trans from list to not list after torch.compile, TODO, add the same logic like torch.compile to use inputs, not like currently create a duplicated tensor in the fori_loop.py file

@ManfeiBai ManfeiBai changed the title Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py [Do Not Merge] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py Feb 23, 2024
@ManfeiBai ManfeiBai added DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing fori_loop labels Feb 23, 2024
@ManfeiBai ManfeiBai changed the title [Do Not Merge] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py [Do Not Merge][Fori_loop|While_loop] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py Feb 24, 2024
if (!root_tuple_.empty() & (root_tuple_.size() > 1)) {
xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
xla = builder()->Build(root);
} else if (!root_tuple_.empty() & (root_tuple_.size() == 1)) {
Copy link
Collaborator Author

@ManfeiBai ManfeiBai Feb 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain: we need to skip tuple for cond/body computation creation to match the xla::While format check for cond, error log

@ManfeiBai ManfeiBai marked this pull request as ready for review February 26, 2024 19:34
@ManfeiBai
Copy link
Collaborator Author

ManfeiBai commented Feb 26, 2024

Hi, @JackCaoG, since this PR would add new function to PyLoweringContext, do we want to request review from aws too?

@ManfeiBai ManfeiBai changed the title [Do Not Merge][Fori_loop|While_loop] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py [Fori_loop|While_loop] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py Feb 26, 2024
@JackCaoG
Copy link
Collaborator

@amithrm FYI

@ManfeiBai
Copy link
Collaborator Author

kokoro failure should be fixed on master branch, let's skip it now

@ManfeiBai ManfeiBai removed the DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing label Feb 27, 2024
@ManfeiBai ManfeiBai requested a review from yeounoh February 29, 2024 18:07
Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some suggestions

@ManfeiBai ManfeiBai requested a review from yeounoh February 29, 2024 22:27
@ManfeiBai ManfeiBai force-pushed the ManfeiBai-patch-73 branch from c35ac53 to c7f09d5 Compare March 4, 2024 17:47
@@ -1027,7 +1033,9 @@ void BuildLoweringContextSubmodule(py::module* m) {
.def("hlo_json", &PyLoweringContext::GetHloJsonText)
.def("parameter_id_tensor_mapping",
&PyLoweringContext::GetParameterIdTensorMapping)
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId);
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId)
.def("set_name_string", &PyLoweringContext::SetNameString)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, thanks!

xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
xla = builder()->Build(root);
} else if (!root_tuple_.empty() & (root_tuple_.size() == 1)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should condition on get_name_string(). Add this check at the top and build for while loop `if get_name_string() == "condctx" or get_name_string() == "bodyctx"; otherwise, you can keep the original build logic.

Have your logic for while loop build in a separate private method, and call it if ``if get_name_string() == "condctx" or get_name_string() == "bodyctx"` is true.

So you can keep BuildXla() simple.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, make sense and updated in the newest commit, due to the logic for while loop is one simple line code, we run it directly without warping it in a separate private method

@@ -68,7 +72,8 @@ class LoweringContext : public torch::lazy::LoweringContext {
xla::XlaOp GetOutputOp(const torch::lazy::Output& output);

// Build the XLA computation capturing all the operations created with the
// embedded XLA builder (returned by the builder() API).
// embedded XLA builder (returned by the builder() API) with check whether
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we can keep the original comment. Add the while specific comment on the new private method, as described above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated

def cond_fn(x):
return x.sum() <= 10
ten = torch.ones(1, dtype=torch.int32, device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cosmetic nit: ten can be interpreted as number 10 - suggesting a better name plz

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated

@ManfeiBai ManfeiBai force-pushed the ManfeiBai-patch-73 branch from 96ce688 to 173ff44 Compare March 8, 2024 18:44
@ManfeiBai ManfeiBai merged commit 6170df5 into master Mar 8, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants