You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When a model calls add_self_loops in its forward function and does not explicitly pass the num_nodes parameter, this leads to wrong behavior when tracing the model with torch.jit.trace: The number of nodes in the example input to torch.jit.trace will be saved as a constant.
This may lead to wrong output when forwarding a bigger graph through the traced model (due to self loops not being added), and possibly an error being thrown when forwarding a smaller graph through the traced model (due to out-of-range self loops being added).
Here's example code, roughly along the lines of test_message_passing.py. I trace the model with (x, edge_index) and call the traced model with a smaller and bigger graph. When adding self loops is enabled, the traced model returns wrong outputs on the bigger graph and crashes on the smaller graph. This is because the number of nodes 4 is traced as a constant for the add_self_loops part.
MyConv (no self loops), 4 nodes: conv output and traced_conv output are close
MyConv (no self loops), 8 nodes: conv output and traced_conv output are close
MyConv (no self loops), 2 nodes: conv output and traced_conv output are close
MyConv (with self loops), 4 nodes: conv output and traced_conv output are close
MyConv (with self loops), 8 nodes: conv output and traced_conv output are NOT close
Traceback (most recent call last):
File "~/Documents/edge_update_jit_test/minimal_example.py", line 44, in <module>
test_my_conv_jittable()
File "~/Documents/edge_update_jit_test/minimal_example.py", line 40, in test_my_conv_jittable
out_smaller_traced = traced_conv.forward(x_smaller, edge_index_smaller)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
~/Documents/torch-2.0/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py(272): _lift
~/Documents/torch-2.0/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py(336): _collect
~/Documents/torch-2.0/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py(459): propagate
~/Documents/edge_update_jit_test/minimal_example.py(16): forward
~/Documents/torch-2.0/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
~/Documents/torch-2.0/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
~/Documents/torch-2.0/lib/python3.10/site-packages/torch/jit/_trace.py(1056): trace_module
~/Documents/torch-2.0/lib/python3.10/site-packages/torch/jit/_trace.py(794): trace
~/Documents/edge_update_jit_test/minimal_example.py(31): test_my_conv_jittable
~/Documents/edge_update_jit_test/minimal_example.py(44): <module>
RuntimeError: index out of range in self
Environment
PyG version: 2.3.0
PyTorch version: 2.0.0+cu117
OS: Ubuntu 22.04
Python version: 3.10.6
CUDA/cuDNN version: 12.1
How you installed PyTorch and PyG (conda, pip, source): pip
The text was updated successfully, but these errors were encountered:
…loops` (#7330)
Fix of issue #7226.
The problem I described came from the `maybe_num_nodes` function which
computed `int(edge_index.max()) + 1`. The `int()` call made the computed
number of nodes a constant for the `torch.jit.trace` function.
I fixed it with an if branch that is only executed while tracing. (I
could not get the same code for all cases to work: the workaround used
only for tracing now would also work for normal operation, but then
would break when using `torch.jit.script`).
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
🐛 Describe the bug
When a model calls
add_self_loops
in its forward function and does not explicitly pass thenum_nodes
parameter, this leads to wrong behavior when tracing the model withtorch.jit.trace
: The number of nodes in the example input totorch.jit.trace
will be saved as a constant.This may lead to wrong output when forwarding a bigger graph through the traced model (due to self loops not being added), and possibly an error being thrown when forwarding a smaller graph through the traced model (due to out-of-range self loops being added).
Here's example code, roughly along the lines of
test_message_passing.py
. I trace the model with(x, edge_index)
and call the traced model with a smaller and bigger graph. When adding self loops is enabled, the traced model returns wrong outputs on the bigger graph and crashes on the smaller graph. This is because the number of nodes 4 is traced as a constant for theadd_self_loops
part.It outputs:
Environment
conda
,pip
, source):pip
The text was updated successfully, but these errors were encountered: