Skip to content

Commit

Permalink
Fix test_mask_feature CI failure (#5672)
Browse files Browse the repository at this point in the history
To address #5670.
Fall back to `torch.rand`.
  • Loading branch information
EdisonLeeeee authored Oct 13, 2022
1 parent 6bda075 commit 5f1e0a8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
21 changes: 11 additions & 10 deletions test/utils/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,21 @@ def test_mask_feature():

torch.manual_seed(7)
out = mask_feature(x, mode='all')
assert out[0].tolist() == [[1.0, 0.0, 0.0, 4.0], [5.0, 6.0, 7.0, 0.0],
[9.0, 10.0, 0.0, 12.0]]
assert out[0].tolist() == [[0.0, 2.0, 3.0, 0.0], [5.0, 0.0, 0.0, 8.0],
[0.0, 0.0, 0.0, 0.0]]

assert out[1].tolist() == [[True, False, False, True],
[True, True, True, False],
[True, True, False, True]]
assert out[1].tolist() == [[False, True, True, False],
[True, False, False, True],
[False, False, False, False]]

torch.manual_seed(7)
out = mask_feature(x, mode='all', fill_value=-1)
assert out[0].tolist() == [[1.0, -1., -1., 4.0], [5.0, 6.0, 7.0, -1.],
[9.0, 10.0, -1., 12.0]]
assert out[1].tolist() == [[True, False, False, True],
[True, True, True, False],
[True, True, False, True]]
assert out[0].tolist() == [[-1.0, 2.0, 3.0, -1.0], [5.0, -1.0, -1.0, 8.0],
[-1.0, -1.0, -1.0, -1.0]]

assert out[1].tolist() == [[False, True, True, False],
[True, False, False, True],
[False, False, False, False]]


def test_add_random_edge():
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/utils/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def mask_feature(x: Tensor, p: float = 0.5, mode: str = 'col',
mask = torch.rand(x.size(1), device=x.device) >= p
mask = mask.view(1, -1)
else:
mask = x.bernoulli(1 - p).to(torch.bool)
mask = torch.randn_like(x) >= p

x = x.masked_fill(~mask, fill_value)
return x, mask
Expand Down

0 comments on commit 5f1e0a8

Please sign in to comment.