From 81815b07f6b1acf556e0176f3c9a6883a15dfd43 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 25 Oct 2022 06:22:47 +0000 Subject: [PATCH 1/3] update --- torch_geometric/transforms/virtual_node.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/torch_geometric/transforms/virtual_node.py b/torch_geometric/transforms/virtual_node.py index eda38524aa6b..76d01319216d 100644 --- a/torch_geometric/transforms/virtual_node.py +++ b/torch_geometric/transforms/virtual_node.py @@ -1,3 +1,5 @@ +import copy + import torch from torch import Tensor @@ -35,22 +37,26 @@ def __call__(self, data: Data) -> Data: new_type = edge_type.new_full((num_nodes, ), int(edge_type.max()) + 1) edge_type = torch.cat([edge_type, new_type, new_type + 1], dim=0) - for key, value in data.items(): + old_data = copy.copy(data) + for key, value in old_data.items(): if key == 'edge_index' or key == 'edge_type': continue if isinstance(value, Tensor): - dim = data.__cat_dim__(key, value) + dim = old_data.__cat_dim__(key, value) size = list(value.size()) fill_value = None if key == 'edge_weight': size[dim] = 2 * num_nodes fill_value = 1. - elif data.is_edge_attr(key): + if key == 'batch': + size[dim] = 1 + fill_value = int(value[0]) + elif old_data.is_edge_attr(key): size[dim] = 2 * num_nodes fill_value = 0. - elif data.is_node_attr(key): + elif old_data.is_node_attr(key): size[dim] = 1 fill_value = 0. @@ -62,6 +68,6 @@ def __call__(self, data: Data) -> Data: data.edge_type = edge_type if 'num_nodes' in data: - data.num_nodes = data.num_nodes + 1 + data.num_nodes = old_data.num_nodes + 1 return data From 3b5edb3144674e2b78fb654d0091c3f973cfd4b4 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 25 Oct 2022 06:24:32 +0000 Subject: [PATCH 2/3] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c2ccad5bbfb6..06ec5b84c24c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240)) - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) ### Changed +- Fixed a bug in which `VirtualNode` mistakenly treated node features as edge features ([#5819](https://github.com/pyg-team/pytorch_geometric/pull/5819)) - Fixed `setter` and `getter` handling in `BaseStorage` ([#5815](https://github.com/pyg-team/pytorch_geometric/pull/5815)) - Fixed `path` in `hetero_conv_dblp.py` example ([#5686](https://github.com/pyg-team/pytorch_geometric/pull/5686)) - Fix `auto_select_device` routine in GraphGym for PyTorch Lightning>=1.7 ([#5677](https://github.com/pyg-team/pytorch_geometric/pull/5677)) From d3b278a70bc072839f755b59a7381fc884dc6920 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 25 Oct 2022 06:29:03 +0000 Subject: [PATCH 3/3] fix --- torch_geometric/transforms/virtual_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/virtual_node.py b/torch_geometric/transforms/virtual_node.py index 76d01319216d..d86acbbafd65 100644 --- a/torch_geometric/transforms/virtual_node.py +++ b/torch_geometric/transforms/virtual_node.py @@ -50,7 +50,7 @@ def __call__(self, data: Data) -> Data: if key == 'edge_weight': size[dim] = 2 * num_nodes fill_value = 1. - if key == 'batch': + elif key == 'batch': size[dim] = 1 fill_value = int(value[0]) elif old_data.is_edge_attr(key):