From e46fbb18b004ffc2ae5aece41c039cccbd8e0159 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Fri, 24 Jun 2022 00:44:15 -0700 Subject: [PATCH] `FeatureStore.multi_get_tensor` implementation (#4853) * init * CHANGELOG * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * docstring update * update interface * None check * better errors * comments * typo Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s --- CHANGELOG.md | 2 +- torch_geometric/data/feature_store.py | 77 +++++++++++++++++++----- torch_geometric/loader/utils.py | 17 ++++-- torch_geometric/testing/feature_store.py | 8 ++- 4 files changed, 82 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51e41cfe27d7..3f391c343056 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850)) - Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838)) - Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816)) -- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807)) +- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807), [#4853](https://github.com/pyg-team/pytorch_geometric/pull/4853)) - Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827)) - Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825)) - Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805)) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index c1a30596f791..3c616e6e3ca5 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -280,6 +280,17 @@ def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool: f"specifying all 'UNSET' fields") return self._put_tensor(tensor, attr) + @staticmethod + def _to_type(attr: TensorAttr, + tensor: FeatureTensorType) -> FeatureTensorType: + if (isinstance(attr.index, torch.Tensor) + and isinstance(tensor, np.ndarray)): + return torch.from_numpy(tensor) + if (isinstance(attr.index, np.ndarray) + and isinstance(tensor, torch.Tensor)): + return tensor.detach().cpu().numpy() + return tensor + @abstractmethod def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: r"""To be implemented by :class:`FeatureStore` subclasses.""" @@ -299,27 +310,14 @@ def get_tensor(self, *args, **kwargs) -> FeatureTensorType: from a :class:`TensorAttr` object. Returns: - FeatureTensorType: a Tensor of the same type as the index, or - :obj:`None` if no tensor was found. + FeatureTensorType: a Tensor of the same type as the index. Raises: KeyError: if the tensor corresponding to attr was not found. ValueError: if the input `TensorAttr` is not fully specified. """ - def to_type(tensor: FeatureTensorType) -> FeatureTensorType: - if (isinstance(attr.index, torch.Tensor) - and isinstance(tensor, np.ndarray)): - return torch.from_numpy(tensor) - if (isinstance(attr.index, np.ndarray) - and isinstance(tensor, torch.Tensor)): - return tensor.numpy() - return tensor attr = self._tensor_attr_cls.cast(*args, **kwargs) - if isinstance(attr.index, slice): - if attr.index.start == attr.index.stop == attr.index.step is None: - attr.index = None - if not attr.is_fully_specified(): raise ValueError(f"The input TensorAttr '{attr}' is not fully " f"specified. Please fully specify the input by " @@ -328,7 +326,56 @@ def to_type(tensor: FeatureTensorType) -> FeatureTensorType: tensor = self._get_tensor(attr) if tensor is None: raise KeyError(f"A tensor corresponding to '{attr}' was not found") - return to_type(tensor) + return self._to_type(attr, tensor) + + def _multi_get_tensor( + self, attrs: List[TensorAttr]) -> Optional[FeatureTensorType]: + r"""To be implemented by :class:`FeatureStore` subclasses. + + .. note:: + The default implementation simply iterates over all calls to + :meth:`get_tensor`. Implementor classes that can provide + additional, more performant functionality are recommended to + to override this method. + """ + return [self._get_tensor(attr) for attr in attrs] + + def multi_get_tensor(self, + attrs: List[TensorAttr]) -> List[FeatureTensorType]: + r"""Synchronously obtains a :class:`FeatureTensorType` object from the + feature store for each tensor associated with the attributes in + `attrs`. + + Args: + attrs (List[TensorAttr]): a list of :class:`TensorAttr` attributes + that identify the tensors to get. + + Returns: + List[FeatureTensorType]: a Tensor of the same type as the index for + each attribute. + + Raises: + KeyError: if a tensor corresponding to an attr was not found. + ValueError: if any input `TensorAttr` is not fully specified. + """ + attrs = [self._tensor_attr_cls.cast(attr) for attr in attrs] + bad_attrs = [attr for attr in attrs if not attr.is_fully_specified()] + if len(bad_attrs) > 0: + raise ValueError( + f"The input TensorAttr(s) '{bad_attrs}' are not fully " + f"specified. Please fully specify them by specifying all " + f"'UNSET' fields") + + tensors = self._multi_get_tensor(attrs) + if None in tensors: + bad_attrs = [attrs[i] for i, v in enumerate(tensors) if v is None] + raise KeyError(f"Tensors corresponding to attributes " + f"'{bad_attrs}' were not found") + + return [ + self._to_type(attr, tensor) + for attr, tensor in zip(attrs, tensors) + ] @abstractmethod def _remove_tensor(self, attr: TensorAttr) -> bool: diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index beeb8f36dde5..d562f7c890af 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -206,12 +206,19 @@ def filter_feature_store( data[str_to_edge_type(key)].edge_index = edge_index # Filter node storage: - for attr in feature_store.get_all_tensor_attrs(): + attrs = feature_store.get_all_tensor_attrs() + required_attrs = [] + for attr in attrs: if attr.group_name in node_dict: - # If we have sampled nodes from this group, index into the - # feature store for these nodes' features: attr.index = node_dict[attr.group_name] - tensor = feature_store.get_tensor(attr) - data[attr.group_name][attr.attr_name] = tensor + required_attrs.append(attr) + + # NOTE Here, we utilize `feature_store.multi_get` to give the feature store + # full control over optimizing how it returns features (since the call is + # synchronous, this amounts to giving the feature store control over all + # iteration). + tensors = feature_store.multi_get_tensor(required_attrs) + for i, attr in enumerate(required_attrs): + data[attr.group_name][attr.attr_name] = tensors[i] return data diff --git a/torch_geometric/testing/feature_store.py b/torch_geometric/testing/feature_store.py index 1ed680cfdba5..c3b85b51fdf3 100644 --- a/torch_geometric/testing/feature_store.py +++ b/torch_geometric/testing/feature_store.py @@ -33,7 +33,13 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: if tensor is None: return None - if attr.index is None: # None indices return the whole tensor: + # None indices return the whole tensor: + if attr.index is None: + return tensor + + # Empty slices return the whole tensor: + if (isinstance(attr.index, slice) + and attr.index == slice(None, None, None)): return tensor idx = torch.cat([(index == v).nonzero() for v in attr.index]).view(-1)