Skip to content

Commit

Permalink
Merge branch 'master' into type_hints/datasets.S3DIS
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Oct 21, 2022
2 parents dd798cf + 9b8ac11 commit 4d669a3
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 26 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5731](https://github.com/pyg-team/pytorch_geometric/pull/5731), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5738](https://github.com/pyg-team/pytorch_geometric/pull/5738), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768)), [#5781](https://github.com/pyg-team/pytorch_geometric/pull/5781), [#5778](https://github.com/pyg-team/pytorch_geometric/pull/5778), [#5799](https://github.com/pyg-team/pytorch_geometric/pull/5799))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5731](https://github.com/pyg-team/pytorch_geometric/pull/5731), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5738](https://github.com/pyg-team/pytorch_geometric/pull/5738), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768)), [#5781](https://github.com/pyg-team/pytorch_geometric/pull/5781), [#5778](https://github.com/pyg-team/pytorch_geometric/pull/5778), [#5797](https://github.com/pyg-team/pytorch_geometric/pull/5797), [#5798](https://github.com/pyg-team/pytorch_geometric/pull/5798), [#5799](https://github.com/pyg-team/pytorch_geometric/pull/5799))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
19 changes: 13 additions & 6 deletions torch_geometric/datasets/jodie.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os.path as osp
from typing import Callable, Optional

import torch

Expand All @@ -9,27 +10,33 @@ class JODIEDataset(InMemoryDataset):
url = 'http://snap.stanford.edu/jodie/{}.csv'
names = ['reddit', 'wikipedia', 'mooc', 'lastfm']

def __init__(self, root, name, transform=None, pre_transform=None):
def __init__(
self,
root: str,
name: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
):
self.name = name.lower()
assert self.name in self.names

super().__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_dir(self):
def raw_dir(self) -> str:
return osp.join(self.root, self.name, 'raw')

@property
def processed_dir(self):
def processed_dir(self) -> str:
return osp.join(self.root, self.name, 'processed')

@property
def raw_file_names(self):
def raw_file_names(self) -> str:
return f'{self.name}.csv'

@property
def processed_file_names(self):
def processed_file_names(self) -> str:
return 'data.pt'

def download(self):
Expand All @@ -54,5 +61,5 @@ def process(self):

torch.save(self.collate([data]), self.processed_paths[0])

def __repr__(self):
def __repr__(self) -> str:
return f'{self.name.capitalize()}()'
17 changes: 12 additions & 5 deletions torch_geometric/datasets/mixhop_synthetic_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os.path as osp
import pickle
from typing import Callable, List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -34,28 +35,34 @@ class MixHopSyntheticDataset(InMemoryDataset):
url = ('https://raw.githubusercontent.com/samihaija/mixhop/master/data'
'/synthetic')

def __init__(self, root, homophily, transform=None, pre_transform=None):
def __init__(
self,
root: str,
homophily: float,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
):
self.homophily = homophily
assert homophily in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
super().__init__(root, transform, pre_transform)

self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_dir(self):
def raw_dir(self) -> str:
return osp.join(self.root, f'{self.homophily:0.1f}'[::2], 'raw')

@property
def processed_dir(self):
def processed_dir(self) -> str:
return osp.join(self.root, f'{self.homophily:0.1f}'[::2], 'processed')

@property
def raw_file_names(self):
def raw_file_names(self) -> List[str]:
name = f'ind.n5000-h{self.homophily:0.1f}-c10'
return [f'{name}.allx', f'{name}.ally', f'{name}.graph']

@property
def processed_file_names(self):
def processed_file_names(self) -> str:
return 'data.pt'

def download(self):
Expand Down
16 changes: 12 additions & 4 deletions torch_geometric/datasets/pcpnet_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import os.path as osp
from typing import Callable, Optional

import torch

Expand Down Expand Up @@ -71,8 +72,15 @@ class PCPNetDataset(InMemoryDataset):
'VarDensityGradient': 'testset_vardensity_gradient.txt'
}

def __init__(self, root, category, split='train', transform=None,
pre_transform=None, pre_filter=None):
def __init__(
self,
root: str,
category: str,
split: str = 'train',
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
):

assert split in ['train', 'val', 'test']

Expand All @@ -90,7 +98,7 @@ def __init__(self, root, category, split='train', transform=None,
self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
def raw_file_names(self) -> str:
if self.split == 'train':
return self.category_files_train[self.category]
elif self.split == 'val':
Expand All @@ -99,7 +107,7 @@ def raw_file_names(self):
return self.category_files_test[self.category]

@property
def processed_file_names(self):
def processed_file_names(self) -> str:
return self.split + '_' + self.category + '.pt'

def download(self):
Expand Down
21 changes: 15 additions & 6 deletions torch_geometric/datasets/shrec2016.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import glob
import os
import os.path as osp
from typing import Callable, List, Optional

import torch

Expand Down Expand Up @@ -59,8 +60,16 @@ class SHREC2016(InMemoryDataset):
]
partialities = ['holes', 'cuts']

def __init__(self, root, partiality, category, train=True, transform=None,
pre_transform=None, pre_filter=None):
def __init__(
self,
root: str,
partiality: str,
category: str,
train: bool = True,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
):
assert partiality.lower() in self.partialities
self.part = partiality.lower()
assert category.lower() in self.categories
Expand All @@ -71,18 +80,18 @@ def __init__(self, root, partiality, category, train=True, transform=None,
self.data, self.slices = torch.load(path)

@property
def ref(self):
def ref(self) -> str:
ref = self.__ref__
if self.transform is not None:
ref = self.transform(ref)
return ref

@property
def raw_file_names(self):
def raw_file_names(self) -> List[str]:
return ['training', 'test']

@property
def processed_file_names(self):
def processed_file_names(self) -> List[str]:
name = f'{self.part}_{self.cat}.pt'
return [f'{i}_{name}' for i in ['ref', 'training', 'test']]

Expand Down Expand Up @@ -140,5 +149,5 @@ def process(self):
torch.save(self.collate(test_list), self.processed_paths[2])

def __repr__(self) -> str:
return (f'{self.__class__.name__}({len(self)}, '
return (f'{self.__class__.__name__}({len(self)}, '
f'partiality={self.part}, category={self.cat})')
15 changes: 11 additions & 4 deletions torch_geometric/datasets/tosca.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import glob
import os
import os.path as osp
from typing import Callable, List, Optional

import torch

Expand Down Expand Up @@ -59,8 +60,14 @@ class TOSCA(InMemoryDataset):
'victoria', 'wolf'
]

def __init__(self, root, categories=None, transform=None,
pre_transform=None, pre_filter=None):
def __init__(
self,
root: str,
categories: Optional[List[str]] = None,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
):
categories = self.categories if categories is None else categories
categories = [cat.lower() for cat in categories]
for cat in categories:
Expand All @@ -70,11 +77,11 @@ def __init__(self, root, categories=None, transform=None,
self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
def raw_file_names(self) -> List[str]:
return ['cat0.vert', 'cat0.tri']

@property
def processed_file_names(self):
def processed_file_names(self) -> str:
name = '_'.join([cat[:2] for cat in self.categories])
return f'{name}.pt'

Expand Down
1 change: 1 addition & 0 deletions torch_geometric/transforms/add_positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __call__(self, data: Data) -> Data:
num_nodes = data.num_nodes
edge_index, edge_weight = get_laplacian(
data.edge_index,
data.edge_weight,
normalization='sym',
num_nodes=num_nodes,
)
Expand Down

0 comments on commit 4d669a3

Please sign in to comment.