From 7787d43973cdf8765ad66986136661644148000a Mon Sep 17 00:00:00 2001 From: xionggziran <45009031+xionggziran@users.noreply.github.com> Date: Wed, 6 Apr 2022 01:39:03 +0800 Subject: [PATCH] Update message_passing.py When using a bipartite graph for reverse message passing(set flow == 'target_to_source'), the message construction fails due to incorrect dim_size, although it doesn't matter, and different adjacency matrices can be used to achieve the same effect. --- torch_geometric/nn/conv/message_passing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index 3f254171888b..6fb50a9c0948 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -241,8 +241,8 @@ def __collect__(self, args, edge_index, size, kwargs): out['index'] = out['edge_index_i'] out['size'] = size - out['size_i'] = size[1] if size[1] is not None else size[0] - out['size_j'] = size[0] if size[0] is not None else size[1] + out['size_i'] = size[i] if size[i] is not None else size[j] + out['size_j'] = size[j] if size[j] is not None else size[i] out['dim_size'] = out['size_i'] return out