Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_self_loops() tracing constant with torch.jit.trace() #7226

Closed
Vuenc opened this issue Apr 24, 2023 · 0 comments · Fixed by #7330
Closed

add_self_loops() tracing constant with torch.jit.trace() #7226

Vuenc opened this issue Apr 24, 2023 · 0 comments · Fixed by #7330
Labels

Comments

@Vuenc
Copy link
Contributor

Vuenc commented Apr 24, 2023

🐛 Describe the bug

When a model calls add_self_loops in its forward function and does not explicitly pass the num_nodes parameter, this leads to wrong behavior when tracing the model with torch.jit.trace: The number of nodes in the example input to torch.jit.trace will be saved as a constant.

This may lead to wrong output when forwarding a bigger graph through the traced model (due to self loops not being added), and possibly an error being thrown when forwarding a smaller graph through the traced model (due to out-of-range self loops being added).

Here's example code, roughly along the lines of test_message_passing.py. I trace the model with (x, edge_index) and call the traced model with a smaller and bigger graph. When adding self loops is enabled, the traced model returns wrong outputs on the bigger graph and crashes on the smaller graph. This is because the number of nodes 4 is traced as a constant for the add_self_loops part.

from torch import Tensor
from torch_geometric.nn import MessagePassing
import torch
import torch_geometric

class MyConv(MessagePassing):
    def __init__(self, use_self_loops: bool, aggr: str = 'add'):
        super().__init__(aggr=aggr)
        self.use_self_loops = use_self_loops

    def forward(self, x: Tensor, edge_index: torch.Tensor) -> Tensor:        
        if self.use_self_loops:
            edge_index, _ = torch_geometric.utils.add_self_loops(edge_index)

        # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)
        return self.propagate(edge_index, x=x, size=None)
        
    def message(self, x_j: Tensor) -> Tensor:
        return x_j

def test_my_conv_jittable():
    x = torch.randn(4, 16)
    x_smaller = torch.randn(2, 16)
    x_bigger = torch.randn(8, 16)
    edge_index = torch.stack([torch.arange(0, 3), torch.arange(1, 4)])
    edge_index_bigger = torch.stack([torch.arange(0, 7), torch.arange(1, 8)])
    edge_index_smaller = torch.stack([torch.arange(0, 1), torch.arange(1, 2)])

    for conv, name in [(MyConv(use_self_loops=False), "MyConv (no self loops)"), (MyConv(use_self_loops=True), "MyConv (with self loops)")]:
        out = conv.forward(x, edge_index)
        traced_conv = torch.jit.trace(conv, ((x, edge_index)))
        out_traced = traced_conv.forward(x, edge_index)
        print(f"{name}, 4 nodes: conv output and traced_conv output are {'' if torch.allclose(out, out_traced) else 'NOT '}close")

        out_bigger = conv.forward(x_bigger, edge_index_bigger)
        out_bigger_traced = traced_conv.forward(x_bigger, edge_index_bigger)
        print(f"{name}, 8 nodes: conv output and traced_conv output are {'' if torch.allclose(out_bigger, out_bigger_traced) else 'NOT '}close")

        out_smaller = conv.forward(x_smaller, edge_index_smaller)
        out_smaller_traced = traced_conv.forward(x_smaller, edge_index_smaller)
        print(f"{name}, 2 nodes: conv output and traced_conv output are {'' if torch.allclose(out_smaller, out_smaller_traced) else 'NOT '}close")

if __name__ == "__main__":
    test_my_conv_jittable()

It outputs:

MyConv (no self loops), 4 nodes: conv output and traced_conv output are close
MyConv (no self loops), 8 nodes: conv output and traced_conv output are close
MyConv (no self loops), 2 nodes: conv output and traced_conv output are close
MyConv (with self loops), 4 nodes: conv output and traced_conv output are close
MyConv (with self loops), 8 nodes: conv output and traced_conv output are NOT close
Traceback (most recent call last):
  File "~/Documents/edge_update_jit_test/minimal_example.py", line 44, in <module>
    test_my_conv_jittable()
  File "~/Documents/edge_update_jit_test/minimal_example.py", line 40, in test_my_conv_jittable
    out_smaller_traced = traced_conv.forward(x_smaller, edge_index_smaller)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
~/Documents/torch-2.0/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py(272): _lift
~/Documents/torch-2.0/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py(336): _collect
~/Documents/torch-2.0/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py(459): propagate
~/Documents/edge_update_jit_test/minimal_example.py(16): forward
~/Documents/torch-2.0/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
~/Documents/torch-2.0/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
~/Documents/torch-2.0/lib/python3.10/site-packages/torch/jit/_trace.py(1056): trace_module
~/Documents/torch-2.0/lib/python3.10/site-packages/torch/jit/_trace.py(794): trace
~/Documents/edge_update_jit_test/minimal_example.py(31): test_my_conv_jittable
~/Documents/edge_update_jit_test/minimal_example.py(44): <module>
RuntimeError: index out of range in self

Environment

  • PyG version: 2.3.0
  • PyTorch version: 2.0.0+cu117
  • OS: Ubuntu 22.04
  • Python version: 3.10.6
  • CUDA/cuDNN version: 12.1
  • How you installed PyTorch and PyG (conda, pip, source): pip
@Vuenc Vuenc added the bug label Apr 24, 2023
rusty1s added a commit that referenced this issue May 10, 2023
…loops` (#7330)

Fix of issue #7226.

The problem I described came from the `maybe_num_nodes` function which
computed `int(edge_index.max()) + 1`. The `int()` call made the computed
number of nodes a constant for the `torch.jit.trace` function.

I fixed it with an if branch that is only executed while tracing. (I
could not get the same code for all cases to work: the workaround used
only for tracing now would also work for normal operation, but then
would break when using `torch.jit.script`).

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant