Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distributed feature info for distributed training #7715

Merged
merged 46 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
bbf9611
create new branch of dist_rpc
ZhengHongming888 Jun 30, 2023
e8b6cf8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2023
c554684
create branch distributed_graph
ZhengHongming888 Jul 1, 2023
6213a78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 1, 2023
5ad3cb2
minor
ZhengHongming888 Jul 2, 2023
d7ebd90
Merge branch 'distributed_graph' of https://github.com/ZhengHongming8…
ZhengHongming888 Jul 2, 2023
1a69c83
Merge branch 'pyg-team:master' into dist_rpc
ZhengHongming888 Jul 8, 2023
5569bc4
modify on comments
ZhengHongming888 Jul 8, 2023
8224966
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2023
3c16151
Merge branch 'pyg-team:master' into distributed_graph
ZhengHongming888 Jul 8, 2023
8b5274e
modify based comments
ZhengHongming888 Jul 8, 2023
944228e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2023
fdaf186
create branch distributed_graph
ZhengHongming888 Jul 1, 2023
c331a88
minor
ZhengHongming888 Jul 2, 2023
fcaf373
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 1, 2023
0cbf3b8
modify based comments
ZhengHongming888 Jul 8, 2023
68eff3e
rebase dist_rpc
ZhengHongming888 Jul 9, 2023
e33a2aa
rebase dist_rpc
ZhengHongming888 Jul 9, 2023
69818a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2023
a93b4fe
Merge branch 'pyg-team:master' into distributed_graph
ZhengHongming888 Jul 9, 2023
dde34a7
create new branch of distributed_feature
ZhengHongming888 Jul 9, 2023
1b8cedb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2023
8660f57
Merge branch 'pyg-team:master' into dist_rpc
ZhengHongming888 Jul 14, 2023
047a249
minor for dist_context/rpc
ZhengHongming888 Jul 14, 2023
04c00d3
Merge branch 'pyg-team:master' into distributed_feature
ZhengHongming888 Jul 15, 2023
ec46e93
minor dist_context/rpc/test
ZhengHongming888 Jul 15, 2023
b0fb563
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2023
c8848ed
Merge branch 'pyg-team:master' into distributed_feature
ZhengHongming888 Jul 18, 2023
6eb996a
comments from manan
ZhengHongming888 Jul 19, 2023
1ef24d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2023
a225974
Merge branch 'pyg-team:master' into distributed_feature
ZhengHongming888 Jul 25, 2023
132417e
minor
ZhengHongming888 Jul 25, 2023
06668df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2023
3fae492
Merge branch 'master' into distributed_feature
rusty1s Jul 28, 2023
2be91fc
Merge branch 'pyg-team:master' into distributed_feature
ZhengHongming888 Jul 28, 2023
75b3bff
update test
ZhengHongming888 Jul 29, 2023
b405c55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 29, 2023
d1c3062
Merge branch 'master' into distributed_feature
rusty1s Aug 4, 2023
6cdbed2
Merge branch 'pyg-team:master' into distributed_feature
ZhengHongming888 Aug 4, 2023
e0a8c2b
Merge branch 'pyg-team:master' into distributed_feature
ZhengHongming888 Aug 4, 2023
10916b3
comments from Matthias
ZhengHongming888 Aug 4, 2023
1824b3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2023
8fe6db2
Merge branch 'pyg-team:master' into distributed_feature
ZhengHongming888 Aug 5, 2023
f7879cc
_FieldStatus in local_feature_store
Aug 6, 2023
1db68d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2023
e256e0b
update
rusty1s Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493), [#7764](https://github.com/pyg-team/pytorch_geometric/pull/7764) [#7765](https://github.com/pyg-team/pytorch_geometric/pull/7765))
- Added the `AmazonBook` heterogeneous dataset ([#7483](https://github.com/pyg-team/pytorch_geometric/pull/7483))
- Added hierarchical heterogeneous GraphSAGE example on OGB-MAG ([#7425](https://github.com/pyg-team/pytorch_geometric/pull/7425))
- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628), [#7671](https://github.com/pyg-team/pytorch_geometric/pull/7671), [#7846](https://github.com/pyg-team/pytorch_geometric/pull/7846))
- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628), [#7671](https://github.com/pyg-team/pytorch_geometric/pull/7671), [#7846](https://github.com/pyg-team/pytorch_geometric/pull/7846), [#7715](https://github.com/pyg-team/pytorch_geometric/pull/7715))
- Added the `GDELTLite` dataset ([#7442](https://github.com/pyg-team/pytorch_geometric/pull/7442))
- Added the `approx_knn` function for approximated nearest neighbor search ([#7421](https://github.com/pyg-team/pytorch_geometric/pull/7421))
- Added the `IGMCDataset` ([#7441](https://github.com/pyg-team/pytorch_geometric/pull/7441))
Expand Down
6 changes: 3 additions & 3 deletions test/distributed/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch_geometric.distributed.rpc as rpc
from torch_geometric.distributed import LocalFeatureStore
from torch_geometric.distributed.dist_context import DistContext, DistRole
from torch_geometric.distributed.rpc import RpcRouter
from torch_geometric.distributed.rpc import RPCRouter
from torch_geometric.testing import onlyLinux


Expand Down Expand Up @@ -44,7 +44,7 @@ def run_rpc_feature_test(
]

# 3) Find the mapping between worker and partition ID:
rpc_router = RpcRouter(partition_to_workers)
rpc_router = RPCRouter(partition_to_workers)

assert rpc_router.get_to_worker(partition_idx=0) == 'dist-feature-test-0'
assert rpc_router.get_to_worker(partition_idx=1) == 'dist-feature-test-1'
Expand All @@ -60,7 +60,7 @@ def run_rpc_feature_test(
feature.partition_idx = rank
feature.feature_pb = partition_book
feature.meta = meta
feature.set_local_only(local_only=False)
feature.local_only = False
feature.set_rpc_router(rpc_router)

# Global node IDs:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
EdgeTensorType,
EdgeType,
FeatureTensorType,
NodeOrEdgeType,
NodeType,
QueryType,
SparseTensor,
Expand All @@ -29,7 +30,6 @@
mask_select,
)

NodeOrEdgeType = Union[NodeType, EdgeType]
NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage]


Expand Down
190 changes: 189 additions & 1 deletion torch_geometric/distributed/local_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,26 @@

from torch_geometric.data import FeatureStore, TensorAttr
from torch_geometric.data.feature_store import _FieldStatus
from torch_geometric.typing import EdgeType, NodeType
from torch_geometric.distributed.rpc import (
RPCCallBase,
RPCRouter,
rpc_async,
rpc_register,
)
from torch_geometric.typing import EdgeType, NodeOrEdgeType, NodeType


class RPCCallFeatureLookup(RPCCallBase):
r"""A wrapper for RPC calls to the feature store."""
def __init__(self, dist_feature: FeatureStore):
super().__init__()
self.dist_feature = dist_feature

def rpc_async(self, *args, **kwargs):
return self.dist_feature.rpc_local_feature_get(*args, **kwargs)

def rpc_sync(self, *args, **kwargs):
raise NotImplementedError


@dataclass
Expand Down Expand Up @@ -38,6 +57,15 @@ def __init__(self):
# Save the mapping from global node/edge IDs to indices in `_feat`:
self._global_id_to_index: Dict[Union[NodeType, EdgeType], Tensor] = {}

# For partition/rpc information related to distribute features:
self.num_partitions = 1
self.partition_idx = 0
self.feature_pb: Union[Tensor, Dict[NodeOrEdgeType, Tensor]]
self.local_only = False
self.rpc_router: Optional[RPCRouter] = None
self.meta: Optional[Dict] = None
self.rpc_call_id: Optional[int] = None

@staticmethod
def key(attr: TensorAttr) -> Tuple[str, str]:
return (attr.group_name, attr.attr_name)
Expand Down Expand Up @@ -107,6 +135,166 @@ def _get_tensor_size(self, attr: TensorAttr) -> Tuple[int, ...]:
def get_all_tensor_attrs(self) -> List[LocalTensorAttr]:
return [self._tensor_attr_cls.cast(*key) for key in self._feat.keys()]

def set_rpc_router(self, rpc_router: RPCRouter):
self.rpc_router = rpc_router

if not self.local_only:
if self.rpc_router is None:
raise ValueError("An RPC router must be provided")
rpc_call = RPCCallFeatureLookup(self)
self.rpc_call_id = rpc_register(rpc_call)
else:
self.rpc_call_id = None

def lookup_features(
self,
index: Tensor,
is_node_feat: bool = True,
input_type: Optional[NodeOrEdgeType] = None,
) -> torch.futures.Future:
r"""Lookup of local/remote features."""
remote_fut = self._remote_lookup_features(index, is_node_feat,
input_type)
local_feature = self._local_lookup_features(index, is_node_feat,
input_type)
res_fut = torch.futures.Future()

def when_finish(*_):
try:
remote_feature_list = remote_fut.wait()
# combine the feature from remote and local
result = torch.zeros(index.size(0), local_feature[0].size(1),
dtype=local_feature[0].dtype)
result[local_feature[1]] = local_feature[0]
for remote in remote_feature_list:
result[remote[1]] = remote[0]
except Exception as e:
res_fut.set_exception(e)
else:
res_fut.set_result(result)

remote_fut.add_done_callback(when_finish)
return res_fut

def _local_lookup_features(
self,
index: Tensor,
is_node_feat: bool = True,
input_type: Optional[Union[NodeType, EdgeType]] = None,
) -> Tuple[Tensor, Tensor]:
r""" lookup the features in local nodes based on node/edge ids """
if self.meta['is_hetero']:
feat = self
pb = self.feature_pb[input_type]
else:
feat = self
pb = self.feature_pb

input_order = torch.arange(index.size(0), dtype=torch.long)
partition_ids = pb[index]

local_mask = partition_ids == self.partition_idx
local_ids = torch.masked_select(index, local_mask)
local_index = torch.masked_select(input_order, local_mask)

if self.meta["is_hetero"]:
if is_node_feat:
kwargs = dict(group_name=input_type, attr_name='x')
ret_feat = feat.get_tensor_from_global_id(
index=local_ids, **kwargs)
else:
kwargs = dict(group_name=input_type, attr_name='edge_attr')
ret_feat = feat.get_tensor_from_global_id(
index=local_ids, **kwargs)
else:
if is_node_feat:
kwargs = dict(group_name=None, attr_name='x')
ret_feat = feat.get_tensor_from_global_id(
index=local_ids, **kwargs)
else:
kwargs = dict(group_name=(None, None), attr_name='edge_attr')
ret_feat = feat.get_tensor_from_global_id(
index=local_ids, **kwargs)

return ret_feat, local_index

def _remote_lookup_features(
self,
index: Tensor,
is_node_feat: bool = True,
input_type: Optional[Union[NodeType, EdgeType]] = None,
) -> torch.futures.Future:
r"""Fetch the remote features with the remote node/edge ids"""

if self.meta["is_hetero"]:
pb = self.feature_pb[input_type]
else:
pb = self.feature_pb

input_order = torch.arange(index.size(0), dtype=torch.long)
partition_ids = pb[index]
futs, indexes = [], []
for pidx in range(0, self.num_partitions):
if pidx == self.partition_idx:
continue
remote_mask = (partition_ids == pidx)
remote_ids = index[remote_mask]
if remote_ids.shape[0] > 0:
to_worker = self.rpc_router.get_to_worker(pidx)
futs.append(
rpc_async(
to_worker,
self.rpc_call_id,
args=(remote_ids.cpu(), is_node_feat, input_type),
))
indexes.append(torch.masked_select(input_order, remote_mask))
collect_fut = torch.futures.collect_all(futs)
res_fut = torch.futures.Future()

def when_finish(*_):
try:
fut_list = collect_fut.wait()
result = []
for i, fut in enumerate(fut_list):
result.append((fut.wait(), indexes[i]))
except Exception as e:
res_fut.set_exception(e)
else:
res_fut.set_result(result)

collect_fut.add_done_callback(when_finish)
return res_fut

def rpc_local_feature_get(
self,
index: Tensor,
is_node_feat: bool = True,
input_type: Optional[Union[NodeType, EdgeType]] = None,
) -> Tensor:
r"""Lookup of features in remote nodes."""
if self.meta['is_hetero']:
feat = self
if is_node_feat:
kwargs = dict(group_name=input_type, attr_name='x')
ret_feat = feat.get_tensor_from_global_id(
index=index, **kwargs)
else:
kwargs = dict(group_name=input_type, attr_name='edge_attr')
ret_feat = feat.get_tensor_from_global_id(
index=index, **kwargs)
else:
feat = self
if is_node_feat:
kwargs = dict(group_name=None, attr_name='x')
ret_feat = feat.get_tensor_from_global_id(
index=index, **kwargs)
else:
kwargs = dict(group_name=(None, None), attr_name='edge_attr')
ret_feat = feat.get_tensor_from_global_id(
index=index, **kwargs)

return ret_feat

# Initialization ##########################################################

@classmethod
Expand Down
40 changes: 36 additions & 4 deletions torch_geometric/distributed/local_graph_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os.path as osp
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -18,10 +18,40 @@ def __init__(self):
self._edge_attr: Dict[Tuple, EdgeAttr] = {}
self._edge_id: Dict[Tuple, Tensor] = {}

self.num_partitions = 1
self.partition_idx = 0
# Mapping between node ID and partition ID
self.node_pb: Union[Tensor, Dict[NodeType, Tensor]] = None
# Mapping between edge ID and partition ID
self.edge_pb: Union[Tensor, Dict[EdgeType, Tensor]] = None
# Meta information related to partition and graph store info
self.meta: Optional[Dict[Any, Any]] = None
# Partition labels
self.labels: Union[Tensor, Dict[EdgeType, Tensor]] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we need these. Can you clarify?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.meta is the partition information like num_partition, is_hetero, edge_type, .. which is easy/ very helpful for later code simplicity and you don't need judge based on dict or not,etc. self.labels is the whole dataset information for labels y and also easy for later code simplicity otherwise you need put one more argument from top to low level code function.


@staticmethod
def key(attr: EdgeAttr) -> Tuple:
return (attr.edge_type, attr.layout.value)

def get_partition_ids_from_nids(
self,
ids: torch.Tensor,
node_type: Optional[NodeType] = None,
) -> Tensor:
r"""Get the partition IDs of node IDs for a specific node type."""
if self.meta['is_hetero']:
assert node_type is not None
return self.node_pb[node_type][ids]
return self.node_pb[ids]

def get_partition_ids_from_eids(self, eids: torch.Tensor,
edge_type: Optional[EdgeType] = None):
r"""Get the partition IDs of edge IDs for a specific edge type."""
if self.meta["is_hetero"]:
assert edge_type is not None
return self.edge_pb[edge_type][eids]
return self.edge_pb[eids]

def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool:
edge_attr = self._edge_attr_cls.cast(*args, **kwargs)
self._edge_id[self.key(edge_attr)] = edge_id
Expand Down Expand Up @@ -126,15 +156,17 @@ def from_partition(cls, root: str, pid: int) -> 'LocalGraphStore':

if not meta['is_hetero']:
attr = dict(edge_type=None, layout='coo', size=graph_data['size'])
graph_store.put_edge_index((graph_data['row'], graph_data['col']),
**attr)
graph_store.put_edge_index(
torch.stack((graph_data['row'], graph_data['col']), dim=0),
**attr)
graph_store.put_edge_id(graph_data['edge_id'], **attr)

if meta['is_hetero']:
for edge_type, data in graph_data.items():
attr = dict(edge_type=edge_type, layout='coo',
size=data['size'])
graph_store.put_edge_index((data['row'], data['col']), **attr)
graph_store.put_edge_index(
torch.stack((data['row'], data['col']), dim=0), **attr)
graph_store.put_edge_id(data['edge_id'], **attr)

return graph_store
8 changes: 4 additions & 4 deletions torch_geometric/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def shutdown_rpc(graceful: bool = True):
atexit.register(shutdown_rpc, False)


class RpcRouter:
class RPCRouter:
r"""A router to get the worker based on the partition ID."""
def __init__(self, partition_to_workers: List[List[str]]):
for pid, rpc_worker_list in enumerate(partition_to_workers):
Expand Down Expand Up @@ -132,7 +132,7 @@ def rpc_partition_to_workers(
return partition_to_workers


class RpcCallBase(ABC):
class RPCCallBase(ABC):
r"""A wrapper base class for RPC calls in remote processes."""
@abstractmethod
def rpc_sync(self, *args, **kwargs):
Expand All @@ -145,11 +145,11 @@ def rpc_async(self, *args, **kwargs):

_rpc_call_lock = threading.RLock()
_rpc_call_id: int = 0
_rpc_call_pool: Dict[int, RpcCallBase] = {}
_rpc_call_pool: Dict[int, RPCCallBase] = {}


@rpc_require_initialized
def rpc_register(call: RpcCallBase) -> int:
def rpc_register(call: RPCCallBase) -> int:
r"""Registers a call for RPC requests."""
global _rpc_call_id, _rpc_call_pool

Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ def t(self) -> Tensor: # Only support accessing its transpose:
# `data[('author', 'writes', 'paper')]
EdgeType = Tuple[str, str, str]

NodeOrEdgeType = Union[NodeType, EdgeType]

DEFAULT_REL = 'to'
EDGE_TYPE_STR_SPLIT = '__'

Expand Down