Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamo and cond with free variables creates malformed graph #90469

Closed
ezyang opened this issue Dec 8, 2022 · 3 comments
Closed

Dynamo and cond with free variables creates malformed graph #90469

ezyang opened this issue Dec 8, 2022 · 3 comments
Assignees
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented Dec 8, 2022

🐛 Describe the bug

repro

diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index b0640f6511..8fb8754c51 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -1443,8 +1443,10 @@ class ExportTests(torch._dynamo.test_case.TestCase):
                 self.linear = torch.nn.Linear(3, 3)
 
             def forward(self, pred, x):
+                y = x * 2
+
                 def true_fn(val):
-                    return self.linear(val) * torch.tensor(2)
+                    return self.linear(val) * torch.tensor(2) * y
 
                 def false_fn(val):
                     return self.linear(val) * torch.tensor(-1)

the true graph ends up being

def forward(self, x):
    self_linear = self.self_linear(x);  x = None
    tensor = torch.tensor(2)
    mul = self_linear * tensor;  self_linear = tensor = None
    mul_1 = mul * mul;  mul = mul = None
    return mul_1

which is so bad

Versions

master

cc @soumith @msaroufim @wconstab @ngimel @bdhirsh @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@ezyang ezyang self-assigned this Dec 8, 2022
@ezyang
Copy link
Contributor Author

ezyang commented Dec 8, 2022

cc @voznesenskym

@soulitzer soulitzer added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: dynamo labels Dec 9, 2022
@avikchaudhuri
Copy link
Contributor

I played around with the repro a bit.

def forward(self, pred, x):
    y = x + 2 // note: add instead of mul

    def true_fn(val):
        return self.linear(val) * torch.tensor(2) * y
 
    def false_fn(val):
        return self.linear(val) * torch.tensor(-1) * y

With this variant the generated code actually crashes because it tries to look up a variable add which is not in scope (but it should be, because this is what the addition stored in y gets renamed to):

def forward(self, x):
    self_linear = self.self_linear(x);  x = None
    tensor = torch.tensor(2)
    mul = self_linear * tensor;  self_linear = tensor = None
    mul_1 = mul * add;  mul = mul = None
    return mul_1

Basically this means that the mul * mul in the original repro is just a coincidence.

avikchaudhuri added a commit that referenced this issue Jan 11, 2023
Fixes #90469

Let's say we call `cond` with a `true_fn` / `false_fn` that captures variables in the scope of the call. The instructions emitted in the graphs for the nested functions naively refer to the names of the nodes in the outer graph as if the nested functions were "inlined," but unfortunately this can clash with variables declared in the nested function itself. Moreover, the values of these closed variables need to be bound late at call time.

In this diff we propose a fix for this problem by passing around the closure environment as attributes in the graph module. Details of the compiling / calling protocol are in comments.

Differential Revision: [D42353499](https://our.internmc.facebook.com/intern/diff/D42353499/)

[ghstack-poisoned]
avikchaudhuri added a commit that referenced this issue Jan 11, 2023
…cond / export"

Fixes #90469

Let's say we call `cond` with a `true_fn` / `false_fn` that captures variables in the scope of the call. The instructions emitted in the graphs for the nested functions naively refer to the names of the nodes in the outer graph as if the nested functions were "inlined," but unfortunately this can clash with variables declared in the nested function itself. Moreover, the values of these closed variables need to be bound late at call time.

In this diff we propose a fix for this problem by passing around the closure environment as attributes in the graph module. Details of the compiling / calling protocol are in comments.

Differential Revision: [D42353499](https://our.internmc.facebook.com/intern/diff/D42353499/)

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
avikchaudhuri added a commit that referenced this issue Jan 11, 2023
Pull Request resolved: #91981

Fixes #90469

Let's say we call `cond` with a `true_fn` / `false_fn` that captures variables in the scope of the call. The instructions emitted in the graphs for the nested functions naively refer to the names of the nodes in the outer graph as if the nested functions were "inlined," but unfortunately this can clash with variables declared in the nested function itself. Moreover, the values of these closed variables need to be bound late at call time.

In this diff we propose a fix for this problem by passing around the closure environment as attributes in the graph module. Details of the compiling / calling protocol are in comments.
ghstack-source-id: 177444739

Differential Revision: [D42353499](https://our.internmc.facebook.com/intern/diff/D42353499/)
avikchaudhuri added a commit that referenced this issue Apr 17, 2023
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables.

Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly).

Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)

[ghstack-poisoned]
avikchaudhuri added a commit that referenced this issue Apr 17, 2023
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables.

Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly).

Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
avikchaudhuri added a commit that referenced this issue Apr 17, 2023
Pull Request resolved: #99367

As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables.

Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly).
ghstack-source-id: 186352080

Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)
avikchaudhuri added a commit that referenced this issue Apr 18, 2023
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables.

Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly).

Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
avikchaudhuri added a commit that referenced this issue Apr 18, 2023
Pull Request resolved: #99367

As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables.

Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly).
ghstack-source-id: 186391673

Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)
avikchaudhuri added a commit that referenced this issue Apr 18, 2023
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables.

Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly).

Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
avikchaudhuri added a commit that referenced this issue Apr 18, 2023
Pull Request resolved: #99367

As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables.

Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly).
ghstack-source-id: 186457954

Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)
avikchaudhuri added a commit that referenced this issue Apr 19, 2023
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables.

Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly).

Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
avikchaudhuri added a commit that referenced this issue Apr 19, 2023
Pull Request resolved: #99367

As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables.

Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly).
ghstack-source-id: 186556432

Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)
pytorchmergebot pushed a commit that referenced this issue Apr 19, 2023
As reported in #90469, the implementation of inlining nested function branches for `cond` doesn't properly handle variables captured from outer scopes. This leads to some examples accidentally working, some others generating incorrect code that don't crash but do the wrong thing, and still others that outright crash because of references to non-existent variables.

Properly supporting closed variables is tricky (see #91981 for an abandoned attempt). While this is definitely something we should be able to support longer term, for now it is better to explicitly error and suggest the fix to the user (amounting to rewriting branches to take closed variables explicitly).

Differential Revision: [D45058621](https://our.internmc.facebook.com/intern/diff/D45058621/)

Pull Request resolved: #99367
Approved by: https://github.com/ezyang, https://github.com/tugsbayasgalan
@ydwu4
Copy link
Contributor

ydwu4 commented Nov 29, 2023

Closes as issue fixed. Repro: patch ed's diff to test/dynamo/test_export.py:test_export_with_module_layer. It produces the following graph

class GraphModule(torch.nn.Module):
    def forward(self, pred, x):
        arg0: "b8[]"; arg1: "f32[3, 3]"; 
    
        arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
        l_pred_ = arg0
        l_x_ = arg1
        
        # File: /home/yidi/local/pytorch/test/dynamo/test_export.py:1535, code: y = x * 2
        y = l_x_ * 2
        
        # File: /home/yidi/local/pytorch/torch/nn/modules/linear.py:116, code: return F.linear(input, self.weight, self.bias)
        l__self___linear_weight = self.L__self___linear_weight
        l__self___linear_bias = self.L__self___linear_bias
        
        # File: /home/yidi/local/pytorch/torch/_higher_order_ops/cond.py:116, code: return cond_op(pred, true_fn, false_fn, operands)
        cond_true_0 = self.cond_true_0
        cond_false_0 = self.cond_false_0
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l__self___linear_bias, l__self___linear_weight, l_x_, y]);  l_pred_ = cond_true_0 = cond_false_0 = l__self___linear_bias = l__self___linear_weight = l_x_ = y = None
        getitem = cond[0];  cond = None
        return pytree.tree_unflatten([getitem], self._out_spec)
        
    class GraphModule(torch.nn.Module):
        def forward(self, l__self___linear_bias, l__self___linear_weight, l_x_, y_true_branch):
            l__self___linear_bias_1 = l__self___linear_bias
            l__self___linear_weight_1 = l__self___linear_weight
            l_x__1 = l_x_
            
            # File: /home/yidi/local/pytorch/torch/nn/modules/linear.py:116, code: return F.linear(input, self.weight, self.bias)
            linear = torch._C._nn.linear(l_x__1, l__self___linear_weight_1, l__self___linear_bias_1);  l_x__1 = l__self___linear_weight_1 = l__self___linear_bias_1 = None
            
            # File: /home/yidi/local/pytorch/test/dynamo/test_export.py:1537, code: return self.linear(val) * torch.tensor(2) * y
            tensor = torch.tensor(2)
            mul = linear * tensor;  linear = tensor = None
            mul_1 = mul * y_true_branch;  mul = y_true_branch = None
            return (mul_1,)
            
    class GraphModule(torch.nn.Module):
        def forward(self, l__self___linear_bias_1, l__self___linear_weight_1, l_x_, y_true_branch):
            l__self___linear_bias_2 = l__self___linear_bias_1
            l__self___linear_weight_2 = l__self___linear_weight_1
            l_x__1 = l_x_
            
            # File: /home/yidi/local/pytorch/torch/nn/modules/linear.py:116, code: return F.linear(input, self.weight, self.bias)
            linear = torch._C._nn.linear(l_x__1, l__self___linear_weight_2, l__self___linear_bias_2);  l_x__1 = l__self___linear_weight_2 = l__self___linear_bias_2 = None
            
            # File: /home/yidi/local/pytorch/test/dynamo/test_export.py:1540, code: return self.linear(val) * torch.tensor(-1)
            tensor = torch.tensor(-1)
            mul = linear * tensor;  linear = tensor = None
            return (mul,)

@ydwu4 ydwu4 closed this as completed Nov 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants