Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aot_autograd: avoid using intermediate_base logic unnecessarily #97786

Closed
wants to merge 4 commits into from

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Mar 28, 2023

fixes #97691, see the issue for the proposed design. Now that we are employing AOTAutograd's "intermediate base" logic a lot less frequently, we might see some speedups in the benchmark suite.

Stack from ghstack (oldest at bottom):

cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97786

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 164fb34:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request Mar 28, 2023
ghstack-source-id: c8cddadc4ec99084a53c7312433feb9eebfc727b
Pull Request resolved: #97786
@bdhirsh bdhirsh requested a review from soulitzer March 28, 2023 15:41
# In cases where we know that an output's view-ness is safe to hide from autograd
# (the output is a view of an intermediate that doesn't escape the graph),
# we hide the view-ness from autograd.
# self.assertEqual(ref_o._is_view(), test_o._is_view())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should still be able to test this no?
In particular, check that if ref_o is a view of an input or another output, we preserve that ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can no longer test for it unconditionally - one thing I can do is plumb a bool into this test helper so we know when to test for it (I'll probably do that).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have access to ref_o and ref_inp here right?
So you should be able to check, the pseudo code I have in mind is:

def get_base(t):
    return t._base if t._is_view() else t

def is_in_base(t, tensors):
    t_base = get_base(t)
    for tensor in tensors:
        if t_base is get_base(tensor):
            return True
    return False

ref_is_view_of_non_interm = is_in_base(ref_o, ref_inps) or is_in_base(ref_o, ref_outs)
test_is_view_of_non_interm = is_in_base(test_o, test_inps) or is_in_base(test_o, test_outs)

assert ref_is_view_of_non_interm == test_is_view_of_non_interm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that test is definitely better, thanks - I'll add it.

torch/_functorch/aot_autograd.py Outdated Show resolved Hide resolved
if info.output_type == OutputType.alias_of_intermediate_save_as_output:
intermediate_bases.append(o._base)
elif info.output_type == OutputType.unsafe_view_alias:
# See Note [Intermediate Bases Optimization]
outs[i] = torch.ops.aten._unsafe_view(o, o.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: always call specific overload. the overload resolution from jit used here is dead slow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this matters too much, because this is run at trace time (and we'll still end up baking the overload into the graph). But good to know - I'll update.

Copy link
Collaborator

@albanD albanD Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compilation time is still something we want to keep down no? :D
But yes most likely not a big issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed 😛 (to be far this op will be called usually a handful of times per compiled subgraph, so pretty infrequently (at most, # graph outputs, if every output is an alias of a graph intermediate))

…arily"

fixes #97691, see the issue for the proposed design. Now that we are employing AOTAutograd's "intermediate base" logic a lot less frequently, we might see some speedups in the benchmark suite.




[ghstack-poisoned]
…arily"

fixes #97691, see the issue for the proposed design. Now that we are employing AOTAutograd's "intermediate base" logic a lot less frequently, we might see some speedups in the benchmark suite.




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Mar 28, 2023
ghstack-source-id: 0c01eca2f116f2a3add0789997f48832ef497fac
Pull Request resolved: #97786
Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just had a few minor questions

out_test[0].mul_(3)
# Assert that the aliasing relationship was preserved
self.assertEqual(out_ref[0], out_test[0])
self.assertEqual(out_ref[1], out_test[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we test anywhere that aliasing relationship is preserved in the case of:

def f(a):
   b = a.clone()
   return b.view(-1), b.view(-1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some other tests test for it more indirectly - but I can add one

@@ -815,7 +840,7 @@ def inner(*flat_args):
f_output_tangents = [
o
for o, info in zip(flat_f_outs, output_info)
if info.output_type == OutputType.non_alias and issubclass(info.raw_type, torch.Tensor)
if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias] and issubclass(info.raw_type, torch.Tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment above still up to date?

@@ -953,7 +983,7 @@ def inner_fn(*args):
# For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead,
# which we *should* send to grad()
output_grad_mask = [
meta.output_info[i].output_type == OutputType.non_alias
meta.output_info[i].output_type in [OutputType.non_alias, OutputType.unsafe_view_alias]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the comment here

@@ -481,7 +483,7 @@ def __post_init__(self):
self.aliased_out_indices = aliased_out_indices
self.num_outputs = len(self.output_info)
self.num_outputs_non_aliased = len(
[x for x in self.output_info if x.output_type == OutputType.non_alias]
[x for x in self.output_info if x.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth updating the name of this field? "num_outputs_non_aliased" now that it also counts outputs that may be aliases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought here is that for all intents and purposes w.r.t. AOTAutograd, these tensors are non-aliased. I would have liked to not need a new OutputType, and just use non_alias for these tensors - but I needed some way of telling the joint code in AOTAutograd to insert an unsafe_view() later (I could have added a separate metadata, but an extra OutputType felt easier).

Let me know if that doesn't seem reasonable to you though!

fwiw, Alban also pointed out another option - don't insert unsafe_view here in AOTAutograd, and instead require the backend compiler to run ADInplaceOrView keys inside of its compiled kernel (that's what happens by default today, since inductor is just generated as_strided()). I figured I would do this simpler thing first that doesn't involve changing inductor code, and leave that for later work.

jansel added a commit that referenced this pull request Mar 30, 2023
#97786 might be a better fix for this one, hopefully we can land that then discard this PR


cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Mar 30, 2023
#97786 might be a better fix for this one, hopefully we can land that then discard this PR


cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
…arily"

fixes #97691, see the issue for the proposed design. Now that we are employing AOTAutograd's "intermediate base" logic a lot less frequently, we might see some speedups in the benchmark suite.




cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Mar 30, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 30, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Mar 31, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/402/head branch June 8, 2023 15:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants