Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TypeHints] Random Transforms #5714

Merged
merged 10 commits into from
Oct 15, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701] (https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701] (https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/transforms/random_flip.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

Expand All @@ -14,11 +15,11 @@ class RandomFlip(BaseTransform):
p (float, optional): Probability that node positions will be flipped.
(default: :obj:`0.5`)
"""
def __init__(self, axis, p=0.5):
def __init__(self, axis: int, p: float=0.5):
self.axis = axis
self.p = p

def __call__(self, data):
def __call__(self, data) -> Data:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
if random.random() < self.p:
pos = data.pos.clone()
pos[..., self.axis] = -pos[..., self.axis]
Expand Down
6 changes: 4 additions & 2 deletions torch_geometric/transforms/random_jitter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numbers
from itertools import repeat
from typing import Sequence, Union

import torch

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

Expand All @@ -21,10 +23,10 @@ class RandomJitter(BaseTransform):
If :obj:`translate` is a number instead of a sequence, the same
range is used for each dimension.
"""
def __init__(self, translate):
def __init__(self, translate: Union[float, int, Sequence]):
self.translate = translate

def __call__(self, data):
def __call__(self, data) -> Data:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
(n, dim), t = data.pos.size(), self.translate
if isinstance(t, numbers.Number):
t = list(repeat(t, times=dim))
Expand Down
6 changes: 4 additions & 2 deletions torch_geometric/transforms/random_rotate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import math
import numbers
import random
from typing import Tuple, Union

import torch

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform, LinearTransformation

Expand All @@ -20,14 +22,14 @@ class RandomRotate(BaseTransform):
\mathrm{degrees}]`.
axis (int, optional): The rotation axis. (default: :obj:`0`)
"""
def __init__(self, degrees, axis=0):
def __init__(self, degrees: Union[Tuple[float, float], float], axis=0):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(degrees, numbers.Number):
degrees = (-abs(degrees), abs(degrees))
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2
self.degrees = degrees
self.axis = axis

def __call__(self, data):
def __call__(self, data) -> Data:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
degree = math.pi * random.uniform(*self.degrees) / 180.0
sin, cos = math.sin(degree), math.cos(degree)

Expand Down
6 changes: 4 additions & 2 deletions torch_geometric/transforms/random_scale.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import random
from typing import Tuple

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

Expand All @@ -24,11 +26,11 @@ class RandomScale(BaseTransform):
is randomly sampled from the range
:math:`a \leq \mathrm{scale} \leq b`.
"""
def __init__(self, scales):
def __init__(self, scales: Tuple[float, float]):
assert isinstance(scales, (tuple, list)) and len(scales) == 2
self.scales = scales

def __call__(self, data):
def __call__(self, data) -> Data:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
scale = random.uniform(*self.scales)
data.pos = data.pos * scale
return data
Expand Down