Skip to content

Commit

Permalink
fix: num_nodes in fixed_points transform (#4394)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Apr 1, 2022
1 parent 36a80d7 commit 5469e16
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
19 changes: 13 additions & 6 deletions test/transforms/test_fixed_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,53 @@ def test_fixed_points():
x=torch.randn(100, 16),
y=torch.randn(1),
edge_attr=torch.randn(100, 3),
num_nodes=100,
)

out = FixedPoints(50, replace=True)(copy(data))
assert len(out) == 4
assert len(out) == 5
assert out.pos.size() == (50, 3)
assert out.x.size() == (50, 16)
assert out.y.size() == (1, )
assert out.edge_attr.size() == (100, 3)
assert out.num_nodes == 50

out = FixedPoints(200, replace=True)(copy(data))
assert len(out) == 4
assert len(out) == 5
assert out.pos.size() == (200, 3)
assert out.x.size() == (200, 16)
assert out.y.size() == (1, )
assert out.edge_attr.size() == (100, 3)
assert out.num_nodes == 200

out = FixedPoints(50, replace=False, allow_duplicates=False)(copy(data))
assert len(out) == 4
assert len(out) == 5
assert out.pos.size() == (50, 3)
assert out.x.size() == (50, 16)
assert out.y.size() == (1, )
assert out.edge_attr.size() == (100, 3)
assert out.num_nodes == 50

out = FixedPoints(200, replace=False, allow_duplicates=False)(copy(data))
assert len(out) == 4
assert len(out) == 5
assert out.pos.size() == (100, 3)
assert out.x.size() == (100, 16)
assert out.y.size() == (1, )
assert out.edge_attr.size() == (100, 3)
assert out.num_nodes == 100

out = FixedPoints(50, replace=False, allow_duplicates=True)(copy(data))
assert len(out) == 4
assert len(out) == 5
assert out.pos.size() == (50, 3)
assert out.x.size() == (50, 16)
assert out.y.size() == (1, )
assert out.edge_attr.size() == (100, 3)
assert out.num_nodes == 50

out = FixedPoints(200, replace=False, allow_duplicates=True)(copy(data))
assert len(out) == 4
assert len(out) == 5
assert out.pos.size() == (200, 3)
assert out.x.size() == (200, 16)
assert out.y.size() == (1, )
assert out.edge_attr.size() == (100, 3)
assert out.num_nodes == 200
8 changes: 5 additions & 3 deletions torch_geometric/transforms/fixed_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def __call__(self, data):
], dim=0)[:self.num]

for key, item in data:
if bool(re.search('edge', key)):
if key == 'num_nodes':
data.num_nodes = choice.size(0)
elif bool(re.search('edge', key)):
continue
if (torch.is_tensor(item) and item.size(0) == num_nodes
and item.size(0) != 1):
elif (torch.is_tensor(item) and item.size(0) == num_nodes
and item.size(0) != 1):
data[key] = item[choice]

return data
Expand Down

0 comments on commit 5469e16

Please sign in to comment.