diff --git a/CHANGELOG.md b/CHANGELOG.md index 92c6e48e56b1..30dcf6b2309b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,7 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641)) - Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642)) - Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) -- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5731](https://github.com/pyg-team/pytorch_geometric/pull/5731), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5738](https://github.com/pyg-team/pytorch_geometric/pull/5738), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768)), [#5781](https://github.com/pyg-team/pytorch_geometric/pull/5781), [#5778](https://github.com/pyg-team/pytorch_geometric/pull/5778), [#5799](https://github.com/pyg-team/pytorch_geometric/pull/5799)) +- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5731](https://github.com/pyg-team/pytorch_geometric/pull/5731), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5738](https://github.com/pyg-team/pytorch_geometric/pull/5738), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768)), [#5781](https://github.com/pyg-team/pytorch_geometric/pull/5781), [#5778](https://github.com/pyg-team/pytorch_geometric/pull/5778), [#5797](https://github.com/pyg-team/pytorch_geometric/pull/5797), [#5798](https://github.com/pyg-team/pytorch_geometric/pull/5798), [#5799](https://github.com/pyg-team/pytorch_geometric/pull/5799)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) - Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614)) - Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602)) diff --git a/torch_geometric/datasets/jodie.py b/torch_geometric/datasets/jodie.py index e9546bf74af1..327b8602c0a4 100644 --- a/torch_geometric/datasets/jodie.py +++ b/torch_geometric/datasets/jodie.py @@ -1,4 +1,5 @@ import os.path as osp +from typing import Callable, Optional import torch @@ -9,7 +10,13 @@ class JODIEDataset(InMemoryDataset): url = 'http://snap.stanford.edu/jodie/{}.csv' names = ['reddit', 'wikipedia', 'mooc', 'lastfm'] - def __init__(self, root, name, transform=None, pre_transform=None): + def __init__( + self, + root: str, + name: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + ): self.name = name.lower() assert self.name in self.names @@ -17,19 +24,19 @@ def __init__(self, root, name, transform=None, pre_transform=None): self.data, self.slices = torch.load(self.processed_paths[0]) @property - def raw_dir(self): + def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property - def processed_dir(self): + def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property - def raw_file_names(self): + def raw_file_names(self) -> str: return f'{self.name}.csv' @property - def processed_file_names(self): + def processed_file_names(self) -> str: return 'data.pt' def download(self): @@ -54,5 +61,5 @@ def process(self): torch.save(self.collate([data]), self.processed_paths[0]) - def __repr__(self): + def __repr__(self) -> str: return f'{self.name.capitalize()}()' diff --git a/torch_geometric/datasets/mixhop_synthetic_dataset.py b/torch_geometric/datasets/mixhop_synthetic_dataset.py index 55f46f50a2b7..58102c135c7f 100644 --- a/torch_geometric/datasets/mixhop_synthetic_dataset.py +++ b/torch_geometric/datasets/mixhop_synthetic_dataset.py @@ -1,5 +1,6 @@ import os.path as osp import pickle +from typing import Callable, List, Optional import numpy as np import torch @@ -34,7 +35,13 @@ class MixHopSyntheticDataset(InMemoryDataset): url = ('https://raw.githubusercontent.com/samihaija/mixhop/master/data' '/synthetic') - def __init__(self, root, homophily, transform=None, pre_transform=None): + def __init__( + self, + root: str, + homophily: float, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + ): self.homophily = homophily assert homophily in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] super().__init__(root, transform, pre_transform) @@ -42,20 +49,20 @@ def __init__(self, root, homophily, transform=None, pre_transform=None): self.data, self.slices = torch.load(self.processed_paths[0]) @property - def raw_dir(self): + def raw_dir(self) -> str: return osp.join(self.root, f'{self.homophily:0.1f}'[::2], 'raw') @property - def processed_dir(self): + def processed_dir(self) -> str: return osp.join(self.root, f'{self.homophily:0.1f}'[::2], 'processed') @property - def raw_file_names(self): + def raw_file_names(self) -> List[str]: name = f'ind.n5000-h{self.homophily:0.1f}-c10' return [f'{name}.allx', f'{name}.ally', f'{name}.graph'] @property - def processed_file_names(self): + def processed_file_names(self) -> str: return 'data.pt' def download(self): diff --git a/torch_geometric/datasets/pcpnet_dataset.py b/torch_geometric/datasets/pcpnet_dataset.py index 3d51e0c5f283..9eec27a543a0 100644 --- a/torch_geometric/datasets/pcpnet_dataset.py +++ b/torch_geometric/datasets/pcpnet_dataset.py @@ -1,5 +1,6 @@ import os import os.path as osp +from typing import Callable, Optional import torch @@ -71,8 +72,15 @@ class PCPNetDataset(InMemoryDataset): 'VarDensityGradient': 'testset_vardensity_gradient.txt' } - def __init__(self, root, category, split='train', transform=None, - pre_transform=None, pre_filter=None): + def __init__( + self, + root: str, + category: str, + split: str = 'train', + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + ): assert split in ['train', 'val', 'test'] @@ -90,7 +98,7 @@ def __init__(self, root, category, split='train', transform=None, self.data, self.slices = torch.load(self.processed_paths[0]) @property - def raw_file_names(self): + def raw_file_names(self) -> str: if self.split == 'train': return self.category_files_train[self.category] elif self.split == 'val': @@ -99,7 +107,7 @@ def raw_file_names(self): return self.category_files_test[self.category] @property - def processed_file_names(self): + def processed_file_names(self) -> str: return self.split + '_' + self.category + '.pt' def download(self): diff --git a/torch_geometric/datasets/shrec2016.py b/torch_geometric/datasets/shrec2016.py index 55f73bdba519..53704d4744d6 100644 --- a/torch_geometric/datasets/shrec2016.py +++ b/torch_geometric/datasets/shrec2016.py @@ -1,6 +1,7 @@ import glob import os import os.path as osp +from typing import Callable, List, Optional import torch @@ -59,8 +60,16 @@ class SHREC2016(InMemoryDataset): ] partialities = ['holes', 'cuts'] - def __init__(self, root, partiality, category, train=True, transform=None, - pre_transform=None, pre_filter=None): + def __init__( + self, + root: str, + partiality: str, + category: str, + train: bool = True, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + ): assert partiality.lower() in self.partialities self.part = partiality.lower() assert category.lower() in self.categories @@ -71,18 +80,18 @@ def __init__(self, root, partiality, category, train=True, transform=None, self.data, self.slices = torch.load(path) @property - def ref(self): + def ref(self) -> str: ref = self.__ref__ if self.transform is not None: ref = self.transform(ref) return ref @property - def raw_file_names(self): + def raw_file_names(self) -> List[str]: return ['training', 'test'] @property - def processed_file_names(self): + def processed_file_names(self) -> List[str]: name = f'{self.part}_{self.cat}.pt' return [f'{i}_{name}' for i in ['ref', 'training', 'test']] @@ -140,5 +149,5 @@ def process(self): torch.save(self.collate(test_list), self.processed_paths[2]) def __repr__(self) -> str: - return (f'{self.__class__.name__}({len(self)}, ' + return (f'{self.__class__.__name__}({len(self)}, ' f'partiality={self.part}, category={self.cat})') diff --git a/torch_geometric/datasets/tosca.py b/torch_geometric/datasets/tosca.py index defddd9245d9..429869386561 100644 --- a/torch_geometric/datasets/tosca.py +++ b/torch_geometric/datasets/tosca.py @@ -1,6 +1,7 @@ import glob import os import os.path as osp +from typing import Callable, List, Optional import torch @@ -59,8 +60,14 @@ class TOSCA(InMemoryDataset): 'victoria', 'wolf' ] - def __init__(self, root, categories=None, transform=None, - pre_transform=None, pre_filter=None): + def __init__( + self, + root: str, + categories: Optional[List[str]] = None, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + ): categories = self.categories if categories is None else categories categories = [cat.lower() for cat in categories] for cat in categories: @@ -70,11 +77,11 @@ def __init__(self, root, categories=None, transform=None, self.data, self.slices = torch.load(self.processed_paths[0]) @property - def raw_file_names(self): + def raw_file_names(self) -> List[str]: return ['cat0.vert', 'cat0.tri'] @property - def processed_file_names(self): + def processed_file_names(self) -> str: name = '_'.join([cat[:2] for cat in self.categories]) return f'{name}.pt' diff --git a/torch_geometric/transforms/add_positional_encoding.py b/torch_geometric/transforms/add_positional_encoding.py index 05aed9387dc0..30bdc3f9a9e1 100644 --- a/torch_geometric/transforms/add_positional_encoding.py +++ b/torch_geometric/transforms/add_positional_encoding.py @@ -69,6 +69,7 @@ def __call__(self, data: Data) -> Data: num_nodes = data.num_nodes edge_index, edge_weight = get_laplacian( data.edge_index, + data.edge_weight, normalization='sym', num_nodes=num_nodes, )