From 661f4ccb46fea02f6450c04a99d6265a7dcb136f Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Sat, 15 Oct 2022 16:44:50 +0800 Subject: [PATCH 01/10] add typehints for random transforms --- torch_geometric/transforms/random_flip.py | 5 +++-- torch_geometric/transforms/random_jitter.py | 6 ++++-- torch_geometric/transforms/random_rotate.py | 6 ++++-- torch_geometric/transforms/random_scale.py | 6 ++++-- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/torch_geometric/transforms/random_flip.py b/torch_geometric/transforms/random_flip.py index 554e9189b1d0..2a1255491241 100644 --- a/torch_geometric/transforms/random_flip.py +++ b/torch_geometric/transforms/random_flip.py @@ -1,5 +1,6 @@ import random +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @@ -14,11 +15,11 @@ class RandomFlip(BaseTransform): p (float, optional): Probability that node positions will be flipped. (default: :obj:`0.5`) """ - def __init__(self, axis, p=0.5): + def __init__(self, axis: int, p=0.5): self.axis = axis self.p = p - def __call__(self, data): + def __call__(self, data) -> Data: if random.random() < self.p: pos = data.pos.clone() pos[..., self.axis] = -pos[..., self.axis] diff --git a/torch_geometric/transforms/random_jitter.py b/torch_geometric/transforms/random_jitter.py index cbf5d69bd05d..636fe884eadb 100644 --- a/torch_geometric/transforms/random_jitter.py +++ b/torch_geometric/transforms/random_jitter.py @@ -1,8 +1,10 @@ import numbers from itertools import repeat +from typing import Sequence, Union import torch +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @@ -21,10 +23,10 @@ class RandomJitter(BaseTransform): If :obj:`translate` is a number instead of a sequence, the same range is used for each dimension. """ - def __init__(self, translate): + def __init__(self, translate: Union[float, int, Sequence]): self.translate = translate - def __call__(self, data): + def __call__(self, data) -> Data: (n, dim), t = data.pos.size(), self.translate if isinstance(t, numbers.Number): t = list(repeat(t, times=dim)) diff --git a/torch_geometric/transforms/random_rotate.py b/torch_geometric/transforms/random_rotate.py index e4d422d68438..ea01e98f1f54 100644 --- a/torch_geometric/transforms/random_rotate.py +++ b/torch_geometric/transforms/random_rotate.py @@ -1,9 +1,11 @@ import math import numbers import random +from typing import Tuple, Union import torch +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform, LinearTransformation @@ -20,14 +22,14 @@ class RandomRotate(BaseTransform): \mathrm{degrees}]`. axis (int, optional): The rotation axis. (default: :obj:`0`) """ - def __init__(self, degrees, axis=0): + def __init__(self, degrees: Union[Tuple[float, float], float], axis=0): if isinstance(degrees, numbers.Number): degrees = (-abs(degrees), abs(degrees)) assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 self.degrees = degrees self.axis = axis - def __call__(self, data): + def __call__(self, data) -> Data: degree = math.pi * random.uniform(*self.degrees) / 180.0 sin, cos = math.sin(degree), math.cos(degree) diff --git a/torch_geometric/transforms/random_scale.py b/torch_geometric/transforms/random_scale.py index dcc0ca0bfc59..e5c0b952f01b 100644 --- a/torch_geometric/transforms/random_scale.py +++ b/torch_geometric/transforms/random_scale.py @@ -1,5 +1,7 @@ import random +from typing import Tuple +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @@ -24,11 +26,11 @@ class RandomScale(BaseTransform): is randomly sampled from the range :math:`a \leq \mathrm{scale} \leq b`. """ - def __init__(self, scales): + def __init__(self, scales: Tuple[float, float]): assert isinstance(scales, (tuple, list)) and len(scales) == 2 self.scales = scales - def __call__(self, data): + def __call__(self, data) -> Data: scale = random.uniform(*self.scales) data.pos = data.pos * scale return data From 84b0f4789194e463940cb401c0068cc0af5cb4f8 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Sat, 15 Oct 2022 16:47:13 +0800 Subject: [PATCH 02/10] fix changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d3b159655135..cf578a8166be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,7 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641)) - Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642)) - Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) -- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701] (https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710)) +- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701] (https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) - Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614)) - Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602)) From 797182b61008d055f1341a139ac52a4ee6b02301 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Sat, 15 Oct 2022 12:19:23 +0200 Subject: [PATCH 03/10] Update torch_geometric/transforms/random_flip.py Co-authored-by: Jinu Sunil --- torch_geometric/transforms/random_flip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/random_flip.py b/torch_geometric/transforms/random_flip.py index 2a1255491241..6d4a458d5b62 100644 --- a/torch_geometric/transforms/random_flip.py +++ b/torch_geometric/transforms/random_flip.py @@ -15,7 +15,7 @@ class RandomFlip(BaseTransform): p (float, optional): Probability that node positions will be flipped. (default: :obj:`0.5`) """ - def __init__(self, axis: int, p=0.5): + def __init__(self, axis: int, p: float=0.5): self.axis = axis self.p = p From ad09f1006632f27ba19034e8314bb58da42d8129 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Oct 2022 10:20:32 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/transforms/random_flip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/random_flip.py b/torch_geometric/transforms/random_flip.py index 6d4a458d5b62..a9b84efeda69 100644 --- a/torch_geometric/transforms/random_flip.py +++ b/torch_geometric/transforms/random_flip.py @@ -15,7 +15,7 @@ class RandomFlip(BaseTransform): p (float, optional): Probability that node positions will be flipped. (default: :obj:`0.5`) """ - def __init__(self, axis: int, p: float=0.5): + def __init__(self, axis: int, p: float = 0.5): self.axis = axis self.p = p From 8aa67fbf01cbb10ba4e7a157db9dfaceba9ba49b Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Sat, 15 Oct 2022 12:20:51 +0200 Subject: [PATCH 05/10] Update torch_geometric/transforms/random_rotate.py Co-authored-by: Jinu Sunil --- torch_geometric/transforms/random_rotate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/random_rotate.py b/torch_geometric/transforms/random_rotate.py index ea01e98f1f54..4d2cac8bde4d 100644 --- a/torch_geometric/transforms/random_rotate.py +++ b/torch_geometric/transforms/random_rotate.py @@ -22,7 +22,7 @@ class RandomRotate(BaseTransform): \mathrm{degrees}]`. axis (int, optional): The rotation axis. (default: :obj:`0`) """ - def __init__(self, degrees: Union[Tuple[float, float], float], axis=0): + def __init__(self, degrees: Union[Tuple[float, float], float], axis: int=0): if isinstance(degrees, numbers.Number): degrees = (-abs(degrees), abs(degrees)) assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 From 24652887c2d99ed0a7021ca615bdb513c5d151b3 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Sat, 15 Oct 2022 12:21:04 +0200 Subject: [PATCH 06/10] Update torch_geometric/transforms/random_rotate.py Co-authored-by: Jinu Sunil --- torch_geometric/transforms/random_rotate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/random_rotate.py b/torch_geometric/transforms/random_rotate.py index 4d2cac8bde4d..15804a208127 100644 --- a/torch_geometric/transforms/random_rotate.py +++ b/torch_geometric/transforms/random_rotate.py @@ -29,7 +29,7 @@ def __init__(self, degrees: Union[Tuple[float, float], float], axis: int=0): self.degrees = degrees self.axis = axis - def __call__(self, data) -> Data: + def __call__(self, data: Data) -> Data: degree = math.pi * random.uniform(*self.degrees) / 180.0 sin, cos = math.sin(degree), math.cos(degree) From d10ffd63ec6fac860d9620f3c295684908316641 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Sat, 15 Oct 2022 12:21:58 +0200 Subject: [PATCH 07/10] Update torch_geometric/transforms/random_scale.py Co-authored-by: Jinu Sunil --- torch_geometric/transforms/random_scale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/random_scale.py b/torch_geometric/transforms/random_scale.py index e5c0b952f01b..a8ebbe4ea6d5 100644 --- a/torch_geometric/transforms/random_scale.py +++ b/torch_geometric/transforms/random_scale.py @@ -30,7 +30,7 @@ def __init__(self, scales: Tuple[float, float]): assert isinstance(scales, (tuple, list)) and len(scales) == 2 self.scales = scales - def __call__(self, data) -> Data: + def __call__(self, data: Data) -> Data: scale = random.uniform(*self.scales) data.pos = data.pos * scale return data From ce81711e3d09b47c750437c45a0cf47e181cdf57 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Oct 2022 10:22:08 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/transforms/random_rotate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/transforms/random_rotate.py b/torch_geometric/transforms/random_rotate.py index 15804a208127..04bb99d2b4d9 100644 --- a/torch_geometric/transforms/random_rotate.py +++ b/torch_geometric/transforms/random_rotate.py @@ -22,7 +22,8 @@ class RandomRotate(BaseTransform): \mathrm{degrees}]`. axis (int, optional): The rotation axis. (default: :obj:`0`) """ - def __init__(self, degrees: Union[Tuple[float, float], float], axis: int=0): + def __init__(self, degrees: Union[Tuple[float, float], float], + axis: int = 0): if isinstance(degrees, numbers.Number): degrees = (-abs(degrees), abs(degrees)) assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 From aad0dd1ed44fc32c13618901d6a694aff2322eea Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Sat, 15 Oct 2022 12:22:36 +0200 Subject: [PATCH 09/10] Update torch_geometric/transforms/random_flip.py --- torch_geometric/transforms/random_flip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/random_flip.py b/torch_geometric/transforms/random_flip.py index a9b84efeda69..cb9fdc63f233 100644 --- a/torch_geometric/transforms/random_flip.py +++ b/torch_geometric/transforms/random_flip.py @@ -19,7 +19,7 @@ def __init__(self, axis: int, p: float = 0.5): self.axis = axis self.p = p - def __call__(self, data) -> Data: + def __call__(self, data: Data) -> Data: if random.random() < self.p: pos = data.pos.clone() pos[..., self.axis] = -pos[..., self.axis] From 846b9cb6ceb2b5c33b623e8797745a096e610fcc Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Sat, 15 Oct 2022 12:22:41 +0200 Subject: [PATCH 10/10] Update torch_geometric/transforms/random_jitter.py --- torch_geometric/transforms/random_jitter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/random_jitter.py b/torch_geometric/transforms/random_jitter.py index 636fe884eadb..4b39035e9044 100644 --- a/torch_geometric/transforms/random_jitter.py +++ b/torch_geometric/transforms/random_jitter.py @@ -26,7 +26,7 @@ class RandomJitter(BaseTransform): def __init__(self, translate: Union[float, int, Sequence]): self.translate = translate - def __call__(self, data) -> Data: + def __call__(self, data: Data) -> Data: (n, dim), t = data.pos.size(), self.translate if isinstance(t, numbers.Number): t = list(repeat(t, times=dim))