Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MessagePassing.propagate with flow="target_to_source" and size=(...) bug introduced in PyG 2.0.4 #4591

Closed
pimdh opened this issue May 3, 2022 · 5 comments
Labels

Comments

@pimdh
Copy link

pimdh commented May 3, 2022

🐛 Describe the bug

With source_to_target we get as expected:

import torch
from torch_geometric.nn import MessagePassing

mp = MessagePassing(flow="source_to_target", node_dim=0)
mp.propagate(torch.tensor([[0], [0]]), size=(100, 10), x=torch.randn(100, 3)).shape
# => (10,3)

In PyG 2.0.4 with target_to_source, we get:

mp = MessagePassing(flow="target_to_source", node_dim=0)
mp.propagate(torch.tensor([[0], [0]]), size=(10, 100), x=torch.randn(100, 3)).shape
# => (100,3)

While

mp = MessagePassing(flow="target_to_source", node_dim=0)
mp.propagate(torch.tensor([[0], [0]]), size=(100, 10), x=torch.randn(100, 3))

raises

Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-20-472fe80fa7f5>", line 2, in <module>
    mp.propagate(torch.tensor([[0], [0]]), size=(100, 10), x=torch.randn(100, 3)).shape
  File "/opt/conda/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 309, in propagate
    coll_dict = self.__collect__(self.__user_args__, edge_index,
  File "/opt/conda/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 201, in __collect__
    self.__set_size__(size, dim, data)
  File "/opt/conda/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 165, in __set_size__
    raise ValueError(
ValueError: Encountered tensor with size 100 in dimension 0, but expected size 10.

I'm not sure which of the two shapes is expected, but either the second or the third example should return the same as the first example.

In PyG 2.0.3, this was working fine. I suspect the issue is related to the change made in #3907. In PyG 2.0.3, we get:

mp = MessagePassing(flow="target_to_source", node_dim=0)
mp.propagate(torch.tensor([[0], [0]]), size=(10, 100), x=torch.randn(100, 3)).shape 
# => raises above error
mp = MessagePassing(flow="target_to_source", node_dim=0)
mp.propagate(torch.tensor([[0], [0]]), size=(100, 10), x=torch.randn(100, 3)).shape 
# => (10, 3), as desired

Environment

  • PyG version: 2.0.4
  • PyTorch version: 1.10.2
  • OS: Ubuntu 20.04.3 LTS
  • Python version: 3.8.10
@pimdh pimdh added the bug label May 3, 2022
@rusty1s
Copy link
Member

rusty1s commented May 3, 2022

Thanks for reporting. It doesn't make much sense for

mp = MessagePassing(flow="target_to_source", node_dim=0)
mp.propagate(torch.tensor([[0], [0]]), size=(100, 10), x=torch.randn(100, 3))

to succeed since you should pass in a tuple of node features whenever your edge_index denotes a bipartite graph of different size. The following works as expected for me:

mp = MessagePassing(flow="target_to_source", node_dim=0)
out = mp.propagate(torch.tensor([[0], [0]]), size=(100, 10),
                   x=(torch.randn(100, 3), torch.randn(10, 3)))
print(out.shape)

@pimdh
Copy link
Author

pimdh commented May 3, 2022

Thanks for coming back to me. I'm a bit confused still, and suspect a bug really is present in PyG 2.0.4. Imagine the following setting: I want to go from 100 nodes, with features arange(100) to 10 nodes, via message passing along edges. The only edge is from the 2nd to the 1st node (0-indexed). I expect to get 10 nodes out with only the 1st node set to 2, the rest zero.

As you suggest, we can add a zero feature on the 10 nodes and supply both to propagate. This would lead to:

mp = MessagePassing(flow="target_to_source", node_dim=0)
out = mp.propagate(torch.tensor([[1], [2]]), size=(100, 10),
                   x=(torch.arange(100), torch.zeros(10)))
print(out)  # unexpectedly gives 10 zeros

If I flip the sizes around, I get:

mp = MessagePassing(flow="target_to_source", node_dim=0)
out = mp.propagate(torch.tensor([[1], [2]]), size=(10, 100),
                   x=(torch.zeros(10), torch.arange(100)))
print(out)  # [0, 2, 0, 0, ...] length 100

With source_to_target, I get the desired behaviour with:

mp = MessagePassing(flow="source_to_target", node_dim=0)
out = mp.propagate(torch.tensor([[2], [1]]), size=(100, 10),
                   x=(torch.arange(100), torch.zeros(10)))
print(out)  # [0, 2, 0, ...] length 10

In PyG 2.0.3, I get as expected:

mp = MessagePassing(flow="target_to_source", node_dim=0)
out = mp.propagate(torch.tensor([[1], [2]]), size=(100, 10),
                   x=(torch.arange(100), torch.zeros(10)))
print(out)  #  [0, 2, 0, ...] length 10

In PyG 2.0.3, I also correctly get the answer without providing the zero tensor, which seems preferable to me:

mp = MessagePassing(flow="target_to_source", node_dim=0)
out = mp.propagate(torch.tensor([[1], [2]]), size=(100, 10),
                   x=torch.arange(100))
print(out)  # unexpectedly gives 10 zeros

In PyG 2.0.4, I get the aforementioned error.

Thanks!

@rusty1s
Copy link
Member

rusty1s commented May 4, 2022

Can you upgrade to PyG master? I believe this has been already fixed in #4418.

@rusty1s
Copy link
Member

rusty1s commented May 4, 2022

That is,

mp = MessagePassing(flow="target_to_source", node_dim=0)
out = mp.propagate(torch.tensor([[1], [2]]), size=(100, 10),
                   x=(torch.arange(100), torch.zeros(10)))
print(out)  # gives 100 zeros

@pimdh
Copy link
Author

pimdh commented May 4, 2022

Thanks! It's indeed fixed at master. The correct code is (I don't need the zero tensor):

mp = MessagePassing(flow="target_to_source", node_dim=0)
out = mp.propagate(torch.tensor([[1], [2]]), size=(10, 100),
                   x=torch.arange(100))
print(out)  # [0, 2, 0, 0, ...] length 10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants