Skip to content

Commit

Permalink
Merge branch 'master' into node_layer_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
lightaime authored Jul 8, 2022
2 parents 400f76a + db5e6d9 commit 4af89d3
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 58 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
Expand Down
54 changes: 54 additions & 0 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch_sparse import SparseTensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.feature_store import TensorAttr
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GraphConv, to_hetero
from torch_geometric.testing import withRegisteredOp
Expand Down Expand Up @@ -322,6 +323,15 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
edge_type=('author', 'to', 'paper'),
layout='csc', size=(200, 100))

# COO (sorted):
edge_index = get_edge_index(200, 200, 100)
edge_index = edge_index[:, edge_index[1].argsort()]
data['author', 'to', 'author'].edge_index = edge_index
coo = (edge_index[0], edge_index[1])
graph_store.put_edge_index(edge_index=coo,
edge_type=('author', 'to', 'author'),
layout='coo', size=(200, 200), is_sorted=True)

# Construct neighbor loaders:
loader1 = NeighborLoader(data, batch_size=20,
input_nodes=('paper', range(100)),
Expand Down Expand Up @@ -350,3 +360,47 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
'paper', 'to', 'author'].edge_index.size())
assert (batch1['author', 'to', 'paper'].edge_index.size() == batch1[
'author', 'to', 'paper'].edge_index.size())


@withRegisteredOp('torch_sparse.hetero_temporal_neighbor_sample')
@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData])
@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData])
def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore,
GraphStore):
# Initialize dataset (once):
dataset = get_dataset(name='Cora')
data = dataset[0]

# Initialize feature store, graph store, and reference:
feature_store = FeatureStore()
graph_store = GraphStore()
hetero_data = HeteroData()

feature_store.put_tensor(data.x, group_name='paper', attr_name='x',
index=None)
hetero_data['paper'].x = data.x

feature_store.put_tensor(torch.arange(data.num_nodes), group_name='paper',
attr_name='time', index=None)
hetero_data['paper'].time = torch.arange(data.num_nodes)

num_nodes = data.x.size(dim=0)
graph_store.put_edge_index(edge_index=data.edge_index,
edge_type=('paper', 'to', 'paper'),
layout='coo', size=(num_nodes, num_nodes))
hetero_data['paper', 'to', 'paper'].edge_index = data.edge_index

loader1 = NeighborLoader(hetero_data, num_neighbors=[-1, -1],
input_nodes='paper', time_attr='time',
batch_size=128)

loader2 = NeighborLoader(
(feature_store, graph_store),
num_neighbors=[-1, -1],
input_nodes=TensorAttr(group_name='paper', attr_name='x'),
time_attr='time',
batch_size=128,
)

for batch1, batch2 in zip(loader1, loader2):
assert torch.equal(batch1['paper'].time, batch2['paper'].time)
27 changes: 23 additions & 4 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,12 @@ def _put_edge_index(self, edge_index: EdgeTensorType,
attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index)
setattr(self, attr_name, attr_val)

# Set edge attributes:
if not hasattr(self, '_edge_attrs'):
self._edge_attrs = {}

self._edge_attrs[edge_attr.layout.value] = edge_attr

# Set size, if possible:
size = edge_attr.size
if size is not None:
Expand All @@ -866,13 +872,26 @@ def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
def get_all_edge_attrs(self) -> List[EdgeAttr]:
r"""Returns `EdgeAttr` objects corresponding to the edge indices stored
in `Data` and their layouts"""
out = []
if not hasattr(self, '_edge_attrs'):
return []
added_attrs = set()

# Check edges added via _put_edge_index:
edge_attrs = self._edge_attrs.values()
for attr in edge_attrs:
attr.size = (self.num_nodes, self.num_nodes)
added_attrs.add(attr.layout)

# Check edges added through regular interface:
# TODO deprecate this and store edge attributes for all edges in
# EdgeStorage
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
if attr_name in self:
out.append(
if attr_name in self and layout not in added_attrs:
edge_attrs.append(
EdgeAttr(edge_type=None, layout=layout,
size=(self.num_nodes, self.num_nodes)))
return out

return edge_attrs


###############################################################################
Expand Down
101 changes: 57 additions & 44 deletions torch_geometric/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,12 @@ def get_edge_index(self, *args, **kwargs) -> EdgeTensorType:
Raises:
KeyError: if the edge index corresponding to attr was not found.
"""

edge_attr = self._edge_attr_cls.cast(*args, **kwargs)
edge_attr.layout = EdgeLayout(edge_attr.layout)
# Override is_sorted for CSC and CSR:
# TODO treat is_sorted specially in this function, where is_sorted=True
# returns an edge index sorted by column.
edge_attr.is_sorted = edge_attr.is_sorted or (edge_attr.layout in [
EdgeLayout.CSC, EdgeLayout.CSR
])
Expand All @@ -131,9 +134,57 @@ def get_edge_index(self, *args, **kwargs) -> EdgeTensorType:

# Layout Conversion #######################################################

def _edge_to_layout(
self,
attr: EdgeAttr,
layout: EdgeLayout,
) -> Tuple[Tensor, Tensor, OptTensor]:
from_tuple = self.get_edge_index(attr)

if layout == EdgeLayout.COO:
if attr.layout == EdgeLayout.CSR:
col = from_tuple[1]
row = torch.ops.torch_sparse.ptr2ind(from_tuple[0],
col.numel())
else:
row = from_tuple[0]
col = torch.ops.torch_sparse.ptr2ind(from_tuple[1],
row.numel())
perm = None

elif layout == EdgeLayout.CSR:
# We convert to CSR by converting to CSC on the transpose
if attr.layout == EdgeLayout.COO:
adj = edge_tensor_type_to_adj_type(
attr, (from_tuple[1], from_tuple[0]))
else:
adj = edge_tensor_type_to_adj_type(attr, from_tuple).t()

# NOTE we set is_sorted=False here as is_sorted refers to
# the edge_index being sorted by the destination node
# (column), but here we deal with the transpose
attr_copy = copy.copy(attr)
attr_copy.is_sorted = False
attr_copy.size = None if attr.size is None else (attr.size[1],
attr.size[0])

# Actually rowptr, col, perm
row, col, perm = to_csc(adj, attr_copy, device='cpu')

else:
adj = edge_tensor_type_to_adj_type(attr, from_tuple)

# Actually colptr, row, perm
col, row, perm = to_csc(adj, attr, device='cpu')

return row, col, perm

# TODO support `replace` to replace the existing edge index.
def _to_layout(self, layout: EdgeLayout,
store: bool = False) -> ConversionOutputType:
def _all_edges_to_layout(
self,
layout: EdgeLayout,
store: bool = False,
) -> ConversionOutputType:
# Obtain all edge attributes, grouped by type:
edge_attrs = self.get_all_edge_attrs()
edge_type_to_attrs: Dict[Any, List[EdgeAttr]] = defaultdict(list)
Expand Down Expand Up @@ -165,45 +216,7 @@ def _to_layout(self, layout: EdgeLayout,
else:
from_attr = edge_attrs[edge_layouts.index(EdgeLayout.CSR)]

from_tuple = self.get_edge_index(from_attr)

# Convert to the new layout:
if layout == EdgeLayout.COO:
if from_attr.layout == EdgeLayout.CSR:
col = from_tuple[1]
row = torch.ops.torch_sparse.ptr2ind(
from_tuple[0], col.numel())
else:
row = from_tuple[0]
col = torch.ops.torch_sparse.ptr2ind(
from_tuple[1], row.numel())
perm = None

elif layout == EdgeLayout.CSR:
# We convert to CSR by converting to CSC on the transpose
if from_attr.layout == EdgeLayout.COO:
adj = edge_tensor_type_to_adj_type(
from_attr, (from_tuple[1], from_tuple[0]))
else:
adj = edge_tensor_type_to_adj_type(
from_attr, from_tuple).t()

# NOTE we set is_sorted=False here as is_sorted refers to
# the edge_index being sorted by the destination node
# (column), but here we deal with the transpose
from_attr_copy = copy.copy(from_attr)
from_attr_copy.is_sorted = False
from_attr_copy.size = None if from_attr.size is None else (
from_attr.size[1], from_attr.size[0])

# Actually rowptr, col, perm
row, col, perm = to_csc(adj, from_attr_copy, device='cpu')

else:
adj = edge_tensor_type_to_adj_type(from_attr, from_tuple)

# Actually colptr, row, perm
col, row, perm = to_csc(adj, from_attr, device='cpu')
row, col, perm = self._edge_to_layout(from_attr, layout)

row_dict[from_attr.edge_type] = row
col_dict[from_attr.edge_type] = col
Expand Down Expand Up @@ -235,17 +248,17 @@ def _to_layout(self, layout: EdgeLayout,
def coo(self, store: bool = False) -> ConversionOutputType:
r"""Converts the edge indices in the graph store to COO format,
optionally storing the converted edge indices in the graph store."""
return self._to_layout(EdgeLayout.COO, store)
return self._all_edges_to_layout(EdgeLayout.COO, store)

def csr(self, store: bool = False) -> ConversionOutputType:
r"""Converts the edge indices in the graph store to CSR format,
optionally storing the converted edge indices in the graph store."""
return self._to_layout(EdgeLayout.CSR, store)
return self._all_edges_to_layout(EdgeLayout.CSR, store)

def csc(self, store: bool = False) -> ConversionOutputType:
r"""Converts the edge indices in the graph store to CSC format,
optionally storing the converted edge indices in the graph store."""
return self._to_layout(EdgeLayout.CSC, store)
return self._all_edges_to_layout(EdgeLayout.CSC, store)

# Additional methods ######################################################

Expand Down
29 changes: 27 additions & 2 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
out = self._node_store_dict.get(attr.group_name, None)
if out:
# Group name exists, handle index or create new attribute name:
val = getattr(out, attr.attr_name)
val = getattr(out, attr.attr_name, None)
if val is not None:
val[attr.index] = tensor
else:
Expand Down Expand Up @@ -754,6 +754,13 @@ def _put_edge_index(self, edge_index: EdgeTensorType,
attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index)
setattr(self[edge_attr.edge_type], attr_name, attr_val)

# Set edge attributes:
if not hasattr(self[edge_attr.edge_type], '_edge_attrs'):
self[edge_attr.edge_type]._edge_attrs = {}

self[edge_attr.edge_type]._edge_attrs[
edge_attr.layout.value] = edge_attr

key = self._to_canonical(edge_attr.edge_type)
src, _, dst = key

Expand All @@ -780,12 +787,30 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]:
r"""Returns a list of `EdgeAttr` objects corresponding to the edge
indices stored in `HeteroData` and their layouts."""
out = []
added_attrs = set()

# Check edges added via _put_edge_index:
for edge_type, _ in self.edge_items():
if not hasattr(self[edge_type], '_edge_attrs'):
continue
edge_attrs = self[edge_type]._edge_attrs.values()
for attr in edge_attrs:
attr.size = self[edge_type].size()
added_attrs.add((attr.edge_type, attr.layout))
out.extend(edge_attrs)

# Check edges added through regular interface:
# TODO deprecate this and store edge attributes for all edges in
# EdgeStorage
for edge_type, edge_store in self.edge_items():
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
if attr_name in edge_store:
# Don't double count:
if attr_name in edge_store and ((edge_type, layout)
not in added_attrs):
out.append(
EdgeAttr(edge_type=edge_type, layout=layout,
size=self[edge_type].size()))

return out


Expand Down
31 changes: 24 additions & 7 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,21 @@ def __init__(
# TODO support `collect` on `FeatureStore`
self.node_time_dict = None
if time_attr is not None:
raise ValueError(
f"'time_attr' attribute not yet supported for "
f"'{data[0].__class__.__name__}' object")
# We need to obtain all features with 'attr_name=time_attr'
# from the feature store and store them in node_time_dict. To
# do so, we make an explicit feature store GET call here with
# the relevant 'TensorAttr's
time_attrs = [
attr for attr in feature_store.get_all_tensor_attrs()
if attr.attr_name == time_attr
]
for attr in time_attrs:
attr.index = None
time_tensors = feature_store.multi_get_tensor(time_attrs)
self.node_time_dict = {
time_attr.group_name: time_tensor
for time_attr, time_tensor in zip(time_attrs, time_tensors)
}

# Obtain all node and edge metadata:
node_attrs = feature_store.get_all_tensor_attrs()
Expand Down Expand Up @@ -475,18 +487,23 @@ def to_index(tensor):
if isinstance(input_nodes, Tensor):
return None, to_index(input_nodes)

# Can't infer number of nodes from a group_name; need an attr_name
if isinstance(input_nodes, str):
num_nodes = feature_store.get_tensor_size(input_nodes)[0]
return input_nodes, range(num_nodes)
raise NotImplementedError(
f"Cannot infer the number of nodes from a single string "
f"(got '{input_nodes}'). Please pass a more explicit "
f"representation. ")

if isinstance(input_nodes, (list, tuple)):
assert len(input_nodes) == 2
assert isinstance(input_nodes[0], str)

node_type, input_nodes = input_nodes
if input_nodes is None:
num_nodes = feature_store.get_tensor_size(input_nodes)[0]
return input_nodes[0], range(num_nodes)
raise NotImplementedError(
f"Cannot infer the number of nodes from a node type alone "
f"(got '{input_nodes}'). Please pass a more explicit "
f"representation. ")
return node_type, to_index(input_nodes)

assert isinstance(input_nodes, TensorAttr)
Expand Down

0 comments on commit 4af89d3

Please sign in to comment.