diff --git a/CHANGELOG.md b/CHANGELOG.md index 1592b78f4fda..377ad94d77ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124)) - Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117)) ### Changed +- Drop `SparseTensor` dependency in `GraphStore` ([#5517](https://github.com/pyg-team/pytorch_geometric/pull/5517)) - Replace `NeighborSampler` with `NeighborLoader` in the distributed sampling example ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6307)) - Fixed the filtering of node features in `transforms.RemoveIsolatedNodes` ([#6308](https://github.com/pyg-team/pytorch_geometric/pull/6308)) - Fixed a bug in `DimeNet` that causes a output dimension mismatch ([#6305](https://github.com/pyg-team/pytorch_geometric/pull/6305)) diff --git a/test/data/test_graph_store.py b/test/data/test_graph_store.py index 52b5cb7a13e3..1b2465b6c206 100644 --- a/test/data/test_graph_store.py +++ b/test/data/test_graph_store.py @@ -1,115 +1,75 @@ -from typing import List - import pytest import torch from torch_sparse import SparseTensor -from torch_geometric.data.graph_store import EdgeLayout from torch_geometric.testing import MyGraphStore -from torch_geometric.typing import OptTensor -from torch_geometric.utils.sort_edge_index import sort_edge_index def get_edge_index(num_src_nodes, num_dst_nodes, num_edges): row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long) col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long) - return torch.stack([row, col], dim=0) + return row, col def test_graph_store(): graph_store = MyGraphStore() - edge_index = torch.LongTensor([[0, 1], [1, 2]]) - adj = SparseTensor(row=edge_index[0], col=edge_index[1]) - - def assert_equal_tensor_tuple(expected, actual): - assert len(expected) == len(actual) - for i in range(len(expected)): - assert torch.equal(expected[i], actual[i]) - - # We put all three tensor types: COO, CSR, and CSC, and we get them back - # to confirm that `GraphStore` works as intended. - coo = adj.coo()[:-1] - csr = adj.csr()[:-1] - csc = adj.csc()[-2::-1] # (row, colptr) - - # Put: - graph_store['edge', EdgeLayout.COO] = coo - graph_store['edge', 'csr'] = csr - graph_store['edge', 'csc'] = csc - - # Get: - assert_equal_tensor_tuple(coo, graph_store['edge', 'coo']) - assert_equal_tensor_tuple(csr, graph_store['edge', 'csr']) - assert_equal_tensor_tuple(csc, graph_store['edge', 'csc']) - - # Get attrs: - edge_attrs = graph_store.get_all_edge_attrs() - assert len(edge_attrs) == 3 + + coo = torch.tensor([0, 1]), torch.tensor([1, 2]) + csr = torch.tensor([0, 1, 2]), torch.tensor([1, 2]) + csc = torch.tensor([0, 1]), torch.tensor([0, 0, 1, 2]) + + graph_store['edge_type', 'coo'] = coo + graph_store['edge_type', 'csr'] = csr + graph_store['edge_type', 'csc'] = csc + + assert torch.equal(graph_store['edge_type', 'coo'][0], coo[0]) + assert torch.equal(graph_store['edge_type', 'coo'][1], coo[1]) + assert torch.equal(graph_store['edge_type', 'csr'][0], csr[0]) + assert torch.equal(graph_store['edge_type', 'csr'][1], csr[1]) + assert torch.equal(graph_store['edge_type', 'csc'][0], csc[0]) + assert torch.equal(graph_store['edge_type', 'csc'][1], csc[1]) + + assert len(graph_store.get_all_edge_attrs()) == 3 with pytest.raises(KeyError): - _ = graph_store['edge_2', 'coo'] + graph_store['edge_type_2', 'coo'] def test_graph_store_conversion(): graph_store = MyGraphStore() - edge_index = get_edge_index(100, 100, 300) - edge_index = sort_edge_index(edge_index, sort_by_row=False) - adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(100, 100)) - - coo = (edge_index[0], edge_index[1]) - csr = adj.csr()[:2] - csc = adj.csc()[-2::-1] - # Put all edge indices: - graph_store.put_edge_index(edge_index=coo, edge_type=('v', '1', 'v'), - layout='coo', size=(100, 100), is_sorted=True) - graph_store.put_edge_index(edge_index=csr, edge_type=('v', '2', 'v'), - layout='csr', size=(100, 100)) - graph_store.put_edge_index(edge_index=csc, edge_type=('v', '3', 'v'), - layout='csc', size=(100, 100)) + coo = (row, col) = get_edge_index(100, 100, 300) + adj = SparseTensor(row=row, col=col, sparse_sizes=(100, 100)) + csr, csc = adj.csr()[:2], adj.csc()[:2][::-1] - def assert_edge_index_equal(expected: torch.Tensor, actual: torch.Tensor): - assert torch.equal(sort_edge_index(expected), sort_edge_index(actual)) + graph_store.put_edge_index(coo, ('v', '1', 'v'), 'coo', size=(100, 100)) + graph_store.put_edge_index(csr, ('v', '2', 'v'), 'csr', size=(100, 100)) + graph_store.put_edge_index(csc, ('v', '3', 'v'), 'csc', size=(100, 100)) # Convert to COO: row_dict, col_dict, perm_dict = graph_store.coo() assert len(row_dict) == len(col_dict) == len(perm_dict) == 3 - for key in row_dict.keys(): - actual = torch.stack((row_dict[key], col_dict[key])) - assert_edge_index_equal(actual, edge_index) - assert perm_dict[key] is None + for row, col, perm in zip(row_dict.values(), col_dict.values(), + perm_dict.values()): + assert torch.equal(row.sort()[0], coo[0].sort()[0]) + assert torch.equal(col.sort()[0], coo[1].sort()[0]) + assert perm is None # Convert to CSR: - rowptr_dict, col_dict, perm_dict = graph_store.csr() - assert len(rowptr_dict) == len(col_dict) == len(perm_dict) == 3 - for key in rowptr_dict: - assert torch.equal(rowptr_dict[key], csr[0]) - assert torch.equal(col_dict[key], csr[1]) - if key == ('v', '1', 'v'): - assert perm_dict[key] is not None + row_dict, col_dict, perm_dict = graph_store.csr() + assert len(row_dict) == len(col_dict) == len(perm_dict) == 3 + for row, col in zip(row_dict.values(), col_dict.values()): + assert torch.equal(row, csr[0]) + assert torch.equal(col.sort()[0], csr[1].sort()[0]) # Convert to CSC: - row_dict, colptr_dict, perm_dict = graph_store.csc() - assert len(row_dict) == len(colptr_dict) == len(perm_dict) == 3 - for key in row_dict: - assert torch.equal(row_dict[key], csc[0]) - assert torch.equal(colptr_dict[key], csc[1]) - assert perm_dict[key] is None + row_dict, col_dict, perm_dict = graph_store.csc() + assert len(row_dict) == len(col_dict) == len(perm_dict) == 3 + for row, col in zip(row_dict.values(), col_dict.values()): + assert torch.equal(row.sort()[0], csc[0].sort()[0]) + assert torch.equal(col, csc[1]) # Ensure that 'edge_types' parameters work as intended: - def _tensor_eq(expected: List[OptTensor], actual: List[OptTensor]): - for tensor_expected, tensor_actual in zip(expected, actual): - if tensor_expected is None or tensor_actual is None: - return tensor_actual == tensor_expected - return torch.equal(tensor_expected, tensor_actual) - - edge_types = [('v', '1', 'v'), ('v', '2', 'v')] - assert _tensor_eq( - list(graph_store.coo()[0].values())[:-1], - graph_store.coo(edge_types=edge_types)[0].values()) - assert _tensor_eq( - list(graph_store.csr()[0].values())[:-1], - graph_store.csr(edge_types=edge_types)[0].values()) - assert _tensor_eq( - list(graph_store.csc()[0].values())[:-1], - graph_store.csc(edge_types=edge_types)[0].values()) + out = graph_store.coo([('v', '1', 'v')]) + assert torch.equal(list(out[0].values())[0], coo[0]) + assert torch.equal(list(out[1].values())[0], coo[1]) diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 99a8a5847ba5..50ae260fd26c 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -21,12 +21,7 @@ from torch_geometric.data import EdgeAttr, FeatureStore, GraphStore, TensorAttr from torch_geometric.data.feature_store import _field_status -from torch_geometric.data.graph_store import ( - EDGE_LAYOUT_TO_ATTR_NAME, - EdgeLayout, - adj_type_to_edge_tensor_type, - edge_tensor_type_to_adj_type, -) +from torch_geometric.data.graph_store import EdgeLayout from torch_geometric.data.storage import ( BaseStorage, EdgeStorage, @@ -340,26 +335,28 @@ def contains_self_loops(self) -> bool: @dataclass class DataTensorAttr(TensorAttr): - r"""Attribute class for `Data`, which does not require a `group_name`.""" - def __init__(self, attr_name=_field_status.UNSET, - index=_field_status.UNSET): - # Treat group_name as optional, and move it to the end + r"""Tensor attribute for `Data` without group name.""" + def __init__( + self, + attr_name=_field_status.UNSET, + index=None, + ): super().__init__(None, attr_name, index) @dataclass class DataEdgeAttr(EdgeAttr): - r"""Edge attribute class for `Data`, which does not require a - `edge_type`.""" + r"""Edge attribute class for `Data` without edge type.""" def __init__( self, layout: Optional[EdgeLayout] = None, is_sorted: bool = False, size: Optional[Tuple[int, int]] = None, - edge_type: EdgeType = None, ): - # Treat edge_type as optional, and move it to the end - super().__init__(edge_type, layout, is_sorted, size) + super().__init__(None, layout, is_sorted, size) + + +############################################################################### class Data(BaseData, FeatureStore, GraphStore): @@ -834,27 +831,16 @@ def num_faces(self) -> Optional[int]: # FeatureStore interface ################################################## - def items(self): - r"""Returns an `ItemsView` over the stored attributes in the `Data` - object.""" - # NOTE this is necessary to override the default `MutableMapping` - # items() method. - return self._store.items() - def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: - r"""Stores a feature tensor in node storage.""" - out = getattr(self, attr.attr_name, None) + out = self.get(attr.attr_name) if out is not None and attr.index is not None: - # Attr name exists, handle index: out[attr.index] = tensor else: - # No attr name (or None index), just store tensor: + assert attr.index is None setattr(self, attr.attr_name, tensor) return True def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: - r"""Obtains a feature tensor from node storage.""" - # Retrieve tensor and index accordingly: tensor = getattr(self, attr.attr_name, None) if tensor is not None: # TODO this behavior is a bit odd, since TensorAttr requires that @@ -865,15 +851,12 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: return None def _remove_tensor(self, attr: TensorAttr) -> bool: - r"""Deletes a feature tensor from node storage.""" - # Remove tensor entirely: if hasattr(self, attr.attr_name): delattr(self, attr.attr_name) return True return False def _get_tensor_size(self, attr: TensorAttr) -> Tuple: - r"""Returns the size of the tensor corresponding to `attr`.""" return self._get_tensor(attr).size() def get_all_tensor_attrs(self) -> List[TensorAttr]: @@ -883,70 +866,80 @@ def get_all_tensor_attrs(self) -> List[TensorAttr]: if self._store.is_node_attr(name) ] - def __len__(self) -> int: - return BaseData.__len__(self) - # GraphStore interface #################################################### def _put_edge_index(self, edge_index: EdgeTensorType, edge_attr: EdgeAttr) -> bool: - r"""Stores `edge_index` in `Data`, in the specified layout.""" - - # Convert the edge index to a recognizable layout: - attr_name = EDGE_LAYOUT_TO_ATTR_NAME[edge_attr.layout] - attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index) - setattr(self, attr_name, attr_val) - - # Set edge attributes: if not hasattr(self, '_edge_attrs'): self._edge_attrs = {} - - self._edge_attrs[edge_attr.layout.value] = edge_attr - - # Set size, if possible: - size = edge_attr.size - if size is not None: - if size[0] != size[1]: - raise ValueError( - f"'Data' requires size[0] == size[1], but received " - f"the tuple {size}.") - self.num_nodes = size[0] + self._edge_attrs[edge_attr.layout] = edge_attr + + row, col = edge_index + + if edge_attr.layout == EdgeLayout.COO: + self.edge_index = torch.stack([row, col], dim=0) + elif edge_attr.layout == EdgeLayout.CSR: + self.adj = SparseTensor( + rowptr=row, + col=col, + sparse_sizes=edge_attr.size, + is_sorted=True, + trust_data=True, + ) + else: # edge_attr.layout == EdgeLayout.CSC: + size = edge_attr.size[::-1] if edge_attr.size is not None else None + self.adj_t = SparseTensor( + rowptr=col, + col=row, + sparse_sizes=size, + is_sorted=True, + trust_data=True, + ) return True def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: - r"""Obtains the edge index corresponding to `edge_attr` in `Data`, - in the specified layout.""" - # Get the requested layout and the edge tensor type associated with it: - attr_name = EDGE_LAYOUT_TO_ATTR_NAME[edge_attr.layout] - attr_val = getattr(self._store, attr_name, None) - if attr_val is not None: - # Convert from Adj type to Tuple[Tensor, Tensor] - attr_val = adj_type_to_edge_tensor_type(edge_attr.layout, attr_val) - return attr_val + if edge_attr.layout == EdgeLayout.COO and 'edge_index' in self: + row, col = self.edge_index + return row, col + elif edge_attr.layout == EdgeLayout.CSR and 'adj' in self: + rowptr, col, _ = self.adj.csr() + return rowptr, col + elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in self: + colptr, row, _ = self.adj_t.csr() + return row, colptr + return None + + def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool: + if edge_attr.layout == EdgeLayout.COO and 'edge_index' in self: + del self.edge_index + if hasattr(self, '_edge_attrs'): + self._edges_to_layout.pop(EdgeLayout.COO, None) + return True + elif edge_attr.layout == EdgeLayout.CSR and 'adj' in self: + del self.adj + if hasattr(self, '_edge_attrs'): + self._edges_to_layout.pop(EdgeLayout.CSR, None) + return True + elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in self: + del self.adj_t + if hasattr(self, '_edge_attrs'): + self._edges_to_layout.pop(EdgeLayout.CSC, None) + return True + return False def get_all_edge_attrs(self) -> List[EdgeAttr]: - r"""Returns `EdgeAttr` objects corresponding to the edge indices stored - in `Data` and their layouts""" - if not hasattr(self, '_edge_attrs'): - return [] - added_attrs = set() - - # Check edges added via _put_edge_index: - edge_attrs = self._edge_attrs.values() - for attr in edge_attrs: - attr.size = (self.num_nodes, self.num_nodes) - added_attrs.add(attr.layout) - - # Check edges added through regular interface: - # TODO deprecate this and store edge attributes for all edges in - # EdgeStorage - for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items(): - if attr_name in self and layout not in added_attrs: - edge_attrs.append( - EdgeAttr(edge_type=None, layout=layout, - size=(self.num_nodes, self.num_nodes))) - - return edge_attrs + edge_attrs = getattr(self, '_edge_attrs', {}) + + if 'edge_index' in self and EdgeLayout.COO not in edge_attrs: + edge_attrs[EdgeLayout.COO] = DataEdgeAttr('coo', is_sorted=False) + if 'adj' in self and EdgeLayout.CSR not in edge_attrs: + size = self.adj.sparse_sizes() + edge_attrs[EdgeLayout.CSR] = DataEdgeAttr('csr', size=size) + if 'adj_t' in self and EdgeLayout.CSC not in edge_attrs: + size = self.adj_t.sparse_sizes()[::-1] + edge_attrs[EdgeLayout.CSC] = DataEdgeAttr('csc', size=size) + + return list(edge_attrs.values()) ############################################################################### diff --git a/torch_geometric/data/graph_store.py b/torch_geometric/data/graph_store.py index d16a45a072b2..7c9c3e3ae31f 100644 --- a/torch_geometric/data/graph_store.py +++ b/torch_geometric/data/graph_store.py @@ -16,12 +16,8 @@ concatenate all metadata values with an edge index and use this as a unique index in a KV store. More complicated implementations may choose to partition the graph in interesting manners based on the provided metadata. - -Major TODOs for future implementation: -* `sample` behind the graph store interface """ import copy -import warnings from abc import abstractmethod from collections import defaultdict from dataclasses import dataclass @@ -31,13 +27,7 @@ import torch from torch import Tensor -from torch_geometric.typing import ( - Adj, - EdgeTensorType, - EdgeType, - OptTensor, - SparseTensor, -) +from torch_geometric.typing import EdgeTensorType, EdgeType, OptTensor from torch_geometric.utils.mixin import CastMixin # The output of converting between two types in the GraphStore is a Tuple of @@ -49,8 +39,11 @@ # CSR, or the col pointer for CSC # * The perm dictionary contains the permutation of edges that was applied # in converting between formats, if applicable. -ConversionOutputType = Tuple[Dict[str, Tensor], Dict[str, Tensor], - Dict[str, OptTensor]] +ConversionOutputType = Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor], + Dict[EdgeType, OptTensor]] + +ptr2ind = torch.ops.torch_sparse.ptr2ind +ind2ptr = torch.ops.torch_sparse.ind2ptr class EdgeLayout(Enum): @@ -71,30 +64,39 @@ class EdgeAttr(CastMixin): :meth:`EdgeAttr.__init__`. """ - # The type of the edge - edge_type: Optional[EdgeType] + # The type of the edge: + edge_type: EdgeType - # The layout of the edge representation - layout: Optional[EdgeLayout] = None + # The layout of the edge representation: + layout: EdgeLayout - # Whether the edge index is sorted, by destination node. Useful for + # Whether the edge index is sorted by destination node. Useful for # avoiding sorting costs when performing neighbor sampling, and only - # meaningful for COO (CSC and CSR are sorted by definition) + # meaningful for COO (CSC is sorted and CSR is not sorted by definition): is_sorted: bool = False - # The number of source and destination nodes in this edge type + # The number of source and destination nodes in this edge type: size: Optional[Tuple[int, int]] = None # NOTE we define __init__ to force-cast layout def __init__( self, - edge_type: Optional[Any], - layout: Optional[EdgeLayout] = None, + edge_type: EdgeType, + layout: EdgeLayout, is_sorted: bool = False, size: Optional[Tuple[int, int]] = None, ): + layout = EdgeLayout(layout) + + if layout == EdgeLayout.CSR and is_sorted: + raise ValueError("Cannot create a 'CSR' edge attribute with " + "option 'is_sorted=True'") + + if layout == EdgeLayout.CSC: + is_sorted = True + self.edge_type = edge_type - self.layout = EdgeLayout(layout) if layout else None + self.layout = layout self.is_sorted = is_sorted self.size = size @@ -134,13 +136,6 @@ def put_edge_index(self, edge_index: EdgeTensorType, *args, attributes. """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) - assert edge_attr.layout is not None - edge_attr.layout = EdgeLayout(edge_attr.layout) - - # Override is_sorted for CSC and CSR: - edge_attr.is_sorted = edge_attr.is_sorted or (edge_attr.layout in [ - EdgeLayout.CSC, EdgeLayout.CSR - ]) return self._put_edge_index(edge_index, edge_attr) @abstractmethod @@ -163,158 +158,36 @@ def get_edge_index(self, *args, **kwargs) -> EdgeTensorType: :class:`EdgeAttr` was not found. """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) - assert edge_attr.layout is not None - edge_attr.layout = EdgeLayout(edge_attr.layout) - # Override is_sorted for CSC and CSR: - # TODO treat is_sorted specially in this function, where is_sorted=True - # returns an edge index sorted by column. - edge_attr.is_sorted = edge_attr.is_sorted or (edge_attr.layout in [ - EdgeLayout.CSC, EdgeLayout.CSR - ]) edge_index = self._get_edge_index(edge_attr) if edge_index is None: - raise KeyError(f"An edge corresponding to '{edge_attr}' was not " - f"found") + raise KeyError(f"'edge_index' for '{edge_attr}' not found") return edge_index @abstractmethod - def get_all_edge_attrs(self) -> List[EdgeAttr]: - r"""Obtains all edge attributes stored in the :class:`GraphStore`.""" + def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool: + r"""To be implemented by :class:`GraphStore` subclasses.""" pass - # Layout Conversion ####################################################### - - def _edge_to_layout( - self, - attr: EdgeAttr, - layout: EdgeLayout, - ) -> Tuple[Tensor, Tensor, OptTensor]: - r"""Converts an :obj:`edge_index` tuple in the :class:`GraphStore` to - the desired output layout by fetching the :obj:`edge_index` and - performing in-memory conversion.""" - import torch_sparse # noqa - from_tuple = self.get_edge_index(attr) - - if layout == EdgeLayout.COO: - if attr.layout == EdgeLayout.CSR: - col = from_tuple[1] - row = torch.ops.torch_sparse.ptr2ind(from_tuple[0], - col.numel()) - else: - row = from_tuple[0] - col = torch.ops.torch_sparse.ptr2ind(from_tuple[1], - row.numel()) - perm = None - - elif layout == EdgeLayout.CSR: - # We convert to CSR by converting to CSC on the transpose - if attr.layout == EdgeLayout.COO: - adj = edge_tensor_type_to_adj_type( - attr, (from_tuple[1], from_tuple[0])) - else: - adj = edge_tensor_type_to_adj_type(attr, from_tuple).t() - - # NOTE we set is_sorted=False here as is_sorted refers to - # the edge_index being sorted by the destination node - # (column), but here we deal with the transpose - attr_copy = copy.copy(attr) - attr_copy.is_sorted = False - attr_copy.size = None if attr.size is None else (attr.size[1], - attr.size[0]) - - # Actually rowptr, col, perm - row, col, perm = to_csc(adj, attr_copy, device='cpu') - - else: - adj = edge_tensor_type_to_adj_type(attr, from_tuple) - - # Actually colptr, row, perm - col, row, perm = to_csc(adj, attr, device='cpu') + def remove_edge_index(self, *args, **kwargs) -> bool: + r"""Synchronously deletes an :obj:`edge_index` tuple from the + :class:`GraphStore`. + Returns whether deletion was successful. - return row, col, perm + Args: + **kwargs (EdgeAttr): Any relevant edge attributes that + correspond to the :obj:`edge_index` tuple. See the + :class:`EdgeAttr` documentation for required and optional + attributes. + """ + edge_attr = self._edge_attr_cls.cast(*args, **kwargs) + return self._remove_edge_index(edge_attr) - # TODO support `replace` to replace the existing edge index. - def _all_edges_to_layout( - self, - layout: EdgeLayout, - edge_types: Optional[List[Any]] = None, - store: bool = False, - ) -> ConversionOutputType: - r"""Converts all edge attributes in the graph store to the desired - layout, by fetching all edge indices and performing conversion on - the caller instance. Implementations that support conversion within - the graph store can override this method.""" - # Obtain all edge attributes, grouped by type: - all_edge_attrs = self.get_all_edge_attrs() - edge_type_to_attrs: Dict[Any, List[EdgeAttr]] = defaultdict(list) - for attr in all_edge_attrs: - edge_type_to_attrs[attr.edge_type].append(attr) - - # Edge types to convert: - edge_types = edge_types or [attr.edge_type for attr in all_edge_attrs] - for edge_type in edge_types: - if edge_type not in edge_type_to_attrs: - raise ValueError( - f"The edge index {edge_type} was not found in the graph " - f"store.") - - # Convert layouts for each attribute from its most favorable original - # layout to the desired layout. Store permutations of edges if - # necessary as part of the conversion: - row_dict, col_dict, perm_dict = {}, {}, {} - for edge_type in edge_types: - edge_type_attrs = edge_type_to_attrs[edge_type] - edge_type_layouts = [attr.layout for attr in edge_type_attrs] - - # Ignore if requested layout is already present: - if layout in edge_type_layouts: - from_attr = edge_type_attrs[edge_type_layouts.index(layout)] - row, col = self.get_edge_index(from_attr) - perm = None - - # Convert otherwise: - else: - # Pick the most favorable layout to convert from. We prefer - # COO to CSC/CSR: - from_attr = None - if EdgeLayout.COO in edge_type_layouts: - from_attr = edge_type_attrs[edge_type_layouts.index( - EdgeLayout.COO)] - elif EdgeLayout.CSC in edge_type_layouts: - from_attr = edge_type_attrs[edge_type_layouts.index( - EdgeLayout.CSC)] - else: - from_attr = edge_type_attrs[edge_type_layouts.index( - EdgeLayout.CSR)] - - row, col, perm = self._edge_to_layout(from_attr, layout) - - row_dict[from_attr.edge_type] = row - col_dict[from_attr.edge_type] = col - perm_dict[from_attr.edge_type] = perm - - if store and layout not in edge_type_layouts: - # We do not store converted edge indices if this conversion - # results in a permutation of nodes in the original edge index. - # This is to exercise an abundance of caution in the case that - # there are edge attributes. - if perm is not None: - warnings.warn(f"Edge index {from_attr.edge_type} with " - f"layout {from_attr.layout} was not sorted " - f"by destination node, so conversion to " - f"{layout} resulted in a permutation of " - f"the order of edges. As a result, the " - f"converted edge is not being re-stored in " - f"the graph store. Please sort the edge " - f"index and set 'is_sorted=True' to avoid " - f"this warning.") - else: - is_sorted = (layout != EdgeLayout.COO) - self.put_edge_index((row, col), - EdgeAttr(from_attr.edge_type, layout, - is_sorted, from_attr.size)) + @abstractmethod + def get_all_edge_attrs(self) -> List[EdgeAttr]: + r"""Obtains all edge attributes stored in the :class:`GraphStore`.""" + pass - return row_dict, col_dict, perm_dict + # Layout Conversion ####################################################### def coo( self, @@ -331,7 +204,7 @@ def coo( store (bool, optional): Whether to store converted edge indices in the :class:`GraphStore`. (default: :obj:`False`) """ - return self._all_edges_to_layout(EdgeLayout.COO, edge_types, store) + return self._edges_to_layout(EdgeLayout.COO, edge_types, store) def csr( self, @@ -348,7 +221,7 @@ def csr( store (bool, optional): Whether to store converted edge indices in the :class:`GraphStore`. (default: :obj:`False`) """ - return self._all_edges_to_layout(EdgeLayout.CSR, edge_types, store) + return self._edges_to_layout(EdgeLayout.CSR, edge_types, store) def csc( self, @@ -365,128 +238,108 @@ def csc( store (bool, optional): Whether to store converted edge indices in the :class:`GraphStore`. (default: :obj:`False`) """ - return self._all_edges_to_layout(EdgeLayout.CSC, edge_types, store) + return self._edges_to_layout(EdgeLayout.CSC, edge_types, store) # Python built-ins ######################################################## def __setitem__(self, key: EdgeAttr, value: EdgeTensorType): - key = self._edge_attr_cls.cast(key) self.put_edge_index(value, key) def __getitem__(self, key: EdgeAttr) -> Optional[EdgeTensorType]: - key = self._edge_attr_cls.cast(key) return self.get_edge_index(key) + def __delitem__(self, key: EdgeAttr): + return self.remove_edge_index(key) + def __repr__(self) -> str: return f'{self.__class__.__name__}()' + # Helper methods ########################################################## + + def _edge_to_layout( + self, + attr: EdgeAttr, + layout: EdgeLayout, + store: bool = False, + ) -> Tuple[Tensor, Tensor, OptTensor]: -# Data and HeteroData utilities ############################################### - -EDGE_LAYOUT_TO_ATTR_NAME = { - EdgeLayout.COO: 'edge_index', - EdgeLayout.CSR: 'adj', - EdgeLayout.CSC: 'adj_t', -} - - -def edge_tensor_type_to_adj_type( - attr: EdgeAttr, - tensor_tuple: EdgeTensorType, -) -> Adj: - r"""Converts an EdgeTensorType tensor tuple to a PyG Adj tensor.""" - src, dst = tensor_tuple - - if attr.layout == EdgeLayout.COO: # COO: (row, col) - assert src.dim() == 1 and dst.dim() == 1 and src.numel() == dst.numel() - - if src.numel() == 0: - return torch.empty((2, 0), dtype=torch.long, device=src.device) - - if (src[0].storage().data_ptr() == dst[1].storage().data_ptr() - and src.storage_offset() < dst.storage_offset()): - # Do not copy if the tensor tuple is constructed from the same - # storage (instead, return a view): - out = torch.empty(0, dtype=src.dtype) - out.set_(src.storage(), storage_offset=src.storage_offset(), - size=(src.size()[0] + dst.size()[0], )) - return out.view(2, -1) - - return torch.stack([src, dst], dim=0) - - elif attr.layout == EdgeLayout.CSR: # CSR: (rowptr, col) - return SparseTensor(rowptr=src, col=dst, is_sorted=True, - sparse_sizes=attr.size) - - elif attr.layout == EdgeLayout.CSC: # CSC: (row, colptr) - # CSC is a transposed adjacency matrix, so rowptr is the compressed - # column and col is the uncompressed row. - sparse_sizes = None if attr.size is None else (attr.size[1], - attr.size[0]) - return SparseTensor(rowptr=dst, col=src, is_sorted=True, - sparse_sizes=sparse_sizes) - raise ValueError(f"Bad edge layout (got '{attr.layout}')") - - -def adj_type_to_edge_tensor_type(layout: EdgeLayout, - edge_index: Adj) -> EdgeTensorType: - r"""Converts a PyG Adj tensor to an EdgeTensorType equivalent.""" - if isinstance(edge_index, Tensor): - return (edge_index[0], edge_index[1]) # (row, col) - if layout == EdgeLayout.COO: - return edge_index.coo()[:-1] # (row, col) - elif layout == EdgeLayout.CSR: - return edge_index.csr()[:-1] # (rowptr, col) - else: - return edge_index.csr()[-2::-1] # (row, colptr) - - -############################################################################### - - -def to_csc( - adj: Adj, - edge_attr: EdgeAttr, - device: Optional[torch.device] = None, - share_memory: bool = False, -) -> Tuple[Tensor, Tensor, OptTensor]: - import torch_sparse # noqa - - # Convert the graph data into a suitable format for sampling (CSC format). - # Returns the `colptr` and `row` indices of the graph, as well as an - # `perm` vector that denotes the permutation of edges. - # Since no permutation of edges is applied when using `SparseTensor`, - # `perm` can be of type `None`. - perm: Optional[Tensor] = None - layout = edge_attr.layout - is_sorted = edge_attr.is_sorted - size = edge_attr.size - - if layout == EdgeLayout.CSR: - colptr, row, _ = adj.csc() - elif layout == EdgeLayout.CSC: - colptr, row, _ = adj.csr() - else: - if size is None: - raise ValueError( - f"Edge {edge_attr.edge_type} cannot be converted " - f"to a different type without specifying 'size' for " - f"the source and destination node types (got {size}). " - f"Please specify these parameters for successful execution.") - (row, col) = adj - if not is_sorted: - perm = (col * size[0]).add_(row).argsort() - row = row[perm] - colptr = torch.ops.torch_sparse.ind2ptr(col[perm], size[1]) - - colptr = colptr.to(device) - row = row.to(device) - perm = perm.to(device) if perm is not None else None - - if not colptr.is_cuda and share_memory: - colptr.share_memory_() - row.share_memory_() - if perm is not None: - perm.share_memory_() - - return colptr, row, perm + (row, col), perm = self.get_edge_index(attr), None + + if layout == EdgeLayout.COO: # COO output requested: + if attr.layout == EdgeLayout.CSR: # CSR->COO + row = ptr2ind(row, col.numel()) + elif attr.layout == EdgeLayout.CSC: # CSC->COO + col = ptr2ind(col, row.numel()) + + elif layout == EdgeLayout.CSR: # CSR output requested: + if attr.layout == EdgeLayout.CSC: # CSC->COO + col = ptr2ind(col, row.numel()) + + if attr.layout != EdgeLayout.CSR: # COO->CSR + row, perm = row.sort() # Cannot be sorted by destination. + col = col[perm] + num_rows = attr.size[0] if attr.size else int(row.max()) + 1 + row = ind2ptr(row, num_rows) + + else: # CSC output requested: + if attr.layout == EdgeLayout.CSR: # CSR->COO + row = ptr2ind(row, col.numel()) + + if attr.layout != EdgeLayout.CSC: # COO->CSC + if not attr.is_sorted: # Not sorted by destination. + col, perm = col.sort() + row = row[perm] + num_cols = attr.size[1] if attr.size else int(col.max()) + 1 + col = ind2ptr(col, num_cols) + + if attr.layout != layout and store: + attr = copy.copy(attr) + attr.layout = layout + if perm is not None: + attr.is_sorted = False + self.put_edge_index((row, col), attr) + + return row, col, perm + + def _edges_to_layout( + self, + layout: EdgeLayout, + edge_types: Optional[List[Any]] = None, + store: bool = False, + ) -> ConversionOutputType: + + # Obtain all edge attributes, grouped by type: + edge_type_attrs: Dict[EdgeType, List[EdgeAttr]] = defaultdict(list) + for attr in self.get_all_edge_attrs(): + edge_type_attrs[attr.edge_type].append(attr) + + # Check that requested edge types exist and filter: + if edge_types is not None: + for edge_type in edge_types: + if edge_type not in edge_type_attrs: + raise ValueError(f"The 'edge_index' of type '{edge_type}' " + f"was not found in the graph store.") + + edge_type_attrs = { + key: attr + for key, attr in edge_type_attrs.items() if key in edge_types + } + + # Convert layout from its most favorable original layout: + row_dict, col_dict, perm_dict = {}, {}, {} + for edge_type, attrs in edge_type_attrs.items(): + layouts = [attr.layout for attr in attrs] + + if layout in layouts: # No conversion needed. + attr = attrs[layouts.index(layout)] + elif EdgeLayout.COO in layouts: # Prefer COO for conversion. + attr = attrs[layouts.index(EdgeLayout.COO)] + elif EdgeLayout.CSC in layouts: + attr = attrs[layouts.index(EdgeLayout.CSC)] + elif EdgeLayout.CSR in layouts: + attr = attrs[layouts.index(EdgeLayout.CSR)] + + row_dict[edge_type], col_dict[edge_type], perm_dict[edge_type] = ( + self._edge_to_layout(attr, layout, store)) + + return row_dict, col_dict, perm_dict diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index 69b88ea1519b..46aa5ad4194a 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -11,11 +11,7 @@ from torch_geometric.data import EdgeAttr, FeatureStore, GraphStore, TensorAttr from torch_geometric.data.data import BaseData, Data, size_repr, warn_or_raise -from torch_geometric.data.graph_store import ( - EDGE_LAYOUT_TO_ATTR_NAME, - adj_type_to_edge_tensor_type, - edge_tensor_type_to_adj_type, -) +from torch_geometric.data.graph_store import EdgeLayout from torch_geometric.data.storage import BaseStorage, EdgeStorage, NodeStorage from torch_geometric.typing import ( EdgeTensorType, @@ -834,7 +830,6 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]: # FeatureStore interface ################################################## def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: - r"""Stores a feature tensor in node storage.""" if not attr.is_set('index'): attr.index = None @@ -845,6 +840,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: if val is not None: val[attr.index] = tensor else: + assert attr.index is None setattr(self[attr.group_name], attr.attr_name, tensor) else: # No node storage found, just store tensor in new one: @@ -852,7 +848,6 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: return True def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: - r"""Obtains a feature tensor from node storage.""" # Retrieve tensor and index accordingly: tensor = getattr(self[attr.group_name], attr.attr_name, None) if tensor is not None: @@ -864,7 +859,6 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: return None def _remove_tensor(self, attr: TensorAttr) -> bool: - r"""Deletes a feature tensor from node storage.""" # Remove tensor entirely: if hasattr(self[attr.group_name], attr.attr_name): delattr(self[attr.group_name], attr.attr_name) @@ -872,7 +866,6 @@ def _remove_tensor(self, attr: TensorAttr) -> bool: return False def _get_tensor_size(self, attr: TensorAttr) -> Tuple: - r"""Returns the size of the tensor corresponding to `attr`.""" return self._get_tensor(attr).size() def get_all_tensor_attrs(self) -> List[TensorAttr]: @@ -883,81 +876,92 @@ def get_all_tensor_attrs(self) -> List[TensorAttr]: out.append(TensorAttr(group_name, attr_name)) return out - def __len__(self) -> int: - return BaseData.__len__(self) - - def __iter__(self): - raise NotImplementedError - # GraphStore interface #################################################### def _put_edge_index(self, edge_index: EdgeTensorType, edge_attr: EdgeAttr) -> bool: - r"""Stores an edge index in edge storage, in the specified layout.""" - - # Convert the edge index to a recognizable layout: - attr_name = EDGE_LAYOUT_TO_ATTR_NAME[edge_attr.layout] - attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index) - setattr(self[edge_attr.edge_type], attr_name, attr_val) - - # Set edge attributes: - if not hasattr(self[edge_attr.edge_type], '_edge_attrs'): - self[edge_attr.edge_type]._edge_attrs = {} - - self[edge_attr.edge_type]._edge_attrs[ - edge_attr.layout.value] = edge_attr - - key = self._to_canonical(edge_attr.edge_type) - src, _, dst = key - - # Handle num_nodes, if possible: - size = edge_attr.size - if size is not None: - # TODO better warning in the case of overwriting 'num_nodes' - self[src].num_nodes = size[0] - self[dst].num_nodes = size[1] - + if not hasattr(self, '_edge_attrs'): + self._edge_attrs = {} + self._edge_attrs[(edge_attr.edge_type, edge_attr.layout)] = edge_attr + + row, col = edge_index + store = self[edge_attr.edge_type] + + if edge_attr.layout == EdgeLayout.COO: + store.edge_index = torch.stack([row, col], dim=0) + elif edge_attr.layout == EdgeLayout.CSR: + store.adj = SparseTensor( + rowptr=row, + col=col, + sparse_sizes=edge_attr.size, + is_sorted=True, + trust_data=True, + ) + else: # edge_attr.layout == EdgeLayout.CSC: + size = edge_attr.size[::-1] if edge_attr.size is not None else None + store.adj_t = SparseTensor( + rowptr=col, + col=row, + sparse_sizes=size, + is_sorted=True, + trust_data=True, + ) return True def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: r"""Gets an edge index from edge storage, in the specified layout.""" - # Get the requested layout and the Adj tensor associated with it: - attr_name = EDGE_LAYOUT_TO_ATTR_NAME[edge_attr.layout] - attr_val = getattr(self[edge_attr.edge_type], attr_name, None) - if attr_val is not None: - # Convert from Adj type to Tuple[Tensor, Tensor] - attr_val = adj_type_to_edge_tensor_type(edge_attr.layout, attr_val) - return attr_val - - def get_all_edge_attrs(self) -> List[EdgeAttr]: - r"""Returns a list of `EdgeAttr` objects corresponding to the edge - indices stored in `HeteroData` and their layouts.""" - out = [] - added_attrs = set() + store = self[edge_attr.edge_type] + if edge_attr.layout == EdgeLayout.COO and 'edge_index' in store: + row, col = store.edge_index + return row, col + elif edge_attr.layout == EdgeLayout.CSR and 'adj' in store: + rowptr, col, _ = store.adj.csr() + return rowptr, col + elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in store: + colptr, row, _ = store.adj_t.csr() + return row, colptr + return None - # Check edges added via _put_edge_index: - for edge_type, _ in self.edge_items(): - if not hasattr(self[edge_type], '_edge_attrs'): - continue - edge_attrs = self[edge_type]._edge_attrs.values() - for attr in edge_attrs: - attr.size = self[edge_type].size() - added_attrs.add((attr.edge_type, attr.layout)) - out.extend(edge_attrs) - - # Check edges added through regular interface: - # TODO deprecate this and store edge attributes for all edges in - # EdgeStorage - for edge_type, edge_store in self.edge_items(): - for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items(): - # Don't double count: - if attr_name in edge_store and ((edge_type, layout) - not in added_attrs): - out.append( - EdgeAttr(edge_type=edge_type, layout=layout, - size=self[edge_type].size())) + def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool: + edge_type = edge_attr.edge_type + store = self[edge_type] + if edge_attr.layout == EdgeLayout.COO and 'edge_index' in store: + del store.edge_index + if hasattr(self, '_edge_attrs'): + self._edges_to_layout.pop((edge_type, EdgeLayout.COO), None) + return True + elif edge_attr.layout == EdgeLayout.CSR and 'adj' in store: + del store.adj + if hasattr(self, '_edge_attrs'): + self._edges_to_layout.pop((edge_type, EdgeLayout.CSR), None) + return True + elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in store: + del store.adj_t + if hasattr(self, '_edge_attrs'): + self._edges_to_layout.pop((edge_type, EdgeLayout.CSC), None) + return True + return False - return out + def get_all_edge_attrs(self) -> List[EdgeAttr]: + edge_attrs = getattr(self, '_edge_attrs', {}) + + for store in self.edge_stores: + if ('edge_index' in store + and (store._key, EdgeLayout.COO) not in edge_attrs): + edge_attrs[(store._key, EdgeLayout.COO)] = EdgeAttr( + store._key, 'coo', is_sorted=False) + if ('adj' in store + and (store._key, EdgeLayout.CSR) not in edge_attrs): + size = store.adj.sparse_sizes() + edge_attrs[(store._key, EdgeLayout.CSR)] = EdgeAttr( + store._key, 'csr', size=size) + if ('adj_t' in store + and (store._key, EdgeLayout.CSC) not in edge_attrs): + size = store.adj_t.sparse_sizes()[::-1] + edge_attrs[(store._key, EdgeLayout.CSC)] = EdgeAttr( + store._key, 'csc', size=size) + + return list(edge_attrs.values()) # Helper functions ############################################################ diff --git a/torch_geometric/testing/graph_store.py b/torch_geometric/testing/graph_store.py index fb371759ba5f..039f9d44eb8a 100644 --- a/torch_geometric/testing/graph_store.py +++ b/torch_geometric/testing/graph_store.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple from torch import Tensor @@ -9,18 +9,21 @@ class MyGraphStore(GraphStore): def __init__(self): super().__init__() - self.store: Dict[EdgeAttr, Tuple[Tensor, Tensor]] = {} + self.store: Dict[Tuple, Tuple[Tensor, Tensor]] = {} @staticmethod - def key(attr: EdgeAttr) -> str: + def key(attr: EdgeAttr) -> Tuple: return (attr.edge_type, attr.layout.value, attr.is_sorted, attr.size) def _put_edge_index(self, edge_index: EdgeTensorType, edge_attr: EdgeAttr) -> bool: - self.store[MyGraphStore.key(edge_attr)] = edge_index + self.store[self.key(edge_attr)] = edge_index def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: - return self.store.get(MyGraphStore.key(edge_attr), None) + return self.store.get(self.key(edge_attr), None) - def get_all_edge_attrs(self): - return [EdgeAttr(*key) for key in self.store] + def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool: + return self.store.pop(self.key(edge_attr), None) is not None + + def get_all_edge_attrs(self) -> List[EdgeAttr]: + return [EdgeAttr(*key) for key in self.store.keys()]