Skip to content

Commit

Permalink
Add distributed feature info for distributed training (#7715)
Browse files Browse the repository at this point in the history
This code belongs to the part of the whole distributed training for PyG.
(This PR is to replace #7678)

This PR originally designed for the DistFeature class and now merged
with LocalFeatureStore -

Add partition/rpc info into LocalFeatureStore like num_partition,
partition_idx, feature_pb (feature_partitionbook), partition_meta,
RpcRouter, etc
Add one new class (RpcCallFeatureLookup) to do real remote rpc
feature_lookup work
Add one api ( .lookup_features() ) to do feature lookup in local node
and remote nodes based on sampled global node ids/edge ids based on
torch rpc apis
one unit test to verify the function of local/remote feature lookup
under .test/distributed/. folder
Now we combined the local feature store and distributed feature
properties (partition info and rpc remote access apis) into one
FeatureStore. later on we will change the class name from
LocalFeatureStore into PartitionFeatureStore with another PR.

Any comments please let us know. thanks.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: root <root@skyocean.sh.intel.com>
  • Loading branch information
4 people authored Aug 7, 2023
1 parent 7b1ed2a commit aadb135
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493), [#7764](https://github.com/pyg-team/pytorch_geometric/pull/7764) [#7765](https://github.com/pyg-team/pytorch_geometric/pull/7765))
- Added the `AmazonBook` heterogeneous dataset ([#7483](https://github.com/pyg-team/pytorch_geometric/pull/7483))
- Added hierarchical heterogeneous GraphSAGE example on OGB-MAG ([#7425](https://github.com/pyg-team/pytorch_geometric/pull/7425))
- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628), [#7671](https://github.com/pyg-team/pytorch_geometric/pull/7671), [#7846](https://github.com/pyg-team/pytorch_geometric/pull/7846))
- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628), [#7671](https://github.com/pyg-team/pytorch_geometric/pull/7671), [#7846](https://github.com/pyg-team/pytorch_geometric/pull/7846), [#7715](https://github.com/pyg-team/pytorch_geometric/pull/7715))
- Added the `GDELTLite` dataset ([#7442](https://github.com/pyg-team/pytorch_geometric/pull/7442))
- Added the `approx_knn` function for approximated nearest neighbor search ([#7421](https://github.com/pyg-team/pytorch_geometric/pull/7421))
- Added the `IGMCDataset` ([#7441](https://github.com/pyg-team/pytorch_geometric/pull/7441))
Expand Down
6 changes: 3 additions & 3 deletions test/distributed/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch_geometric.distributed.rpc as rpc
from torch_geometric.distributed import LocalFeatureStore
from torch_geometric.distributed.dist_context import DistContext, DistRole
from torch_geometric.distributed.rpc import RpcRouter
from torch_geometric.distributed.rpc import RPCRouter
from torch_geometric.testing import onlyLinux


Expand Down Expand Up @@ -44,7 +44,7 @@ def run_rpc_feature_test(
]

# 3) Find the mapping between worker and partition ID:
rpc_router = RpcRouter(partition_to_workers)
rpc_router = RPCRouter(partition_to_workers)

assert rpc_router.get_to_worker(partition_idx=0) == 'dist-feature-test-0'
assert rpc_router.get_to_worker(partition_idx=1) == 'dist-feature-test-1'
Expand All @@ -60,7 +60,7 @@ def run_rpc_feature_test(
feature.partition_idx = rank
feature.feature_pb = partition_book
feature.meta = meta
feature.set_local_only(local_only=False)
feature.local_only = False
feature.set_rpc_router(rpc_router)

# Global node IDs:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
EdgeTensorType,
EdgeType,
FeatureTensorType,
NodeOrEdgeType,
NodeType,
QueryType,
SparseTensor,
Expand All @@ -29,7 +30,6 @@
mask_select,
)

NodeOrEdgeType = Union[NodeType, EdgeType]
NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage]


Expand Down
190 changes: 189 additions & 1 deletion torch_geometric/distributed/local_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,26 @@

from torch_geometric.data import FeatureStore, TensorAttr
from torch_geometric.data.feature_store import _FieldStatus
from torch_geometric.typing import EdgeType, NodeType
from torch_geometric.distributed.rpc import (
RPCCallBase,
RPCRouter,
rpc_async,
rpc_register,
)
from torch_geometric.typing import EdgeType, NodeOrEdgeType, NodeType


class RPCCallFeatureLookup(RPCCallBase):
r"""A wrapper for RPC calls to the feature store."""
def __init__(self, dist_feature: FeatureStore):
super().__init__()
self.dist_feature = dist_feature

def rpc_async(self, *args, **kwargs):
return self.dist_feature.rpc_local_feature_get(*args, **kwargs)

def rpc_sync(self, *args, **kwargs):
raise NotImplementedError


@dataclass
Expand Down Expand Up @@ -38,6 +57,15 @@ def __init__(self):
# Save the mapping from global node/edge IDs to indices in `_feat`:
self._global_id_to_index: Dict[Union[NodeType, EdgeType], Tensor] = {}

# For partition/rpc information related to distribute features:
self.num_partitions = 1
self.partition_idx = 0
self.feature_pb: Union[Tensor, Dict[NodeOrEdgeType, Tensor]]
self.local_only = False
self.rpc_router: Optional[RPCRouter] = None
self.meta: Optional[Dict] = None
self.rpc_call_id: Optional[int] = None

@staticmethod
def key(attr: TensorAttr) -> Tuple[str, str]:
return (attr.group_name, attr.attr_name)
Expand Down Expand Up @@ -107,6 +135,166 @@ def _get_tensor_size(self, attr: TensorAttr) -> Tuple[int, ...]:
def get_all_tensor_attrs(self) -> List[LocalTensorAttr]:
return [self._tensor_attr_cls.cast(*key) for key in self._feat.keys()]

def set_rpc_router(self, rpc_router: RPCRouter):
self.rpc_router = rpc_router

if not self.local_only:
if self.rpc_router is None:
raise ValueError("An RPC router must be provided")
rpc_call = RPCCallFeatureLookup(self)
self.rpc_call_id = rpc_register(rpc_call)
else:
self.rpc_call_id = None

def lookup_features(
self,
index: Tensor,
is_node_feat: bool = True,
input_type: Optional[NodeOrEdgeType] = None,
) -> torch.futures.Future:
r"""Lookup of local/remote features."""
remote_fut = self._remote_lookup_features(index, is_node_feat,
input_type)
local_feature = self._local_lookup_features(index, is_node_feat,
input_type)
res_fut = torch.futures.Future()

def when_finish(*_):
try:
remote_feature_list = remote_fut.wait()
# combine the feature from remote and local
result = torch.zeros(index.size(0), local_feature[0].size(1),
dtype=local_feature[0].dtype)
result[local_feature[1]] = local_feature[0]
for remote in remote_feature_list:
result[remote[1]] = remote[0]
except Exception as e:
res_fut.set_exception(e)
else:
res_fut.set_result(result)

remote_fut.add_done_callback(when_finish)
return res_fut

def _local_lookup_features(
self,
index: Tensor,
is_node_feat: bool = True,
input_type: Optional[Union[NodeType, EdgeType]] = None,
) -> Tuple[Tensor, Tensor]:
r""" lookup the features in local nodes based on node/edge ids """
if self.meta['is_hetero']:
feat = self
pb = self.feature_pb[input_type]
else:
feat = self
pb = self.feature_pb

input_order = torch.arange(index.size(0), dtype=torch.long)
partition_ids = pb[index]

local_mask = partition_ids == self.partition_idx
local_ids = torch.masked_select(index, local_mask)
local_index = torch.masked_select(input_order, local_mask)

if self.meta["is_hetero"]:
if is_node_feat:
kwargs = dict(group_name=input_type, attr_name='x')
ret_feat = feat.get_tensor_from_global_id(
index=local_ids, **kwargs)
else:
kwargs = dict(group_name=input_type, attr_name='edge_attr')
ret_feat = feat.get_tensor_from_global_id(
index=local_ids, **kwargs)
else:
if is_node_feat:
kwargs = dict(group_name=None, attr_name='x')
ret_feat = feat.get_tensor_from_global_id(
index=local_ids, **kwargs)
else:
kwargs = dict(group_name=(None, None), attr_name='edge_attr')
ret_feat = feat.get_tensor_from_global_id(
index=local_ids, **kwargs)

return ret_feat, local_index

def _remote_lookup_features(
self,
index: Tensor,
is_node_feat: bool = True,
input_type: Optional[Union[NodeType, EdgeType]] = None,
) -> torch.futures.Future:
r"""Fetch the remote features with the remote node/edge ids"""

if self.meta["is_hetero"]:
pb = self.feature_pb[input_type]
else:
pb = self.feature_pb

input_order = torch.arange(index.size(0), dtype=torch.long)
partition_ids = pb[index]
futs, indexes = [], []
for pidx in range(0, self.num_partitions):
if pidx == self.partition_idx:
continue
remote_mask = (partition_ids == pidx)
remote_ids = index[remote_mask]
if remote_ids.shape[0] > 0:
to_worker = self.rpc_router.get_to_worker(pidx)
futs.append(
rpc_async(
to_worker,
self.rpc_call_id,
args=(remote_ids.cpu(), is_node_feat, input_type),
))
indexes.append(torch.masked_select(input_order, remote_mask))
collect_fut = torch.futures.collect_all(futs)
res_fut = torch.futures.Future()

def when_finish(*_):
try:
fut_list = collect_fut.wait()
result = []
for i, fut in enumerate(fut_list):
result.append((fut.wait(), indexes[i]))
except Exception as e:
res_fut.set_exception(e)
else:
res_fut.set_result(result)

collect_fut.add_done_callback(when_finish)
return res_fut

def rpc_local_feature_get(
self,
index: Tensor,
is_node_feat: bool = True,
input_type: Optional[Union[NodeType, EdgeType]] = None,
) -> Tensor:
r"""Lookup of features in remote nodes."""
if self.meta['is_hetero']:
feat = self
if is_node_feat:
kwargs = dict(group_name=input_type, attr_name='x')
ret_feat = feat.get_tensor_from_global_id(
index=index, **kwargs)
else:
kwargs = dict(group_name=input_type, attr_name='edge_attr')
ret_feat = feat.get_tensor_from_global_id(
index=index, **kwargs)
else:
feat = self
if is_node_feat:
kwargs = dict(group_name=None, attr_name='x')
ret_feat = feat.get_tensor_from_global_id(
index=index, **kwargs)
else:
kwargs = dict(group_name=(None, None), attr_name='edge_attr')
ret_feat = feat.get_tensor_from_global_id(
index=index, **kwargs)

return ret_feat

# Initialization ##########################################################

@classmethod
Expand Down
40 changes: 36 additions & 4 deletions torch_geometric/distributed/local_graph_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os.path as osp
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -18,10 +18,40 @@ def __init__(self):
self._edge_attr: Dict[Tuple, EdgeAttr] = {}
self._edge_id: Dict[Tuple, Tensor] = {}

self.num_partitions = 1
self.partition_idx = 0
# Mapping between node ID and partition ID
self.node_pb: Union[Tensor, Dict[NodeType, Tensor]] = None
# Mapping between edge ID and partition ID
self.edge_pb: Union[Tensor, Dict[EdgeType, Tensor]] = None
# Meta information related to partition and graph store info
self.meta: Optional[Dict[Any, Any]] = None
# Partition labels
self.labels: Union[Tensor, Dict[EdgeType, Tensor]] = None

@staticmethod
def key(attr: EdgeAttr) -> Tuple:
return (attr.edge_type, attr.layout.value)

def get_partition_ids_from_nids(
self,
ids: torch.Tensor,
node_type: Optional[NodeType] = None,
) -> Tensor:
r"""Get the partition IDs of node IDs for a specific node type."""
if self.meta['is_hetero']:
assert node_type is not None
return self.node_pb[node_type][ids]
return self.node_pb[ids]

def get_partition_ids_from_eids(self, eids: torch.Tensor,
edge_type: Optional[EdgeType] = None):
r"""Get the partition IDs of edge IDs for a specific edge type."""
if self.meta["is_hetero"]:
assert edge_type is not None
return self.edge_pb[edge_type][eids]
return self.edge_pb[eids]

def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool:
edge_attr = self._edge_attr_cls.cast(*args, **kwargs)
self._edge_id[self.key(edge_attr)] = edge_id
Expand Down Expand Up @@ -126,15 +156,17 @@ def from_partition(cls, root: str, pid: int) -> 'LocalGraphStore':

if not meta['is_hetero']:
attr = dict(edge_type=None, layout='coo', size=graph_data['size'])
graph_store.put_edge_index((graph_data['row'], graph_data['col']),
**attr)
graph_store.put_edge_index(
torch.stack((graph_data['row'], graph_data['col']), dim=0),
**attr)
graph_store.put_edge_id(graph_data['edge_id'], **attr)

if meta['is_hetero']:
for edge_type, data in graph_data.items():
attr = dict(edge_type=edge_type, layout='coo',
size=data['size'])
graph_store.put_edge_index((data['row'], data['col']), **attr)
graph_store.put_edge_index(
torch.stack((data['row'], data['col']), dim=0), **attr)
graph_store.put_edge_id(data['edge_id'], **attr)

return graph_store
8 changes: 4 additions & 4 deletions torch_geometric/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def shutdown_rpc(graceful: bool = True):
atexit.register(shutdown_rpc, False)


class RpcRouter:
class RPCRouter:
r"""A router to get the worker based on the partition ID."""
def __init__(self, partition_to_workers: List[List[str]]):
for pid, rpc_worker_list in enumerate(partition_to_workers):
Expand Down Expand Up @@ -132,7 +132,7 @@ def rpc_partition_to_workers(
return partition_to_workers


class RpcCallBase(ABC):
class RPCCallBase(ABC):
r"""A wrapper base class for RPC calls in remote processes."""
@abstractmethod
def rpc_sync(self, *args, **kwargs):
Expand All @@ -145,11 +145,11 @@ def rpc_async(self, *args, **kwargs):

_rpc_call_lock = threading.RLock()
_rpc_call_id: int = 0
_rpc_call_pool: Dict[int, RpcCallBase] = {}
_rpc_call_pool: Dict[int, RPCCallBase] = {}


@rpc_require_initialized
def rpc_register(call: RpcCallBase) -> int:
def rpc_register(call: RPCCallBase) -> int:
r"""Registers a call for RPC requests."""
global _rpc_call_id, _rpc_call_pool

Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ def t(self) -> Tensor: # Only support accessing its transpose:
# `data[('author', 'writes', 'paper')]
EdgeType = Tuple[str, str, str]

NodeOrEdgeType = Union[NodeType, EdgeType]

DEFAULT_REL = 'to'
EDGE_TYPE_STR_SPLIT = '__'

Expand Down

0 comments on commit aadb135

Please sign in to comment.