Skip to content

Commit

Permalink
Data.validate() and HeteroData.validate() (#4885)
Browse files Browse the repository at this point in the history
* update

* changelog

* update

* typo
  • Loading branch information
rusty1s authored Jun 30, 2022
1 parent 7f55f41 commit 3d6eb74
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 2 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `LinkeNeighborLoader` support to lightning datamodule ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))
- Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885))
- Added `LinkNeighborLoader` support to `LightningDataModule` ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))
- Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884))
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
Expand Down
1 change: 1 addition & 0 deletions test/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_data():
x = torch.tensor([[1, 3, 5], [2, 4, 6]], dtype=torch.float).t()
edge_index = torch.tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]])
data = Data(x=x, edge_index=edge_index).to(torch.device('cpu'))
data.validate(raise_on_error=True)

N = data.num_nodes
assert N == 3
Expand Down
1 change: 1 addition & 0 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_init_hetero_data():
data['paper', 'paper'].edge_index = edge_index_paper_paper
data['paper', 'author'].edge_index = edge_index_paper_author
data['author', 'paper'].edge_index = edge_index_author_paper
data.validate(raise_on_error=True)

assert len(data) == 2
assert data.node_types == ['v1', 'paper', 'author']
Expand Down
36 changes: 36 additions & 0 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import warnings
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import (
Expand Down Expand Up @@ -514,6 +515,34 @@ def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
else:
return 0

def validate(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the data."""
cls_name = self.__class__.__name__
status = True

num_nodes = self.num_nodes
if num_nodes is None:
status = False
warn_or_raise(f"'num_nodes' is undefined in '{cls_name}'",
raise_on_error)

if 'edge_index' in self and self.edge_index.numel() > 0:
if self.edge_index.min() < 0:
status = False
warn_or_raise(
f"'edge_index' contains negative indices in "
f"'{cls_name}' (found {int(self.edge_index.min())})",
raise_on_error)

if num_nodes is not None and self.edge_index.max() >= num_nodes:
status = False
warn_or_raise(
f"'edge_index' contains larger indices than the number "
f"of nodes ({num_nodes}) in '{cls_name}' "
f"(found {int(self.edge_index.max())})", raise_on_error)

return status

def debug(self):
pass # TODO

Expand Down Expand Up @@ -879,3 +908,10 @@ def size_repr(key: Any, value: Any, indent: int = 0) -> str:
return f'{pad}\033[1m{key}\033[0m={out}'
else:
return f'{pad}{key}={out}'


def warn_or_raise(msg: str, raise_on_error: bool = True):
if raise_on_error:
raise ValueError(msg)
else:
warnings.warn(msg)
55 changes: 54 additions & 1 deletion torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.data.data import BaseData, Data, size_repr
from torch_geometric.data.data import BaseData, Data, size_repr, warn_or_raise
from torch_geometric.data.feature_store import FeatureStore, TensorAttr
from torch_geometric.data.graph_store import (
EDGE_LAYOUT_TO_ATTR_NAME,
Expand Down Expand Up @@ -325,6 +325,59 @@ def is_undirected(self) -> bool:
edge_index, _, _ = to_homogeneous_edge_index(self)
return is_undirected(edge_index, num_nodes=self.num_nodes)

def validate(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the data."""
cls_name = self.__class__.__name__
status = True

for edge_type, store in self._edge_store_dict.items():
src, _, dst = edge_type

num_src_nodes = self[src].num_nodes
num_dst_nodes = self[dst].num_nodes
if num_src_nodes is None:
status = False
warn_or_raise(
f"'num_nodes' is undefined in node type '{src}' of "
f"'{cls_name}'", raise_on_error)

if num_dst_nodes is None:
status = False
warn_or_raise(
f"'num_nodes' is undefined in node type '{dst}' of "
f"'{cls_name}'", raise_on_error)

if 'edge_index' in store and store.edge_index.numel() > 0:
if store.edge_index.min() < 0:
status = False
warn_or_raise(
f"'edge_index' of edge type {edge_type} contains "
f"negative indices in '{cls_name}' "
f"(found {int(store.edge_index.min())})",
raise_on_error)

if (num_src_nodes is not None
and store.edge_index[0].max() >= num_src_nodes):
status = False
warn_or_raise(
f"'edge_index' of edge type {edge_type} contains"
f"larger source indices than the number of nodes"
f"({num_src_nodes}) of this node type in '{cls_name}' "
f"(found {int(store.edge_index[0].max())})",
raise_on_error)

if (num_dst_nodes is not None
and store.edge_index[1].max() >= num_dst_nodes):
status = False
warn_or_raise(
f"'edge_index' of edge type {edge_type} contains"
f"larger destination indices than the number of nodes"
f"({num_dst_nodes}) of this node type in '{cls_name}' "
f"(found {int(store.edge_index[1].max())})",
raise_on_error)

return status

def debug(self):
pass # TODO

Expand Down

0 comments on commit 3d6eb74

Please sign in to comment.