Skip to content

Commit

Permalink
GraphStore: support COO layouts, refactor conversion logic (#4883)
Browse files Browse the repository at this point in the history
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
mananshah99 and rusty1s authored Jun 29, 2022
1 parent 6c3c235 commit 7f55f41
Show file tree
Hide file tree
Showing 11 changed files with 406 additions and 153 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
Expand Down
2 changes: 1 addition & 1 deletion test/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def assert_equal_tensor_tuple(expected, actual):
csc = adj.csc()[-2::-1] # (row, colptr)

# Put:
data.put_edge_index(coo, layout='coo')
data.put_edge_index(coo, layout='coo', size=(3, 3))
data.put_edge_index(csr, layout='csr')
data.put_edge_index(csc, layout='csc')

Expand Down
56 changes: 56 additions & 0 deletions test/data/test_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@

from torch_geometric.data.graph_store import EdgeLayout
from torch_geometric.testing.graph_store import MyGraphStore
from torch_geometric.utils.sort_edge_index import sort_edge_index


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long)
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long)
return torch.stack([row, col], dim=0)


def test_graph_store():
Expand Down Expand Up @@ -38,3 +45,52 @@ def assert_equal_tensor_tuple(expected, actual):

with pytest.raises(KeyError):
_ = graph_store['edge_2', 'coo']


def test_graph_store_conversion():
graph_store = MyGraphStore()
edge_index = get_edge_index(100, 100, 300)
edge_index = sort_edge_index(edge_index, sort_by_row=False)
adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(100, 100))

coo = (edge_index[0], edge_index[1])
csr = adj.csr()[:2]
csc = adj.csc()[-2::-1]

# Put all edge indices:
graph_store.put_edge_index(edge_index=coo, edge_type=('v', '1', 'v'),
layout='coo', size=(100, 100), is_sorted=True)

graph_store.put_edge_index(edge_index=csr, edge_type=('v', '2', 'v'),
layout='csr', size=(100, 100))

graph_store.put_edge_index(edge_index=csc, edge_type=('v', '3', 'v'),
layout='csc', size=(100, 100))

def assert_edge_index_equal(expected: torch.Tensor, actual: torch.Tensor):
assert torch.equal(sort_edge_index(expected), sort_edge_index(actual))

# Convert to COO:
row_dict, col_dict, perm_dict = graph_store.coo()
assert len(row_dict) == len(col_dict) == len(perm_dict) == 3
for key in row_dict.keys():
actual = torch.stack((row_dict[key], col_dict[key]))
assert_edge_index_equal(actual, edge_index)
assert perm_dict[key] is None

# Convert to CSR:
rowptr_dict, col_dict, perm_dict = graph_store.csr()
assert len(rowptr_dict) == len(col_dict) == len(perm_dict) == 3
for key in rowptr_dict:
assert torch.equal(rowptr_dict[key], csr[0])
assert torch.equal(col_dict[key], csr[1])
if key == ('v', '1', 'v'):
assert perm_dict[key] is not None

# Convert to CSC:
row_dict, colptr_dict, perm_dict = graph_store.csc()
assert len(row_dict) == len(colptr_dict) == len(perm_dict) == 3
for key in row_dict:
assert torch.equal(row_dict[key], csc[0])
assert torch.equal(colptr_dict[key], csc[1])
assert perm_dict[key] is None
9 changes: 6 additions & 3 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,12 @@ def assert_equal_tensor_tuple(expected, actual):
csc = adj.csc()[-2::-1] # (row, colptr)

# Put:
data.put_edge_index(coo, layout='coo', edge_type=('a', 'to', 'b'))
data.put_edge_index(csr, layout='csr', edge_type=('a', 'to', 'c'))
data.put_edge_index(csc, layout='csc', edge_type=('b', 'to', 'c'))
data.put_edge_index(coo, layout='coo', edge_type=('a', 'to', 'b'),
size=(3, 3))
data.put_edge_index(csr, layout='csr', edge_type=('a', 'to', 'c'),
size=(3, 3))
data.put_edge_index(csc, layout='csc', edge_type=('b', 'to', 'c'),
size=(3, 3))

# Get:
assert_equal_tensor_tuple(
Expand Down
31 changes: 16 additions & 15 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,29 +297,30 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
feature_store.put_tensor(x, group_name='author', attr_name='x', index=None)

# Set up edge indices:

# COO:
edge_index = get_edge_index(100, 100, 500)
data['paper', 'to', 'paper'].edge_index = edge_index
graph_store.put_edge_index(
edge_index=SparseTensor.from_edge_index(edge_index).csr()[:2],
edge_type=('paper', 'to', 'paper'),
layout='csr',
)
coo = (edge_index[0], edge_index[1])
graph_store.put_edge_index(edge_index=coo,
edge_type=('paper', 'to', 'paper'),
layout='coo', size=(100, 100))

# CSR:
edge_index = get_edge_index(100, 200, 1000)
data['paper', 'to', 'author'].edge_index = edge_index
graph_store.put_edge_index(
edge_index=SparseTensor.from_edge_index(edge_index).csr()[:2],
edge_type=('paper', 'to', 'author'),
layout='csr',
)
csr = SparseTensor.from_edge_index(edge_index).csr()[:2]
graph_store.put_edge_index(edge_index=csr,
edge_type=('paper', 'to', 'author'),
layout='csr', size=(100, 200))

# CSC:
edge_index = get_edge_index(200, 100, 1000)
data['author', 'to', 'paper'].edge_index = edge_index
graph_store.put_edge_index(
edge_index=SparseTensor.from_edge_index(edge_index).csr()[:2],
edge_type=('author', 'to', 'paper'),
layout='csr',
)
csc = SparseTensor(row=edge_index[1], col=edge_index[0]).csr()[-2::-1]
graph_store.put_edge_index(edge_index=csc,
edge_type=('author', 'to', 'paper'),
layout='csc', size=(200, 100))

# Construct neighbor loaders:
loader1 = NeighborLoader(data, batch_size=20,
Expand Down
83 changes: 30 additions & 53 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
TensorAttr,
_field_status,
)
from torch_geometric.data.graph_store import EdgeAttr, EdgeLayout, GraphStore
from torch_geometric.data.graph_store import (
EDGE_LAYOUT_TO_ATTR_NAME,
EdgeAttr,
EdgeLayout,
GraphStore,
adj_type_to_edge_tensor_type,
edge_tensor_type_to_adj_type,
)
from torch_geometric.data.storage import (
BaseStorage,
EdgeStorage,
Expand All @@ -33,7 +40,6 @@
)
from torch_geometric.deprecation import deprecated
from torch_geometric.typing import (
Adj,
EdgeTensorType,
EdgeType,
FeatureTensorType,
Expand Down Expand Up @@ -328,10 +334,15 @@ def __init__(self, attr_name=_field_status.UNSET,
class DataEdgeAttr(EdgeAttr):
r"""Edge attribute class for `Data`, which does not require a
`edge_type`."""
def __init__(self, layout: EdgeLayout, is_sorted: bool = False,
edge_type: EdgeType = None):
# Treat group_name as optional, and move it to the end
super().__init__(edge_type, layout, is_sorted)
def __init__(
self,
layout: EdgeLayout,
is_sorted: bool = False,
size: Optional[Tuple[int, int]] = None,
edge_type: EdgeType = None,
):
# Treat edge_type as optional, and move it to the end
super().__init__(edge_type, layout, is_sorted, size)


class Data(BaseData, FeatureStore, GraphStore):
Expand Down Expand Up @@ -795,10 +806,20 @@ def __len__(self) -> int:
def _put_edge_index(self, edge_index: EdgeTensorType,
edge_attr: EdgeAttr) -> bool:
r"""Stores `edge_index` in `Data`, in the specified layout."""

# Convert the edge index to a recognizable layout:
attr_name = EDGE_LAYOUT_TO_ATTR_NAME[edge_attr.layout]
attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index)
setattr(self, attr_name, attr_val)

# Set size, if possible:
size = edge_attr.size
if size is not None:
if size[0] != size[1]:
raise ValueError(
f"'Data' requires size[0] == size[1], but received "
f"the tuple {size}.")
self.num_nodes = size[0]
return True

def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
Expand All @@ -818,58 +839,14 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]:
out = []
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
if attr_name in self:
out.append(EdgeAttr(edge_type=None, layout=layout))
out.append(
EdgeAttr(edge_type=None, layout=layout,
size=(self.num_nodes, self.num_nodes)))
return out


###############################################################################

EDGE_LAYOUT_TO_ATTR_NAME = {
EdgeLayout.COO: 'edge_index',
EdgeLayout.CSR: 'adj',
EdgeLayout.CSC: 'adj_t',
}


def edge_tensor_type_to_adj_type(
attr: EdgeAttr,
tensor_tuple: EdgeTensorType,
) -> Adj:
r"""Converts an EdgeTensorType tensor tuple to a PyG Adj tensor."""
src, dst = tensor_tuple

if attr.layout == EdgeLayout.COO:
# COO: (row, col)
if (src[0].storage().data_ptr() == dst[1].storage().data_ptr()):
# Do not copy if the tensor tuple is constructed from the same
# storage (instead, return a view):
out = torch.empty(0, dtype=src.dtype)
out.set_(src.storage(), storage_offset=0,
size=src.size() + dst.size())
return out.view(2, -1)
return torch.stack(tensor_tuple)
elif attr.layout == EdgeLayout.CSR:
# CSR: (rowptr, col)
return SparseTensor(rowptr=src, col=dst, is_sorted=True)
elif attr.layout == EdgeLayout.CSC:
# CSC: (row, colptr) this is a transposed adjacency matrix, so rowptr
# is the compressed column and col is the uncompressed row.
return SparseTensor(rowptr=dst, col=src, is_sorted=True)
raise ValueError(f"Bad edge layout (got '{attr.layout}')")


def adj_type_to_edge_tensor_type(layout: EdgeLayout,
edge_index: Adj) -> EdgeTensorType:
r"""Converts a PyG Adj tensor to an EdgeTensorType equivalent."""
if isinstance(edge_index, Tensor):
return (edge_index[0], edge_index[1]) # (row, col)
if layout == EdgeLayout.COO:
return edge_index.coo()[:-1] # (row, col
elif layout == EdgeLayout.CSR:
return edge_index.csr()[:-1] # (rowptr, col)
else:
return edge_index.csr()[-2::-1] # (row, colptr)


def size_repr(key: Any, value: Any, indent: int = 0) -> str:
pad = ' ' * indent
Expand Down
Loading

0 comments on commit 7f55f41

Please sign in to comment.