Skip to content

Commit

Permalink
FeatureStore.multi_get_tensor implementation (#4853)
Browse files Browse the repository at this point in the history
* init

* CHANGELOG

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* docstring update

* update interface

* None check

* better errors

* comments

* typo

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Jun 24, 2022
1 parent c40b099 commit e46fbb1
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 22 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
- Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816))
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807))
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807), [#4853](https://github.com/pyg-team/pytorch_geometric/pull/4853))
- Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827))
- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))
- Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805))
Expand Down
77 changes: 62 additions & 15 deletions torch_geometric/data/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,17 @@ def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool:
f"specifying all 'UNSET' fields")
return self._put_tensor(tensor, attr)

@staticmethod
def _to_type(attr: TensorAttr,
tensor: FeatureTensorType) -> FeatureTensorType:
if (isinstance(attr.index, torch.Tensor)
and isinstance(tensor, np.ndarray)):
return torch.from_numpy(tensor)
if (isinstance(attr.index, np.ndarray)
and isinstance(tensor, torch.Tensor)):
return tensor.detach().cpu().numpy()
return tensor

@abstractmethod
def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
r"""To be implemented by :class:`FeatureStore` subclasses."""
Expand All @@ -299,27 +310,14 @@ def get_tensor(self, *args, **kwargs) -> FeatureTensorType:
from a :class:`TensorAttr` object.
Returns:
FeatureTensorType: a Tensor of the same type as the index, or
:obj:`None` if no tensor was found.
FeatureTensorType: a Tensor of the same type as the index.
Raises:
KeyError: if the tensor corresponding to attr was not found.
ValueError: if the input `TensorAttr` is not fully specified.
"""
def to_type(tensor: FeatureTensorType) -> FeatureTensorType:
if (isinstance(attr.index, torch.Tensor)
and isinstance(tensor, np.ndarray)):
return torch.from_numpy(tensor)
if (isinstance(attr.index, np.ndarray)
and isinstance(tensor, torch.Tensor)):
return tensor.numpy()
return tensor

attr = self._tensor_attr_cls.cast(*args, **kwargs)
if isinstance(attr.index, slice):
if attr.index.start == attr.index.stop == attr.index.step is None:
attr.index = None

if not attr.is_fully_specified():
raise ValueError(f"The input TensorAttr '{attr}' is not fully "
f"specified. Please fully specify the input by "
Expand All @@ -328,7 +326,56 @@ def to_type(tensor: FeatureTensorType) -> FeatureTensorType:
tensor = self._get_tensor(attr)
if tensor is None:
raise KeyError(f"A tensor corresponding to '{attr}' was not found")
return to_type(tensor)
return self._to_type(attr, tensor)

def _multi_get_tensor(
self, attrs: List[TensorAttr]) -> Optional[FeatureTensorType]:
r"""To be implemented by :class:`FeatureStore` subclasses.
.. note::
The default implementation simply iterates over all calls to
:meth:`get_tensor`. Implementor classes that can provide
additional, more performant functionality are recommended to
to override this method.
"""
return [self._get_tensor(attr) for attr in attrs]

def multi_get_tensor(self,
attrs: List[TensorAttr]) -> List[FeatureTensorType]:
r"""Synchronously obtains a :class:`FeatureTensorType` object from the
feature store for each tensor associated with the attributes in
`attrs`.
Args:
attrs (List[TensorAttr]): a list of :class:`TensorAttr` attributes
that identify the tensors to get.
Returns:
List[FeatureTensorType]: a Tensor of the same type as the index for
each attribute.
Raises:
KeyError: if a tensor corresponding to an attr was not found.
ValueError: if any input `TensorAttr` is not fully specified.
"""
attrs = [self._tensor_attr_cls.cast(attr) for attr in attrs]
bad_attrs = [attr for attr in attrs if not attr.is_fully_specified()]
if len(bad_attrs) > 0:
raise ValueError(
f"The input TensorAttr(s) '{bad_attrs}' are not fully "
f"specified. Please fully specify them by specifying all "
f"'UNSET' fields")

tensors = self._multi_get_tensor(attrs)
if None in tensors:
bad_attrs = [attrs[i] for i, v in enumerate(tensors) if v is None]
raise KeyError(f"Tensors corresponding to attributes "
f"'{bad_attrs}' were not found")

return [
self._to_type(attr, tensor)
for attr, tensor in zip(attrs, tensors)
]

@abstractmethod
def _remove_tensor(self, attr: TensorAttr) -> bool:
Expand Down
17 changes: 12 additions & 5 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,19 @@ def filter_feature_store(
data[str_to_edge_type(key)].edge_index = edge_index

# Filter node storage:
for attr in feature_store.get_all_tensor_attrs():
attrs = feature_store.get_all_tensor_attrs()
required_attrs = []
for attr in attrs:
if attr.group_name in node_dict:
# If we have sampled nodes from this group, index into the
# feature store for these nodes' features:
attr.index = node_dict[attr.group_name]
tensor = feature_store.get_tensor(attr)
data[attr.group_name][attr.attr_name] = tensor
required_attrs.append(attr)

# NOTE Here, we utilize `feature_store.multi_get` to give the feature store
# full control over optimizing how it returns features (since the call is
# synchronous, this amounts to giving the feature store control over all
# iteration).
tensors = feature_store.multi_get_tensor(required_attrs)
for i, attr in enumerate(required_attrs):
data[attr.group_name][attr.attr_name] = tensors[i]

return data
8 changes: 7 additions & 1 deletion torch_geometric/testing/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
if tensor is None:
return None

if attr.index is None: # None indices return the whole tensor:
# None indices return the whole tensor:
if attr.index is None:
return tensor

# Empty slices return the whole tensor:
if (isinstance(attr.index, slice)
and attr.index == slice(None, None, None)):
return tensor

idx = torch.cat([(index == v).nonzero() for v in attr.index]).view(-1)
Expand Down

0 comments on commit e46fbb1

Please sign in to comment.