-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamo and cond with free variables creates malformed graph #90469
Comments
I played around with the repro a bit.
With this variant the generated code actually crashes because it tries to look up a variable
Basically this means that the |
Fixes #90469 Let's say we call `cond` with a `true_fn` / `false_fn` that captures variables in the scope of the call. The instructions emitted in the graphs for the nested functions naively refer to the names of the nodes in the outer graph as if the nested functions were "inlined," but unfortunately this can clash with variables declared in the nested function itself. Moreover, the values of these closed variables need to be bound late at call time. In this diff we propose a fix for this problem by passing around the closure environment as attributes in the graph module. Details of the compiling / calling protocol are in comments. Differential Revision: [D42353499](https://our.internmc.facebook.com/intern/diff/D42353499/) [ghstack-poisoned]
…cond / export" Fixes #90469 Let's say we call `cond` with a `true_fn` / `false_fn` that captures variables in the scope of the call. The instructions emitted in the graphs for the nested functions naively refer to the names of the nodes in the outer graph as if the nested functions were "inlined," but unfortunately this can clash with variables declared in the nested function itself. Moreover, the values of these closed variables need to be bound late at call time. In this diff we propose a fix for this problem by passing around the closure environment as attributes in the graph module. Details of the compiling / calling protocol are in comments. Differential Revision: [D42353499](https://our.internmc.facebook.com/intern/diff/D42353499/) cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Pull Request resolved: #91981 Fixes #90469 Let's say we call `cond` with a `true_fn` / `false_fn` that captures variables in the scope of the call. The instructions emitted in the graphs for the nested functions naively refer to the names of the nodes in the outer graph as if the nested functions were "inlined," but unfortunately this can clash with variables declared in the nested function itself. Moreover, the values of these closed variables need to be bound late at call time. In this diff we propose a fix for this problem by passing around the closure environment as attributes in the graph module. Details of the compiling / calling protocol are in comments. ghstack-source-id: 177444739 Differential Revision: [D42353499](https://our.internmc.facebook.com/intern/diff/D42353499/)
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables. Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly). Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/) [ghstack-poisoned]
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables. Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly). Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/) cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Pull Request resolved: #99367 As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables. Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly). ghstack-source-id: 186352080 Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables. Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly). Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/) cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Pull Request resolved: #99367 As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables. Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly). ghstack-source-id: 186391673 Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables. Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly). Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/) cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Pull Request resolved: #99367 As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables. Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly). ghstack-source-id: 186457954 Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables. Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly). Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/) cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Pull Request resolved: #99367 As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables. Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly). ghstack-source-id: 186556432 Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables. Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly). Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/) Pull Request resolved: #99367 Approved by: https://github.com/ezyang, https://github.com/tugsbayasgalan
Closes as issue fixed. Repro: patch ed's diff to test/dynamo/test_export.py:test_export_with_module_layer. It produces the following graph
|
🐛 Describe the bug
repro
the true graph ends up being
which is so bad
Versions
master
cc @soumith @msaroufim @wconstab @ngimel @bdhirsh @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire
The text was updated successfully, but these errors were encountered: