diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index 3f254171888b..6fb50a9c0948 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -241,8 +241,8 @@ def __collect__(self, args, edge_index, size, kwargs): out['index'] = out['edge_index_i'] out['size'] = size - out['size_i'] = size[1] if size[1] is not None else size[0] - out['size_j'] = size[0] if size[0] is not None else size[1] + out['size_i'] = size[i] if size[i] is not None else size[j] + out['size_j'] = size[j] if size[j] is not None else size[i] out['dim_size'] = out['size_i'] return out