diff --git a/test/utils/test_augmentation.py b/test/utils/test_augmentation.py index 249a1c578ac4..3988c8431e34 100644 --- a/test/utils/test_augmentation.py +++ b/test/utils/test_augmentation.py @@ -54,20 +54,21 @@ def test_mask_feature(): torch.manual_seed(7) out = mask_feature(x, mode='all') - assert out[0].tolist() == [[1.0, 0.0, 0.0, 4.0], [5.0, 6.0, 7.0, 0.0], - [9.0, 10.0, 0.0, 12.0]] + assert out[0].tolist() == [[0.0, 2.0, 3.0, 0.0], [5.0, 0.0, 0.0, 8.0], + [0.0, 0.0, 0.0, 0.0]] - assert out[1].tolist() == [[True, False, False, True], - [True, True, True, False], - [True, True, False, True]] + assert out[1].tolist() == [[False, True, True, False], + [True, False, False, True], + [False, False, False, False]] torch.manual_seed(7) out = mask_feature(x, mode='all', fill_value=-1) - assert out[0].tolist() == [[1.0, -1., -1., 4.0], [5.0, 6.0, 7.0, -1.], - [9.0, 10.0, -1., 12.0]] - assert out[1].tolist() == [[True, False, False, True], - [True, True, True, False], - [True, True, False, True]] + assert out[0].tolist() == [[-1.0, 2.0, 3.0, -1.0], [5.0, -1.0, -1.0, 8.0], + [-1.0, -1.0, -1.0, -1.0]] + + assert out[1].tolist() == [[False, True, True, False], + [True, False, False, True], + [False, False, False, False]] def test_add_random_edge(): diff --git a/torch_geometric/utils/augmentation.py b/torch_geometric/utils/augmentation.py index 16f0f79ce62e..f1df3028dcc3 100644 --- a/torch_geometric/utils/augmentation.py +++ b/torch_geometric/utils/augmentation.py @@ -143,7 +143,7 @@ def mask_feature(x: Tensor, p: float = 0.5, mode: str = 'col', mask = torch.rand(x.size(1), device=x.device) >= p mask = mask.view(1, -1) else: - mask = x.bernoulli(1 - p).to(torch.bool) + mask = torch.randn_like(x) >= p x = x.masked_fill(~mask, fill_value) return x, mask