diff --git a/CHANGELOG.md b/CHANGELOG.md index 019563cf78f3..15090708aa02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908)) - Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873)) - Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815)) -- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883)) +- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922)) - Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) - Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850)) - Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838)) diff --git a/test/loader/test_neighbor_loader.py b/test/loader/test_neighbor_loader.py index 789403f2747e..3afd5223f80c 100644 --- a/test/loader/test_neighbor_loader.py +++ b/test/loader/test_neighbor_loader.py @@ -4,6 +4,7 @@ from torch_sparse import SparseTensor from torch_geometric.data import Data, HeteroData +from torch_geometric.data.feature_store import TensorAttr from torch_geometric.loader import NeighborLoader from torch_geometric.nn import GraphConv, to_hetero from torch_geometric.testing import withRegisteredOp @@ -322,6 +323,15 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore): edge_type=('author', 'to', 'paper'), layout='csc', size=(200, 100)) + # COO (sorted): + edge_index = get_edge_index(200, 200, 100) + edge_index = edge_index[:, edge_index[1].argsort()] + data['author', 'to', 'author'].edge_index = edge_index + coo = (edge_index[0], edge_index[1]) + graph_store.put_edge_index(edge_index=coo, + edge_type=('author', 'to', 'author'), + layout='coo', size=(200, 200), is_sorted=True) + # Construct neighbor loaders: loader1 = NeighborLoader(data, batch_size=20, input_nodes=('paper', range(100)), @@ -350,3 +360,47 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore): 'paper', 'to', 'author'].edge_index.size()) assert (batch1['author', 'to', 'paper'].edge_index.size() == batch1[ 'author', 'to', 'paper'].edge_index.size()) + + +@withRegisteredOp('torch_sparse.hetero_temporal_neighbor_sample') +@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData]) +@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData]) +def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore, + GraphStore): + # Initialize dataset (once): + dataset = get_dataset(name='Cora') + data = dataset[0] + + # Initialize feature store, graph store, and reference: + feature_store = FeatureStore() + graph_store = GraphStore() + hetero_data = HeteroData() + + feature_store.put_tensor(data.x, group_name='paper', attr_name='x', + index=None) + hetero_data['paper'].x = data.x + + feature_store.put_tensor(torch.arange(data.num_nodes), group_name='paper', + attr_name='time', index=None) + hetero_data['paper'].time = torch.arange(data.num_nodes) + + num_nodes = data.x.size(dim=0) + graph_store.put_edge_index(edge_index=data.edge_index, + edge_type=('paper', 'to', 'paper'), + layout='coo', size=(num_nodes, num_nodes)) + hetero_data['paper', 'to', 'paper'].edge_index = data.edge_index + + loader1 = NeighborLoader(hetero_data, num_neighbors=[-1, -1], + input_nodes='paper', time_attr='time', + batch_size=128) + + loader2 = NeighborLoader( + (feature_store, graph_store), + num_neighbors=[-1, -1], + input_nodes=TensorAttr(group_name='paper', attr_name='x'), + time_attr='time', + batch_size=128, + ) + + for batch1, batch2 in zip(loader1, loader2): + assert torch.equal(batch1['paper'].time, batch2['paper'].time) diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 7a19bdb414a8..2c3200e4ab25 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -842,6 +842,12 @@ def _put_edge_index(self, edge_index: EdgeTensorType, attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index) setattr(self, attr_name, attr_val) + # Set edge attributes: + if not hasattr(self, '_edge_attrs'): + self._edge_attrs = {} + + self._edge_attrs[edge_attr.layout.value] = edge_attr + # Set size, if possible: size = edge_attr.size if size is not None: @@ -866,13 +872,26 @@ def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: def get_all_edge_attrs(self) -> List[EdgeAttr]: r"""Returns `EdgeAttr` objects corresponding to the edge indices stored in `Data` and their layouts""" - out = [] + if not hasattr(self, '_edge_attrs'): + return [] + added_attrs = set() + + # Check edges added via _put_edge_index: + edge_attrs = self._edge_attrs.values() + for attr in edge_attrs: + attr.size = (self.num_nodes, self.num_nodes) + added_attrs.add(attr.layout) + + # Check edges added through regular interface: + # TODO deprecate this and store edge attributes for all edges in + # EdgeStorage for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items(): - if attr_name in self: - out.append( + if attr_name in self and layout not in added_attrs: + edge_attrs.append( EdgeAttr(edge_type=None, layout=layout, size=(self.num_nodes, self.num_nodes))) - return out + + return edge_attrs ############################################################################### diff --git a/torch_geometric/data/graph_store.py b/torch_geometric/data/graph_store.py index 48e66bd17503..8f35792c254f 100644 --- a/torch_geometric/data/graph_store.py +++ b/torch_geometric/data/graph_store.py @@ -117,9 +117,12 @@ def get_edge_index(self, *args, **kwargs) -> EdgeTensorType: Raises: KeyError: if the edge index corresponding to attr was not found. """ + edge_attr = self._edge_attr_cls.cast(*args, **kwargs) edge_attr.layout = EdgeLayout(edge_attr.layout) # Override is_sorted for CSC and CSR: + # TODO treat is_sorted specially in this function, where is_sorted=True + # returns an edge index sorted by column. edge_attr.is_sorted = edge_attr.is_sorted or (edge_attr.layout in [ EdgeLayout.CSC, EdgeLayout.CSR ]) @@ -131,9 +134,57 @@ def get_edge_index(self, *args, **kwargs) -> EdgeTensorType: # Layout Conversion ####################################################### + def _edge_to_layout( + self, + attr: EdgeAttr, + layout: EdgeLayout, + ) -> Tuple[Tensor, Tensor, OptTensor]: + from_tuple = self.get_edge_index(attr) + + if layout == EdgeLayout.COO: + if attr.layout == EdgeLayout.CSR: + col = from_tuple[1] + row = torch.ops.torch_sparse.ptr2ind(from_tuple[0], + col.numel()) + else: + row = from_tuple[0] + col = torch.ops.torch_sparse.ptr2ind(from_tuple[1], + row.numel()) + perm = None + + elif layout == EdgeLayout.CSR: + # We convert to CSR by converting to CSC on the transpose + if attr.layout == EdgeLayout.COO: + adj = edge_tensor_type_to_adj_type( + attr, (from_tuple[1], from_tuple[0])) + else: + adj = edge_tensor_type_to_adj_type(attr, from_tuple).t() + + # NOTE we set is_sorted=False here as is_sorted refers to + # the edge_index being sorted by the destination node + # (column), but here we deal with the transpose + attr_copy = copy.copy(attr) + attr_copy.is_sorted = False + attr_copy.size = None if attr.size is None else (attr.size[1], + attr.size[0]) + + # Actually rowptr, col, perm + row, col, perm = to_csc(adj, attr_copy, device='cpu') + + else: + adj = edge_tensor_type_to_adj_type(attr, from_tuple) + + # Actually colptr, row, perm + col, row, perm = to_csc(adj, attr, device='cpu') + + return row, col, perm + # TODO support `replace` to replace the existing edge index. - def _to_layout(self, layout: EdgeLayout, - store: bool = False) -> ConversionOutputType: + def _all_edges_to_layout( + self, + layout: EdgeLayout, + store: bool = False, + ) -> ConversionOutputType: # Obtain all edge attributes, grouped by type: edge_attrs = self.get_all_edge_attrs() edge_type_to_attrs: Dict[Any, List[EdgeAttr]] = defaultdict(list) @@ -165,45 +216,7 @@ def _to_layout(self, layout: EdgeLayout, else: from_attr = edge_attrs[edge_layouts.index(EdgeLayout.CSR)] - from_tuple = self.get_edge_index(from_attr) - - # Convert to the new layout: - if layout == EdgeLayout.COO: - if from_attr.layout == EdgeLayout.CSR: - col = from_tuple[1] - row = torch.ops.torch_sparse.ptr2ind( - from_tuple[0], col.numel()) - else: - row = from_tuple[0] - col = torch.ops.torch_sparse.ptr2ind( - from_tuple[1], row.numel()) - perm = None - - elif layout == EdgeLayout.CSR: - # We convert to CSR by converting to CSC on the transpose - if from_attr.layout == EdgeLayout.COO: - adj = edge_tensor_type_to_adj_type( - from_attr, (from_tuple[1], from_tuple[0])) - else: - adj = edge_tensor_type_to_adj_type( - from_attr, from_tuple).t() - - # NOTE we set is_sorted=False here as is_sorted refers to - # the edge_index being sorted by the destination node - # (column), but here we deal with the transpose - from_attr_copy = copy.copy(from_attr) - from_attr_copy.is_sorted = False - from_attr_copy.size = None if from_attr.size is None else ( - from_attr.size[1], from_attr.size[0]) - - # Actually rowptr, col, perm - row, col, perm = to_csc(adj, from_attr_copy, device='cpu') - - else: - adj = edge_tensor_type_to_adj_type(from_attr, from_tuple) - - # Actually colptr, row, perm - col, row, perm = to_csc(adj, from_attr, device='cpu') + row, col, perm = self._edge_to_layout(from_attr, layout) row_dict[from_attr.edge_type] = row col_dict[from_attr.edge_type] = col @@ -235,17 +248,17 @@ def _to_layout(self, layout: EdgeLayout, def coo(self, store: bool = False) -> ConversionOutputType: r"""Converts the edge indices in the graph store to COO format, optionally storing the converted edge indices in the graph store.""" - return self._to_layout(EdgeLayout.COO, store) + return self._all_edges_to_layout(EdgeLayout.COO, store) def csr(self, store: bool = False) -> ConversionOutputType: r"""Converts the edge indices in the graph store to CSR format, optionally storing the converted edge indices in the graph store.""" - return self._to_layout(EdgeLayout.CSR, store) + return self._all_edges_to_layout(EdgeLayout.CSR, store) def csc(self, store: bool = False) -> ConversionOutputType: r"""Converts the edge indices in the graph store to CSC format, optionally storing the converted edge indices in the graph store.""" - return self._to_layout(EdgeLayout.CSC, store) + return self._all_edges_to_layout(EdgeLayout.CSC, store) # Additional methods ###################################################### diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index d68b0d7ca880..ebb79dbc2d18 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -695,7 +695,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: out = self._node_store_dict.get(attr.group_name, None) if out: # Group name exists, handle index or create new attribute name: - val = getattr(out, attr.attr_name) + val = getattr(out, attr.attr_name, None) if val is not None: val[attr.index] = tensor else: @@ -754,6 +754,13 @@ def _put_edge_index(self, edge_index: EdgeTensorType, attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index) setattr(self[edge_attr.edge_type], attr_name, attr_val) + # Set edge attributes: + if not hasattr(self[edge_attr.edge_type], '_edge_attrs'): + self[edge_attr.edge_type]._edge_attrs = {} + + self[edge_attr.edge_type]._edge_attrs[ + edge_attr.layout.value] = edge_attr + key = self._to_canonical(edge_attr.edge_type) src, _, dst = key @@ -780,12 +787,30 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]: r"""Returns a list of `EdgeAttr` objects corresponding to the edge indices stored in `HeteroData` and their layouts.""" out = [] + added_attrs = set() + + # Check edges added via _put_edge_index: + for edge_type, _ in self.edge_items(): + if not hasattr(self[edge_type], '_edge_attrs'): + continue + edge_attrs = self[edge_type]._edge_attrs.values() + for attr in edge_attrs: + attr.size = self[edge_type].size() + added_attrs.add((attr.edge_type, attr.layout)) + out.extend(edge_attrs) + + # Check edges added through regular interface: + # TODO deprecate this and store edge attributes for all edges in + # EdgeStorage for edge_type, edge_store in self.edge_items(): for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items(): - if attr_name in edge_store: + # Don't double count: + if attr_name in edge_store and ((edge_type, layout) + not in added_attrs): out.append( EdgeAttr(edge_type=edge_type, layout=layout, size=self[edge_type].size())) + return out diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index ff3c0e7b9cfa..cdf89f97b473 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -95,9 +95,21 @@ def __init__( # TODO support `collect` on `FeatureStore` self.node_time_dict = None if time_attr is not None: - raise ValueError( - f"'time_attr' attribute not yet supported for " - f"'{data[0].__class__.__name__}' object") + # We need to obtain all features with 'attr_name=time_attr' + # from the feature store and store them in node_time_dict. To + # do so, we make an explicit feature store GET call here with + # the relevant 'TensorAttr's + time_attrs = [ + attr for attr in feature_store.get_all_tensor_attrs() + if attr.attr_name == time_attr + ] + for attr in time_attrs: + attr.index = None + time_tensors = feature_store.multi_get_tensor(time_attrs) + self.node_time_dict = { + time_attr.group_name: time_tensor + for time_attr, time_tensor in zip(time_attrs, time_tensors) + } # Obtain all node and edge metadata: node_attrs = feature_store.get_all_tensor_attrs() @@ -475,9 +487,12 @@ def to_index(tensor): if isinstance(input_nodes, Tensor): return None, to_index(input_nodes) + # Can't infer number of nodes from a group_name; need an attr_name if isinstance(input_nodes, str): - num_nodes = feature_store.get_tensor_size(input_nodes)[0] - return input_nodes, range(num_nodes) + raise NotImplementedError( + f"Cannot infer the number of nodes from a single string " + f"(got '{input_nodes}'). Please pass a more explicit " + f"representation. ") if isinstance(input_nodes, (list, tuple)): assert len(input_nodes) == 2 @@ -485,8 +500,10 @@ def to_index(tensor): node_type, input_nodes = input_nodes if input_nodes is None: - num_nodes = feature_store.get_tensor_size(input_nodes)[0] - return input_nodes[0], range(num_nodes) + raise NotImplementedError( + f"Cannot infer the number of nodes from a node type alone " + f"(got '{input_nodes}'). Please pass a more explicit " + f"representation. ") return node_type, to_index(input_nodes) assert isinstance(input_nodes, TensorAttr)