Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FeatureStore.multi_get_tensor implementation #4853

Merged
merged 13 commits into from
Jun 24, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
- Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816))
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807))
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807), [#4853](https://github.com/pyg-team/pytorch_geometric/pull/4853))
- Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827))
- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))
- Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805))
Expand Down
25 changes: 25 additions & 0 deletions torch_geometric/data/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,31 @@ def to_type(tensor: FeatureTensorType) -> FeatureTensorType:
raise KeyError(f"A tensor corresponding to '{attr}' was not found")
return to_type(tensor)

def multi_get_tensor(self,
attrs: List[TensorAttr]) -> List[FeatureTensorType]:
r"""Synchronously obtains a :class:`FeatureTensorType` object from the
feature store for each tensor associated with the attributes in `attr`.

Args:
attrs (List[TensorAttr]): a list of :class:`TensorAttr` attributes
that identify the tensors to get.

Returns:
List[FeatureTensorType]: a Tensor of the same type as the index for
each attribute, or :obj:`None` if no tensor was found.

Raises:
KeyError: if the tensor corresponding to attr was not found.
ValueError: if any input `TensorAttr` is not fully specified.
"""
# NOTE The default implementation simply iterates over calls to
# `get_tensor`: implementor classes that can provide additional, more
# performant functionality are recommended to override this method.
out: List[FeatureTensorType] = []
for attr in attrs:
out.append(self.get_tensor(attr))
return out

@abstractmethod
def _remove_tensor(self, attr: TensorAttr) -> bool:
r"""To be implemented by :obj:`FeatureStore` subclasses."""
Expand Down
29 changes: 21 additions & 8 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import math
from collections import defaultdict
from typing import Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -206,12 +207,24 @@ def filter_feature_store(
data[str_to_edge_type(key)].edge_index = edge_index

# Filter node storage:
for attr in feature_store.get_all_tensor_attrs():
if attr.group_name in node_dict:
# If we have sampled nodes from this group, index into the
# feature store for these nodes' features:
attr.index = node_dict[attr.group_name]
tensor = feature_store.get_tensor(attr)
data[attr.group_name][attr.attr_name] = tensor

attrs = feature_store.get_all_tensor_attrs()
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
attrs_by_group_name = defaultdict(list)
for attr in attrs:
attrs_by_group_name[attr.group_name].append(attr)

# NOTE Here, we utilize `feature_store.multi_get` by grouping attrs by
# group name, as many feature store implementations may have efficient
# implementations of obtaining multiple attr names for the same group name
for group_name in node_dict:
attrs = attrs_by_group_name[group_name]

# Get tensors at the necessary indices:
index = node_dict[group_name]
for attr in attrs:
attr.index = index
tensors = feature_store.multi_get_tensor(attrs)

# Store responses in `data`:
for i, attr in enumerate(attrs):
data[group_name][attr.attr_name] = tensors[i]
return data