Skip to content

Commit

Permalink
Merge branch 'master' into update_data
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Dec 29, 2022
2 parents 2f64cb1 + e27d935 commit 2992fd6
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 534 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
### Changed
- Drop `SparseTensor` dependency in `GraphStore` ([#5517](https://github.com/pyg-team/pytorch_geometric/pull/5517))
- Replace `NeighborSampler` with `NeighborLoader` in the distributed sampling example ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6307))
- Fixed the filtering of node features in `transforms.RemoveIsolatedNodes` ([#6308](https://github.com/pyg-team/pytorch_geometric/pull/6308))
- Fixed a bug in `DimeNet` that causes a output dimension mismatch ([#6305](https://github.com/pyg-team/pytorch_geometric/pull/6305))
Expand Down
126 changes: 43 additions & 83 deletions test/data/test_graph_store.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,75 @@
from typing import List

import pytest
import torch
from torch_sparse import SparseTensor

from torch_geometric.data.graph_store import EdgeLayout
from torch_geometric.testing import MyGraphStore
from torch_geometric.typing import OptTensor
from torch_geometric.utils.sort_edge_index import sort_edge_index


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long)
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long)
return torch.stack([row, col], dim=0)
return row, col


def test_graph_store():
graph_store = MyGraphStore()
edge_index = torch.LongTensor([[0, 1], [1, 2]])
adj = SparseTensor(row=edge_index[0], col=edge_index[1])

def assert_equal_tensor_tuple(expected, actual):
assert len(expected) == len(actual)
for i in range(len(expected)):
assert torch.equal(expected[i], actual[i])

# We put all three tensor types: COO, CSR, and CSC, and we get them back
# to confirm that `GraphStore` works as intended.
coo = adj.coo()[:-1]
csr = adj.csr()[:-1]
csc = adj.csc()[-2::-1] # (row, colptr)

# Put:
graph_store['edge', EdgeLayout.COO] = coo
graph_store['edge', 'csr'] = csr
graph_store['edge', 'csc'] = csc

# Get:
assert_equal_tensor_tuple(coo, graph_store['edge', 'coo'])
assert_equal_tensor_tuple(csr, graph_store['edge', 'csr'])
assert_equal_tensor_tuple(csc, graph_store['edge', 'csc'])

# Get attrs:
edge_attrs = graph_store.get_all_edge_attrs()
assert len(edge_attrs) == 3

coo = torch.tensor([0, 1]), torch.tensor([1, 2])
csr = torch.tensor([0, 1, 2]), torch.tensor([1, 2])
csc = torch.tensor([0, 1]), torch.tensor([0, 0, 1, 2])

graph_store['edge_type', 'coo'] = coo
graph_store['edge_type', 'csr'] = csr
graph_store['edge_type', 'csc'] = csc

assert torch.equal(graph_store['edge_type', 'coo'][0], coo[0])
assert torch.equal(graph_store['edge_type', 'coo'][1], coo[1])
assert torch.equal(graph_store['edge_type', 'csr'][0], csr[0])
assert torch.equal(graph_store['edge_type', 'csr'][1], csr[1])
assert torch.equal(graph_store['edge_type', 'csc'][0], csc[0])
assert torch.equal(graph_store['edge_type', 'csc'][1], csc[1])

assert len(graph_store.get_all_edge_attrs()) == 3

with pytest.raises(KeyError):
_ = graph_store['edge_2', 'coo']
graph_store['edge_type_2', 'coo']


def test_graph_store_conversion():
graph_store = MyGraphStore()
edge_index = get_edge_index(100, 100, 300)
edge_index = sort_edge_index(edge_index, sort_by_row=False)
adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(100, 100))

coo = (edge_index[0], edge_index[1])
csr = adj.csr()[:2]
csc = adj.csc()[-2::-1]

# Put all edge indices:
graph_store.put_edge_index(edge_index=coo, edge_type=('v', '1', 'v'),
layout='coo', size=(100, 100), is_sorted=True)
graph_store.put_edge_index(edge_index=csr, edge_type=('v', '2', 'v'),
layout='csr', size=(100, 100))
graph_store.put_edge_index(edge_index=csc, edge_type=('v', '3', 'v'),
layout='csc', size=(100, 100))
coo = (row, col) = get_edge_index(100, 100, 300)
adj = SparseTensor(row=row, col=col, sparse_sizes=(100, 100))
csr, csc = adj.csr()[:2], adj.csc()[:2][::-1]

def assert_edge_index_equal(expected: torch.Tensor, actual: torch.Tensor):
assert torch.equal(sort_edge_index(expected), sort_edge_index(actual))
graph_store.put_edge_index(coo, ('v', '1', 'v'), 'coo', size=(100, 100))
graph_store.put_edge_index(csr, ('v', '2', 'v'), 'csr', size=(100, 100))
graph_store.put_edge_index(csc, ('v', '3', 'v'), 'csc', size=(100, 100))

# Convert to COO:
row_dict, col_dict, perm_dict = graph_store.coo()
assert len(row_dict) == len(col_dict) == len(perm_dict) == 3
for key in row_dict.keys():
actual = torch.stack((row_dict[key], col_dict[key]))
assert_edge_index_equal(actual, edge_index)
assert perm_dict[key] is None
for row, col, perm in zip(row_dict.values(), col_dict.values(),
perm_dict.values()):
assert torch.equal(row.sort()[0], coo[0].sort()[0])
assert torch.equal(col.sort()[0], coo[1].sort()[0])
assert perm is None

# Convert to CSR:
rowptr_dict, col_dict, perm_dict = graph_store.csr()
assert len(rowptr_dict) == len(col_dict) == len(perm_dict) == 3
for key in rowptr_dict:
assert torch.equal(rowptr_dict[key], csr[0])
assert torch.equal(col_dict[key], csr[1])
if key == ('v', '1', 'v'):
assert perm_dict[key] is not None
row_dict, col_dict, perm_dict = graph_store.csr()
assert len(row_dict) == len(col_dict) == len(perm_dict) == 3
for row, col in zip(row_dict.values(), col_dict.values()):
assert torch.equal(row, csr[0])
assert torch.equal(col.sort()[0], csr[1].sort()[0])

# Convert to CSC:
row_dict, colptr_dict, perm_dict = graph_store.csc()
assert len(row_dict) == len(colptr_dict) == len(perm_dict) == 3
for key in row_dict:
assert torch.equal(row_dict[key], csc[0])
assert torch.equal(colptr_dict[key], csc[1])
assert perm_dict[key] is None
row_dict, col_dict, perm_dict = graph_store.csc()
assert len(row_dict) == len(col_dict) == len(perm_dict) == 3
for row, col in zip(row_dict.values(), col_dict.values()):
assert torch.equal(row.sort()[0], csc[0].sort()[0])
assert torch.equal(col, csc[1])

# Ensure that 'edge_types' parameters work as intended:
def _tensor_eq(expected: List[OptTensor], actual: List[OptTensor]):
for tensor_expected, tensor_actual in zip(expected, actual):
if tensor_expected is None or tensor_actual is None:
return tensor_actual == tensor_expected
return torch.equal(tensor_expected, tensor_actual)

edge_types = [('v', '1', 'v'), ('v', '2', 'v')]
assert _tensor_eq(
list(graph_store.coo()[0].values())[:-1],
graph_store.coo(edge_types=edge_types)[0].values())
assert _tensor_eq(
list(graph_store.csr()[0].values())[:-1],
graph_store.csr(edge_types=edge_types)[0].values())
assert _tensor_eq(
list(graph_store.csc()[0].values())[:-1],
graph_store.csc(edge_types=edge_types)[0].values())
out = graph_store.coo([('v', '1', 'v')])
assert torch.equal(list(out[0].values())[0], coo[0])
assert torch.equal(list(out[1].values())[0], coo[1])
161 changes: 77 additions & 84 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@

from torch_geometric.data import EdgeAttr, FeatureStore, GraphStore, TensorAttr
from torch_geometric.data.feature_store import _field_status
from torch_geometric.data.graph_store import (
EDGE_LAYOUT_TO_ATTR_NAME,
EdgeLayout,
adj_type_to_edge_tensor_type,
edge_tensor_type_to_adj_type,
)
from torch_geometric.data.graph_store import EdgeLayout
from torch_geometric.data.storage import (
BaseStorage,
EdgeStorage,
Expand Down Expand Up @@ -340,26 +335,28 @@ def contains_self_loops(self) -> bool:

@dataclass
class DataTensorAttr(TensorAttr):
r"""Attribute class for `Data`, which does not require a `group_name`."""
def __init__(self, attr_name=_field_status.UNSET,
index=_field_status.UNSET):
# Treat group_name as optional, and move it to the end
r"""Tensor attribute for `Data` without group name."""
def __init__(
self,
attr_name=_field_status.UNSET,
index=None,
):
super().__init__(None, attr_name, index)


@dataclass
class DataEdgeAttr(EdgeAttr):
r"""Edge attribute class for `Data`, which does not require a
`edge_type`."""
r"""Edge attribute class for `Data` without edge type."""
def __init__(
self,
layout: Optional[EdgeLayout] = None,
is_sorted: bool = False,
size: Optional[Tuple[int, int]] = None,
edge_type: EdgeType = None,
):
# Treat edge_type as optional, and move it to the end
super().__init__(edge_type, layout, is_sorted, size)
super().__init__(None, layout, is_sorted, size)


###############################################################################


class Data(BaseData, FeatureStore, GraphStore):
Expand Down Expand Up @@ -834,27 +831,16 @@ def num_faces(self) -> Optional[int]:

# FeatureStore interface ##################################################

def items(self):
r"""Returns an `ItemsView` over the stored attributes in the `Data`
object."""
# NOTE this is necessary to override the default `MutableMapping`
# items() method.
return self._store.items()

def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
r"""Stores a feature tensor in node storage."""
out = getattr(self, attr.attr_name, None)
out = self.get(attr.attr_name)
if out is not None and attr.index is not None:
# Attr name exists, handle index:
out[attr.index] = tensor
else:
# No attr name (or None index), just store tensor:
assert attr.index is None
setattr(self, attr.attr_name, tensor)
return True

def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
r"""Obtains a feature tensor from node storage."""
# Retrieve tensor and index accordingly:
tensor = getattr(self, attr.attr_name, None)
if tensor is not None:
# TODO this behavior is a bit odd, since TensorAttr requires that
Expand All @@ -865,15 +851,12 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
return None

def _remove_tensor(self, attr: TensorAttr) -> bool:
r"""Deletes a feature tensor from node storage."""
# Remove tensor entirely:
if hasattr(self, attr.attr_name):
delattr(self, attr.attr_name)
return True
return False

def _get_tensor_size(self, attr: TensorAttr) -> Tuple:
r"""Returns the size of the tensor corresponding to `attr`."""
return self._get_tensor(attr).size()

def get_all_tensor_attrs(self) -> List[TensorAttr]:
Expand All @@ -883,70 +866,80 @@ def get_all_tensor_attrs(self) -> List[TensorAttr]:
if self._store.is_node_attr(name)
]

def __len__(self) -> int:
return BaseData.__len__(self)

# GraphStore interface ####################################################

def _put_edge_index(self, edge_index: EdgeTensorType,
edge_attr: EdgeAttr) -> bool:
r"""Stores `edge_index` in `Data`, in the specified layout."""

# Convert the edge index to a recognizable layout:
attr_name = EDGE_LAYOUT_TO_ATTR_NAME[edge_attr.layout]
attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index)
setattr(self, attr_name, attr_val)

# Set edge attributes:
if not hasattr(self, '_edge_attrs'):
self._edge_attrs = {}

self._edge_attrs[edge_attr.layout.value] = edge_attr

# Set size, if possible:
size = edge_attr.size
if size is not None:
if size[0] != size[1]:
raise ValueError(
f"'Data' requires size[0] == size[1], but received "
f"the tuple {size}.")
self.num_nodes = size[0]
self._edge_attrs[edge_attr.layout] = edge_attr

row, col = edge_index

if edge_attr.layout == EdgeLayout.COO:
self.edge_index = torch.stack([row, col], dim=0)
elif edge_attr.layout == EdgeLayout.CSR:
self.adj = SparseTensor(
rowptr=row,
col=col,
sparse_sizes=edge_attr.size,
is_sorted=True,
trust_data=True,
)
else: # edge_attr.layout == EdgeLayout.CSC:
size = edge_attr.size[::-1] if edge_attr.size is not None else None
self.adj_t = SparseTensor(
rowptr=col,
col=row,
sparse_sizes=size,
is_sorted=True,
trust_data=True,
)
return True

def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
r"""Obtains the edge index corresponding to `edge_attr` in `Data`,
in the specified layout."""
# Get the requested layout and the edge tensor type associated with it:
attr_name = EDGE_LAYOUT_TO_ATTR_NAME[edge_attr.layout]
attr_val = getattr(self._store, attr_name, None)
if attr_val is not None:
# Convert from Adj type to Tuple[Tensor, Tensor]
attr_val = adj_type_to_edge_tensor_type(edge_attr.layout, attr_val)
return attr_val
if edge_attr.layout == EdgeLayout.COO and 'edge_index' in self:
row, col = self.edge_index
return row, col
elif edge_attr.layout == EdgeLayout.CSR and 'adj' in self:
rowptr, col, _ = self.adj.csr()
return rowptr, col
elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in self:
colptr, row, _ = self.adj_t.csr()
return row, colptr
return None

def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool:
if edge_attr.layout == EdgeLayout.COO and 'edge_index' in self:
del self.edge_index
if hasattr(self, '_edge_attrs'):
self._edges_to_layout.pop(EdgeLayout.COO, None)
return True
elif edge_attr.layout == EdgeLayout.CSR and 'adj' in self:
del self.adj
if hasattr(self, '_edge_attrs'):
self._edges_to_layout.pop(EdgeLayout.CSR, None)
return True
elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in self:
del self.adj_t
if hasattr(self, '_edge_attrs'):
self._edges_to_layout.pop(EdgeLayout.CSC, None)
return True
return False

def get_all_edge_attrs(self) -> List[EdgeAttr]:
r"""Returns `EdgeAttr` objects corresponding to the edge indices stored
in `Data` and their layouts"""
if not hasattr(self, '_edge_attrs'):
return []
added_attrs = set()

# Check edges added via _put_edge_index:
edge_attrs = self._edge_attrs.values()
for attr in edge_attrs:
attr.size = (self.num_nodes, self.num_nodes)
added_attrs.add(attr.layout)

# Check edges added through regular interface:
# TODO deprecate this and store edge attributes for all edges in
# EdgeStorage
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
if attr_name in self and layout not in added_attrs:
edge_attrs.append(
EdgeAttr(edge_type=None, layout=layout,
size=(self.num_nodes, self.num_nodes)))

return edge_attrs
edge_attrs = getattr(self, '_edge_attrs', {})

if 'edge_index' in self and EdgeLayout.COO not in edge_attrs:
edge_attrs[EdgeLayout.COO] = DataEdgeAttr('coo', is_sorted=False)
if 'adj' in self and EdgeLayout.CSR not in edge_attrs:
size = self.adj.sparse_sizes()
edge_attrs[EdgeLayout.CSR] = DataEdgeAttr('csr', size=size)
if 'adj_t' in self and EdgeLayout.CSC not in edge_attrs:
size = self.adj_t.sparse_sizes()[::-1]
edge_attrs[EdgeLayout.CSC] = DataEdgeAttr('csc', size=size)

return list(edge_attrs.values())


###############################################################################
Expand Down
Loading

0 comments on commit 2992fd6

Please sign in to comment.