-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
aot_autograd: avoid using intermediate_base logic unnecessarily #97786
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97786
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 164fb34: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: c8cddadc4ec99084a53c7312433feb9eebfc727b Pull Request resolved: #97786
test/functorch/test_aotdispatch.py
Outdated
# In cases where we know that an output's view-ness is safe to hide from autograd | ||
# (the output is a view of an intermediate that doesn't escape the graph), | ||
# we hide the view-ness from autograd. | ||
# self.assertEqual(ref_o._is_view(), test_o._is_view()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should still be able to test this no?
In particular, check that if ref_o is a view of an input or another output, we preserve that ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can no longer test for it unconditionally - one thing I can do is plumb a bool into this test helper so we know when to test for it (I'll probably do that).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have access to ref_o and ref_inp here right?
So you should be able to check, the pseudo code I have in mind is:
def get_base(t):
return t._base if t._is_view() else t
def is_in_base(t, tensors):
t_base = get_base(t)
for tensor in tensors:
if t_base is get_base(tensor):
return True
return False
ref_is_view_of_non_interm = is_in_base(ref_o, ref_inps) or is_in_base(ref_o, ref_outs)
test_is_view_of_non_interm = is_in_base(test_o, test_inps) or is_in_base(test_o, test_outs)
assert ref_is_view_of_non_interm == test_is_view_of_non_interm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that test is definitely better, thanks - I'll add it.
torch/_functorch/aot_autograd.py
Outdated
if info.output_type == OutputType.alias_of_intermediate_save_as_output: | ||
intermediate_bases.append(o._base) | ||
elif info.output_type == OutputType.unsafe_view_alias: | ||
# See Note [Intermediate Bases Optimization] | ||
outs[i] = torch.ops.aten._unsafe_view(o, o.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: always call specific overload. the overload resolution from jit used here is dead slow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this matters too much, because this is run at trace time (and we'll still end up baking the overload into the graph). But good to know - I'll update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compilation time is still something we want to keep down no? :D
But yes most likely not a big issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed 😛 (to be far this op will be called usually a handful of times per compiled subgraph, so pretty infrequently (at most, # graph outputs, if every output is an alias of a graph intermediate))
…arily" fixes #97691, see the issue for the proposed design. Now that we are employing AOTAutograd's "intermediate base" logic a lot less frequently, we might see some speedups in the benchmark suite. [ghstack-poisoned]
…arily" fixes #97691, see the issue for the proposed design. Now that we are employing AOTAutograd's "intermediate base" logic a lot less frequently, we might see some speedups in the benchmark suite. [ghstack-poisoned]
ghstack-source-id: 0c01eca2f116f2a3add0789997f48832ef497fac Pull Request resolved: #97786
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just had a few minor questions
out_test[0].mul_(3) | ||
# Assert that the aliasing relationship was preserved | ||
self.assertEqual(out_ref[0], out_test[0]) | ||
self.assertEqual(out_ref[1], out_test[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we test anywhere that aliasing relationship is preserved in the case of:
def f(a):
b = a.clone()
return b.view(-1), b.view(-1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some other tests test for it more indirectly - but I can add one
@@ -815,7 +840,7 @@ def inner(*flat_args): | |||
f_output_tangents = [ | |||
o | |||
for o, info in zip(flat_f_outs, output_info) | |||
if info.output_type == OutputType.non_alias and issubclass(info.raw_type, torch.Tensor) | |||
if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias] and issubclass(info.raw_type, torch.Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the comment above still up to date?
@@ -953,7 +983,7 @@ def inner_fn(*args): | |||
# For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead, | |||
# which we *should* send to grad() | |||
output_grad_mask = [ | |||
meta.output_info[i].output_type == OutputType.non_alias | |||
meta.output_info[i].output_type in [OutputType.non_alias, OutputType.unsafe_view_alias] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also the comment here
@@ -481,7 +483,7 @@ def __post_init__(self): | |||
self.aliased_out_indices = aliased_out_indices | |||
self.num_outputs = len(self.output_info) | |||
self.num_outputs_non_aliased = len( | |||
[x for x in self.output_info if x.output_type == OutputType.non_alias] | |||
[x for x in self.output_info if x.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth updating the name of this field? "num_outputs_non_aliased" now that it also counts outputs that may be aliases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thought here is that for all intents and purposes w.r.t. AOTAutograd, these tensors are non-aliased. I would have liked to not need a new OutputType, and just use non_alias
for these tensors - but I needed some way of telling the joint code in AOTAutograd to insert an unsafe_view() later (I could have added a separate metadata, but an extra OutputType felt easier).
Let me know if that doesn't seem reasonable to you though!
fwiw, Alban also pointed out another option - don't insert unsafe_view here in AOTAutograd, and instead require the backend compiler to run ADInplaceOrView keys inside of its compiled kernel (that's what happens by default today, since inductor is just generated as_strided()). I figured I would do this simpler thing first that doesn't involve changing inductor code, and leave that for later work.
#97786 might be a better fix for this one, hopefully we can land that then discard this PR cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
#97786 might be a better fix for this one, hopefully we can land that then discard this PR cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…arily" fixes #97691, see the issue for the proposed design. Now that we are employing AOTAutograd's "intermediate base" logic a lot less frequently, we might see some speedups in the benchmark suite. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
fixes #97691, see the issue for the proposed design. Now that we are employing AOTAutograd's "intermediate base" logic a lot less frequently, we might see some speedups in the benchmark suite.
Stack from ghstack (oldest at bottom):
cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire