Skip to content

Commit

Permalink
Fix the interplay between TUDataset and pre_transform that modify…
Browse files Browse the repository at this point in the history
… node features (#4669)

* fix num node attrs

* changelog

* typo
  • Loading branch information
rusty1s authored May 17, 2022
1 parent 6156650 commit be2a463
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Fixed the interplay between `TUDataset` and `pre_transform` that modify node features ([#4669](https://github.com/pyg-team/pytorch_geometric/pull/4669))
- Make use of the `pyg_sphinx_theme` documentation template ([#4664](https://github.com/pyg-team/pyg-lib/pull/4664), [#4667](https://github.com/pyg-team/pyg-lib/pull/4667))
- Refactored reading molecular positions from sdf file for qm9 datasets ([4654](https://github.com/pyg-team/pytorch_geometric/pull/4654))
- Fixed `MLP.jittable()` bug in case `return_emb=True` ([#4645](https://github.com/pyg-team/pytorch_geometric/pull/4645), [#4648](https://github.com/pyg-team/pytorch_geometric/pull/4648))
Expand Down
38 changes: 16 additions & 22 deletions torch_geometric/datasets/tu_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,16 @@ def __init__(self, root: str, name: str,
self.name = name
self.cleaned = cleaned
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])

out = torch.load(self.processed_paths[0])
if not isinstance(out, tuple) and len(out) != 3:
raise RuntimeError(
"The 'data' object was created by an older version of PyG. "
"If this error occurred while loading an already existing "
"dataset, remove the 'processed/' directory in the dataset's "
"root folder and try again.")
self.data, self.slices, self.sizes = out

if self.data.x is not None and not use_node_attr:
num_node_attributes = self.num_node_attributes
self.data.x = self.data.x[:, num_node_attributes:]
Expand All @@ -141,34 +150,19 @@ def processed_dir(self) -> str:

@property
def num_node_labels(self) -> int:
if self.data.x is None:
return 0
for i in range(self.data.x.size(1)):
x = self.data.x[:, i:]
if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all():
return self.data.x.size(1) - i
return 0
return self.sizes['num_node_labels']

@property
def num_node_attributes(self) -> int:
if self.data.x is None:
return 0
return self.data.x.size(1) - self.num_node_labels
return self.sizes['num_node_attributes']

@property
def num_edge_labels(self) -> int:
if self.data.edge_attr is None:
return 0
for i in range(self.data.edge_attr.size(1)):
if self.data.edge_attr[:, i:].sum() == self.data.edge_attr.size(0):
return self.data.edge_attr.size(1) - i
return 0
return self.sizes['num_edge_labels']

@property
def num_edge_attributes(self) -> int:
if self.data.edge_attr is None:
return 0
return self.data.edge_attr.size(1) - self.num_edge_labels
return self.sizes['num_edge_attributes']

@property
def raw_file_names(self) -> List[str]:
Expand All @@ -189,7 +183,7 @@ def download(self):
os.rename(osp.join(folder, self.name), self.raw_dir)

def process(self):
self.data, self.slices = read_tu_data(self.raw_dir, self.name)
self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name)

if self.pre_filter is not None:
data_list = [self.get(idx) for idx in range(len(self))]
Expand All @@ -201,7 +195,7 @@ def process(self):
data_list = [self.pre_transform(data) for data in data_list]
self.data, self.slices = self.collate(data_list)

torch.save((self.data, self.slices), self.processed_paths[0])
torch.save((self.data, self.slices, sizes), self.processed_paths[0])

def __repr__(self) -> str:
return f'{self.name}({len(self)})'
21 changes: 17 additions & 4 deletions torch_geometric/io/tu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ def read_tu_data(folder, prefix):
edge_index = read_file(folder, prefix, 'A', torch.long).t() - 1
batch = read_file(folder, prefix, 'graph_indicator', torch.long) - 1

node_attributes = node_labels = None
node_attributes = torch.empty((batch.size(0), 0))
if 'node_attributes' in names:
node_attributes = read_file(folder, prefix, 'node_attributes')

node_labels = torch.empty((batch.size(0), 0))
if 'node_labels' in names:
node_labels = read_file(folder, prefix, 'node_labels', torch.long)
if node_labels.dim() == 1:
Expand All @@ -35,11 +37,12 @@ def read_tu_data(folder, prefix):
node_labels = node_labels.unbind(dim=-1)
node_labels = [F.one_hot(x, num_classes=-1) for x in node_labels]
node_labels = torch.cat(node_labels, dim=-1).to(torch.float)
x = cat([node_attributes, node_labels])

edge_attributes, edge_labels = None, None
edge_attributes = torch.empty((edge_index.size(1), 0))
if 'edge_attributes' in names:
edge_attributes = read_file(folder, prefix, 'edge_attributes')

edge_labels = torch.empty((edge_index.size(1), 0))
if 'edge_labels' in names:
edge_labels = read_file(folder, prefix, 'edge_labels', torch.long)
if edge_labels.dim() == 1:
Expand All @@ -48,6 +51,8 @@ def read_tu_data(folder, prefix):
edge_labels = edge_labels.unbind(dim=-1)
edge_labels = [F.one_hot(e, num_classes=-1) for e in edge_labels]
edge_labels = torch.cat(edge_labels, dim=-1).to(torch.float)

x = cat([node_attributes, node_labels])
edge_attr = cat([edge_attributes, edge_labels])

y = None
Expand All @@ -65,7 +70,14 @@ def read_tu_data(folder, prefix):
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
data, slices = split(data, batch)

return data, slices
sizes = {
'num_node_attributes': node_attributes.size(-1),
'num_node_labels': node_labels.size(-1),
'num_edge_attributes': edge_attributes.size(-1),
'num_edge_labels': edge_labels.size(-1),
}

return data, slices, sizes


def read_file(folder, prefix, name, dtype=None):
Expand All @@ -75,6 +87,7 @@ def read_file(folder, prefix, name, dtype=None):

def cat(seq):
seq = [item for item in seq if item is not None]
seq = [item for item in seq if item.numel() > 0]
seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq]
return torch.cat(seq, dim=-1) if len(seq) > 0 else None

Expand Down

0 comments on commit be2a463

Please sign in to comment.