-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RGCNConv has multiple graph breaks #8467
Comments
Would you mind sharing your code and env details to reproduce? |
Hi, thanks for the quick response, I hope this gives more details: Installs:
|
@sharlinu Thank you, I'll have a look at this tomorrow.
Just FYI, if you use torch_geometric.typing.WITH_INDEX_SORT = False
torch_geometric.typing.WITH_TORCH_SCATTER = False because they're not compatible with pytorch_geometric/torch_geometric/compile.py Lines 72 to 80 in c13c00f
|
For the note, users will no longer need to disable the flags manually thanks to #8698 in the next release 2.5.0. |
A quick fix with
In @sharlinu Please let me know if this works for you. |
Closing this issue as I cannot reproduce the issue by running this script with the latest PyG and nightly PyTorch release. import torch
import torch_geometric
torch_geometric.backend.use_segment_matmul = False
device = 'cuda'
conv = torch_geometric.nn.RGCNConv(4, 32, 4, None, None, aggr='sum').to(device)
c_conv = torch.compile(conv)
x = torch.randn(4, 4, device=device)
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3],
[0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1],
], device=device)
edge_type = torch.tensor(
[0, 1, 1, 0, 0, 1, 2, 3, 3, 2, 2, 3],
device=device,
)
out1 = conv(x, edge_index, edge_type)
out2 = c_conv(x, edge_index, edge_type)
assert torch.allclose(out1, out2, atol=1e-8) I'd be happy to assist if the issue still persists. For note, #8890 is a tracking issue for supporting custom ops with |
🚀 RGCNConv does barely have any speed up with torch_geometric.compile()
I have recently implemented torch_geometric.compile on my nn.module that mostly consists of the torch_geometric.nn.conv.rgcn_conv module. Analysing the graph breaks with torch._dynamo.explain() I have found that this module causes 6 graph breaks:
Would it be possible to have a version of this module with reduced number of graph breaks? Unfortunately, I do not know enough about compilation and graph breaks to understand what causes the breaks, so any help on this would also be greatly appreciated!
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: