diff --git a/test/transforms/test_fixed_points.py b/test/transforms/test_fixed_points.py index 0f7d36327e59..badea4aa7fe2 100644 --- a/test/transforms/test_fixed_points.py +++ b/test/transforms/test_fixed_points.py @@ -14,46 +14,53 @@ def test_fixed_points(): x=torch.randn(100, 16), y=torch.randn(1), edge_attr=torch.randn(100, 3), + num_nodes=100, ) out = FixedPoints(50, replace=True)(copy(data)) - assert len(out) == 4 + assert len(out) == 5 assert out.pos.size() == (50, 3) assert out.x.size() == (50, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) + assert out.num_nodes == 50 out = FixedPoints(200, replace=True)(copy(data)) - assert len(out) == 4 + assert len(out) == 5 assert out.pos.size() == (200, 3) assert out.x.size() == (200, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) + assert out.num_nodes == 200 out = FixedPoints(50, replace=False, allow_duplicates=False)(copy(data)) - assert len(out) == 4 + assert len(out) == 5 assert out.pos.size() == (50, 3) assert out.x.size() == (50, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) + assert out.num_nodes == 50 out = FixedPoints(200, replace=False, allow_duplicates=False)(copy(data)) - assert len(out) == 4 + assert len(out) == 5 assert out.pos.size() == (100, 3) assert out.x.size() == (100, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) + assert out.num_nodes == 100 out = FixedPoints(50, replace=False, allow_duplicates=True)(copy(data)) - assert len(out) == 4 + assert len(out) == 5 assert out.pos.size() == (50, 3) assert out.x.size() == (50, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) + assert out.num_nodes == 50 out = FixedPoints(200, replace=False, allow_duplicates=True)(copy(data)) - assert len(out) == 4 + assert len(out) == 5 assert out.pos.size() == (200, 3) assert out.x.size() == (200, 16) assert out.y.size() == (1, ) assert out.edge_attr.size() == (100, 3) + assert out.num_nodes == 200 diff --git a/torch_geometric/transforms/fixed_points.py b/torch_geometric/transforms/fixed_points.py index 711963aa73c2..85f86202aaf2 100644 --- a/torch_geometric/transforms/fixed_points.py +++ b/torch_geometric/transforms/fixed_points.py @@ -46,10 +46,12 @@ def __call__(self, data): ], dim=0)[:self.num] for key, item in data: - if bool(re.search('edge', key)): + if key == 'num_nodes': + data.num_nodes = choice.size(0) + elif bool(re.search('edge', key)): continue - if (torch.is_tensor(item) and item.size(0) == num_nodes - and item.size(0) != 1): + elif (torch.is_tensor(item) and item.size(0) == num_nodes + and item.size(0) != 1): data[key] = item[choice] return data