Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FeatureStore.multi_get_tensor implementation #4853

Merged
merged 13 commits into from
Jun 24, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
- Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816))
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807))
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807), [#4853](https://github.com/pyg-team/pytorch_geometric/pull/4853))
- Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827))
- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))
- Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805))
Expand Down
3 changes: 2 additions & 1 deletion test/data/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
return None

# None indices return the whole tensor:
if attr.index is None:
if attr.index is None or isinstance(
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
attr.index, slice) and attr.index == slice(None, None, None):
return tensor

idx = torch.cat([(index == v).nonzero() for v in attr.index]).view(-1)
Expand Down
77 changes: 62 additions & 15 deletions torch_geometric/data/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,17 @@ def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool:
f"specifying all 'UNSET' fields")
return self._put_tensor(tensor, attr)

@staticmethod
def _to_type(attr: TensorAttr,
tensor: FeatureTensorType) -> FeatureTensorType:
if (isinstance(attr.index, torch.Tensor)
and isinstance(tensor, np.ndarray)):
return torch.from_numpy(tensor)
if (isinstance(attr.index, np.ndarray)
and isinstance(tensor, torch.Tensor)):
return tensor.numpy()
return tensor

@abstractmethod
def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
r"""To be implemented by :class:`FeatureStore` subclasses."""
Expand All @@ -299,27 +310,14 @@ def get_tensor(self, *args, **kwargs) -> FeatureTensorType:
from a :class:`TensorAttr` object.

Returns:
FeatureTensorType: a Tensor of the same type as the index, or
:obj:`None` if no tensor was found.
FeatureTensorType: a Tensor of the same type as the index.

Raises:
KeyError: if the tensor corresponding to attr was not found.
ValueError: if the input `TensorAttr` is not fully specified.
"""
def to_type(tensor: FeatureTensorType) -> FeatureTensorType:
if (isinstance(attr.index, torch.Tensor)
and isinstance(tensor, np.ndarray)):
return torch.from_numpy(tensor)
if (isinstance(attr.index, np.ndarray)
and isinstance(tensor, torch.Tensor)):
return tensor.numpy()
return tensor

attr = self._tensor_attr_cls.cast(*args, **kwargs)
if isinstance(attr.index, slice):
if attr.index.start == attr.index.stop == attr.index.step is None:
attr.index = None

if not attr.is_fully_specified():
raise ValueError(f"The input TensorAttr '{attr}' is not fully "
f"specified. Please fully specify the input by "
Expand All @@ -328,7 +326,56 @@ def to_type(tensor: FeatureTensorType) -> FeatureTensorType:
tensor = self._get_tensor(attr)
if tensor is None:
raise KeyError(f"A tensor corresponding to '{attr}' was not found")
return to_type(tensor)
return FeatureStore._to_type(attr, tensor)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

def _multi_get_tensor(
self, attrs: List[TensorAttr]) -> Optional[FeatureTensorType]:
r"""To be implemented by :class:`FeatureStore` subclasses."""

# NOTE The default implementation simply iterates over calls to
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
# `get_tensor`: implementor classes that can provide additional, more
# performant functionality are recommended to override this method.
out: List[FeatureTensorType] = []
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
for attr in attrs:
out.append(self._get_tensor(attr))
return out

def multi_get_tensor(self,
attrs: List[TensorAttr]) -> List[FeatureTensorType]:
r"""Synchronously obtains a :class:`FeatureTensorType` object from the
feature store for each tensor associated with the attributes in
`attrs`.

Args:
attrs (List[TensorAttr]): a list of :class:`TensorAttr` attributes
that identify the tensors to get.

Returns:
List[FeatureTensorType]: a Tensor of the same type as the index for
each attribute.

Raises:
KeyError: if a tensor corresponding to an attr was not found.
ValueError: if any input `TensorAttr` is not fully specified.
"""
attrs = [self._tensor_attr_cls.cast(attr) for attr in attrs]
bad_attrs = [attr for attr in attrs if not attr.is_fully_specified()]
if len(bad_attrs) > 0:
raise ValueError(
f"The input TensorAttr(s) '{bad_attrs}' are not fully "
f"specified. Please fully specify the input attrs by "
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
f"specifying all 'UNSET' fields.")

tensors = self._multi_get_tensor(attrs)
if None in tensors:
bad_attrs = [attrs[i] for i, v in enumerate(tensors) if v is None]
raise KeyError(f"Tensors corresponding to attributes "
f"'{bad_attrs}' were not found")

return [
FeatureStore._to_type(attr, tensor)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
for attr, tensor in zip(attrs, tensors)
]

@abstractmethod
def _remove_tensor(self, attr: TensorAttr) -> bool:
Expand Down
17 changes: 12 additions & 5 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,19 @@ def filter_feature_store(
data[str_to_edge_type(key)].edge_index = edge_index

# Filter node storage:
for attr in feature_store.get_all_tensor_attrs():
attrs = feature_store.get_all_tensor_attrs()
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
required_attrs = []
for attr in attrs:
if attr.group_name in node_dict:
# If we have sampled nodes from this group, index into the
# feature store for these nodes' features:
attr.index = node_dict[attr.group_name]
tensor = feature_store.get_tensor(attr)
data[attr.group_name][attr.attr_name] = tensor
required_attrs.append(attr)

# NOTE Here, we utilize `feature_store.multi_get` to give the feature store
# full control over optimizing how it returns features (since the call is
# synchronous, this amounts to giving the feature store control over all
# iteration).
tensors = feature_store.multi_get_tensor(required_attrs)
for i, attr in enumerate(required_attrs):
data[attr.group_name][attr.attr_name] = tensors[i]

return data