Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support runtime assertion for inline constraints #100763

Closed
wants to merge 1 commit into from

Conversation

ydwu4
Copy link
Contributor

@ydwu4 ydwu4 commented May 5, 2023

This pr does the following:

  1. previously, inline constraints is not properly set for tensor output data-dependent ops such as a.nonzero because of its return value is not symint. This pr just uses all the unbacked symbols i.e.those start with "i"/"f" in create_unbacked_sym* functions. Note that these symbols are guaranteed to be a super set of inline user constraints.

  2. add inline assertions support by checking.

Currently, it only deal with tensor, SymInt, SymFloat, SymBool output data-dependent ops and ignore the rest. It's good enough for now as we only have a limited number of data-dependent ops (.item and .nonzero are explicitly tested).

The examples for graph that is added assertions is shown below:

class ExportGraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: i64[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        nonzero_default: i64[i0, 1] = torch.ops.aten.nonzero.default(arg0);  arg0 = None
        return pytree.tree_unflatten([nonzero_default], self._out_spec)
        
class GraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: i64[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        sym_size: Sym(s0) = torch.ops.aten.sym_size(arg0, 0)
        nonzero_default: i64[i1, 1] = torch.ops.aten.nonzero.default(arg0);  arg0 = None
        sym_size_1: Sym(i1) = torch.ops.aten.sym_size(nonzero_default, 0)
        ge: Sym(i1 >= 3) = sym_size_1 >= 3
        scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(ge);  ge = None
        _assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'nonzero_default.shape[0] is outside of inline constraint [3, 5].');  scalar_tensor_default = None
        le: Sym(i1 <= 5) = sym_size_1 <= 5;  sym_size_1 = None
        scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(le);  le = None
        _assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'nonzero_default.shape[0] is outside of inline constraint [3, 5].');  scalar_tensor_default_1 = None
        return pytree.tree_unflatten([nonzero_default], self._out_spec)

cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label May 5, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented May 5, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100763

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 17be12d:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

capture_scalar_outputs=True,
)
def test_export_preserve_constraints_as_metadata_tensor(self):
from torch._export.constraints import constrain_as_value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move it at the top of file? We recently ran into circular import issue, so i wanted to make sure it doesn't happen again?

self.assertEqual(num_assert, 3)
self.assertEqual(num_scalar_tensor, 3)

with self.assertRaisesRegex(RuntimeError, r"^_local_scalar_dense_default.*\[2, 5\].$"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message is bit hard to understand from user's perspective, how about we just dump the stacktrace from user code instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

new_inp = torch.tensor([1, 1, 1, 1])
self.assertEqual(mod(new_inp), new_gm(new_inp))

# FIXME: support control flow operators for the pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be fixed after #100836

if upper < math.inf:
self._inser_assert_async(operator.le, proxy, upper, assert_msg)

def _inser_assert_async(self, operator, l, r, assert_msg):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: _insert_assert_async

if "val" in meta:
val = meta["val"]

def add_assertions(val):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a comment on explaining the high level approach? Especially the part why we need to accumulate messages.

cbs, msgs = add_assertions(sym)
for cb, msg in zip(cbs, msgs):
def sym_size_cb(proxy, assert_msg, dim):
dim_proxy = super(AddRuntimeAssertionsForConstraintsPass, self).call_operator(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't it be just super().call_operator()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure.. Seems related to some weird name resolution issue.

@ydwu4
Copy link
Contributor Author

ydwu4 commented May 9, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 9, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: fx release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants