diff --git a/CHANGELOG.md b/CHANGELOG.md index 2503c1b75887..8e7abb72ccd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,7 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493)) - Added the `AmazonBook` heterogeneous dataset ([#7483](https://github.com/pyg-team/pytorch_geometric/pull/7483)) - Added hierarchical heterogeneous GraphSAGE example on OGB-MAG ([#7425](https://github.com/pyg-team/pytorch_geometric/pull/7425)) -- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628)) +- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628), [#7671](https://github.com/pyg-team/pytorch_geometric/pull/7671)) - Added the `GDELTLite` dataset ([#7442](https://github.com/pyg-team/pytorch_geometric/pull/7442)) - Added the `approx_knn` function for approximated nearest neighbor search ([#7421](https://github.com/pyg-team/pytorch_geometric/pull/7421)) - Added the `IGMCDataset` ([#7441](https://github.com/pyg-team/pytorch_geometric/pull/7441)) @@ -77,6 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Made `FieldStatus` enum picklable to avoid `PicklingError` in a multi-process setting ([#7808](https://github.com/pyg-team/pytorch_geometric/pull/7808)) - Fixed `edge_label_time` computation in `LinkNeighborLoader` for homogeneous graphs ([#7807](https://github.com/pyg-team/pytorch_geometric/pull/7807)) - Fixed `edge_label_index` computation in `LinkNeighborLoader` for the homogeneous+`disjoint` mode ([#7791](https://github.com/pyg-team/pytorch_geometric/pull/7791)) - Fixed `CaptumExplainer` for `binary_classification` tasks ([#7787](https://github.com/pyg-team/pytorch_geometric/pull/7787)) diff --git a/test/data/test_feature_store.py b/test/data/test_feature_store.py index 95f02a59c640..c9a85096624c 100644 --- a/test/data/test_feature_store.py +++ b/test/data/test_feature_store.py @@ -4,14 +4,13 @@ import torch from torch_geometric.data import TensorAttr -from torch_geometric.data.feature_store import AttrView, _field_status +from torch_geometric.data.feature_store import AttrView, _FieldStatus from torch_geometric.testing import MyFeatureStore @dataclass class MyTensorAttrNoGroupName(TensorAttr): - def __init__(self, attr_name=_field_status.UNSET, - index=_field_status.UNSET): + def __init__(self, attr_name=_FieldStatus.UNSET, index=_FieldStatus.UNSET): # Treat group_name as optional, and move it to the end super().__init__(None, attr_name, index) diff --git a/test/distributed/test_rpc.py b/test/distributed/test_rpc.py new file mode 100644 index 000000000000..673c53a694a0 --- /dev/null +++ b/test/distributed/test_rpc.py @@ -0,0 +1,121 @@ +import socket +from typing import Dict, List + +import torch + +import torch_geometric.distributed.rpc as rpc +from torch_geometric.distributed import LocalFeatureStore +from torch_geometric.distributed.dist_context import DistContext, DistRole +from torch_geometric.distributed.rpc import RpcRouter + + +def run_rpc_feature_test( + world_size: int, + rank: int, + feature: LocalFeatureStore, + partition_book: torch.Tensor, + master_port: int, +): + # 1) Initialize the context info: + current_ctx = DistContext( + rank=rank, + global_rank=rank, + world_size=world_size, + global_world_size=world_size, + group_name='dist-feature-test', + ) + rpc_worker_names: Dict[DistRole, List[str]] = {} + + rpc.init_rpc( + current_ctx=current_ctx, + rpc_worker_names=rpc_worker_names, + master_addr='localhost', + master_port=master_port, + ) + + # 2) Collect all workers: + partition_to_workers = rpc.rpc_partition_to_workers( + current_ctx, world_size, rank) + + assert partition_to_workers == [ + ['dist-feature-test-0'], + ['dist-feature-test-1'], + ] + + # 3) Find the mapping between worker and partition ID: + rpc_router = RpcRouter(partition_to_workers) + + assert rpc_router.get_to_worker(partition_idx=0) == 'dist-feature-test-0' + assert rpc_router.get_to_worker(partition_idx=1) == 'dist-feature-test-1' + + meta = { + 'edge_types': None, + 'is_hetero': False, + 'node_types': None, + 'num_parts': 2 + } + + feature.num_partitions = world_size + feature.partition_idx = rank + feature.feature_pb = partition_book + feature.meta = meta + feature.set_local_only(local_only=False) + feature.set_rpc_router(rpc_router) + + # Global node IDs: + global_id0 = torch.arange(128 * 2) + global_id1 = torch.arange(128 * 2) + 128 * 2 + + # Lookup the features from stores including locally and remotely: + tensor0 = feature.lookup_features(global_id0) + tensor1 = feature.lookup_features(global_id1) + + # Expected searched results: + cpu_tensor0 = torch.cat([torch.ones(128, 1024), torch.ones(128, 1024) * 2]) + cpu_tensor1 = torch.cat([torch.zeros(128, 1024), torch.zeros(128, 1024)]) + + # Verify.. + assert torch.allclose(cpu_tensor0, tensor0.wait()) + assert torch.allclose(cpu_tensor1, tensor1.wait()) + + rpc.shutdown_rpc() + + +def test_dist_feature_lookup(): + cpu_tensor0 = torch.cat([torch.ones(128, 1024), torch.ones(128, 1024) * 2]) + cpu_tensor1 = torch.cat([torch.zeros(128, 1024), torch.zeros(128, 1024)]) + + # Global node IDs: + global_id0 = torch.arange(128 * 2) + global_id1 = torch.arange(128 * 2) + 128 * 2 + + # Set the partition book for two features (partition 0 and 1): + partition_book = torch.cat([ + torch.zeros(128 * 2, dtype=torch.long), + torch.ones(128 * 2, dtype=torch.long) + ]) + + # Put the test tensor into the different feature stores with IDs: + feature0 = LocalFeatureStore() + feature0.put_global_id(global_id0, group_name=None) + feature0.put_tensor(cpu_tensor0, group_name=None, attr_name='x') + + feature1 = LocalFeatureStore() + feature1.put_global_id(global_id1, group_name=None) + feature1.put_tensor(cpu_tensor1, group_name=None, attr_name='x') + + mp_context = torch.multiprocessing.get_context('spawn') + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() + + w0 = mp_context.Process(target=run_rpc_feature_test, + args=(2, 0, feature0, partition_book, port)) + w1 = mp_context.Process(target=run_rpc_feature_test, + args=(2, 1, feature1, partition_book, port)) + + w0.start() + w1.start() + w0.join() + w1.join() diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index fd19b16dbbd3..d8fdadacc2b1 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -20,7 +20,7 @@ from torch import Tensor from torch_geometric.data import EdgeAttr, FeatureStore, GraphStore, TensorAttr -from torch_geometric.data.feature_store import _field_status +from torch_geometric.data.feature_store import _FieldStatus from torch_geometric.data.graph_store import EdgeLayout from torch_geometric.data.storage import ( BaseStorage, @@ -372,7 +372,7 @@ class DataTensorAttr(TensorAttr): r"""Tensor attribute for `Data` without group name.""" def __init__( self, - attr_name=_field_status.UNSET, + attr_name=_FieldStatus.UNSET, index=None, ): super().__init__(None, attr_name, index) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index 10c7f1efd90f..946479d0f633 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -32,13 +32,15 @@ from torch_geometric.typing import FeatureTensorType, NodeType from torch_geometric.utils.mixin import CastMixin -_field_status = Enum("FieldStatus", "UNSET") - # We allow indexing with a tensor, numpy array, Python slicing, or a single # integer index. IndexType = Union[torch.Tensor, np.ndarray, slice, int] +class _FieldStatus(Enum): + UNSET = None + + @dataclass class TensorAttr(CastMixin): r"""Defines the attributes of a :class:`FeatureStore` tensor. @@ -52,20 +54,20 @@ class TensorAttr(CastMixin): """ # The group name that the tensor corresponds to. Defaults to UNSET. - group_name: Optional[NodeType] = _field_status.UNSET + group_name: Optional[NodeType] = _FieldStatus.UNSET # The name of the tensor within its group. Defaults to UNSET. - attr_name: Optional[str] = _field_status.UNSET + attr_name: Optional[str] = _FieldStatus.UNSET # The node indices the rows of the tensor correspond to. Defaults to UNSET. - index: Optional[IndexType] = _field_status.UNSET + index: Optional[IndexType] = _FieldStatus.UNSET # Convenience methods ##################################################### def is_set(self, key: str) -> bool: r"""Whether an attribute is set in :obj:`TensorAttr`.""" assert key in self.__dataclass_fields__ - return getattr(self, key) != _field_status.UNSET + return getattr(self, key) != _FieldStatus.UNSET def is_fully_specified(self) -> bool: r"""Whether the :obj:`TensorAttr` has no unset fields.""" @@ -137,7 +139,7 @@ def __getattr__(self, key: Any) -> Union['AttrView', FeatureTensorType]: # Find the first attribute name that is UNSET: attr_name: Optional[str] = None for field in out._attr.__dataclass_fields__: - if getattr(out._attr, field) == _field_status.UNSET: + if getattr(out._attr, field) == _FieldStatus.UNSET: attr_name = field break diff --git a/torch_geometric/distributed/dist_context.py b/torch_geometric/distributed/dist_context.py new file mode 100644 index 000000000000..5b3e72f733e3 --- /dev/null +++ b/torch_geometric/distributed/dist_context.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from enum import Enum + + +class DistRole(Enum): + WORKER = 1 + + +@dataclass +class DistContext: + r"""Context information of the current process.""" + rank: int + global_rank: int + world_size: int + global_world_size: int + group_name: str + role: DistRole = DistRole.WORKER + + @property + def worker_name(self) -> str: + return f'{self.group_name}-{self.rank}' diff --git a/torch_geometric/distributed/local_feature_store.py b/torch_geometric/distributed/local_feature_store.py index 9cc5c395e0fb..260087e09a74 100644 --- a/torch_geometric/distributed/local_feature_store.py +++ b/torch_geometric/distributed/local_feature_store.py @@ -8,7 +8,7 @@ from torch import Tensor from torch_geometric.data import FeatureStore, TensorAttr -from torch_geometric.data.feature_store import _field_status +from torch_geometric.data.feature_store import _FieldStatus from torch_geometric.typing import EdgeType, NodeType @@ -17,8 +17,8 @@ class LocalTensorAttr(TensorAttr): r"""Tensor attribute for storing features without :obj:`index`.""" def __init__( self, - group_name: Optional[Union[NodeType, EdgeType]] = _field_status.UNSET, - attr_name: Optional[str] = _field_status.UNSET, + group_name: Optional[Union[NodeType, EdgeType]] = _FieldStatus.UNSET, + attr_name: Optional[str] = _FieldStatus.UNSET, index=None, ): super().__init__(group_name, attr_name, index) diff --git a/torch_geometric/distributed/rpc.py b/torch_geometric/distributed/rpc.py new file mode 100644 index 000000000000..3f710d9b8b56 --- /dev/null +++ b/torch_geometric/distributed/rpc.py @@ -0,0 +1,190 @@ +import atexit +import logging +import threading +from abc import ABC, abstractmethod +from typing import Dict, List + +from torch._C._distributed_rpc import _is_current_rpc_agent_set +from torch.distributed import rpc + +from torch_geometric.distributed.dist_context import DistContext, DistRole + +_rpc_init_lock = threading.RLock() + + +def rpc_is_initialized() -> bool: + return _is_current_rpc_agent_set() + + +@rpc.api._require_initialized +def global_all_gather(obj, timeout=None): + r"""Gathers objects from all groups in a list.""" + if timeout is None: + return rpc.api._all_gather(obj) + return rpc.api._all_gather(obj, timeout=timeout) + + +@rpc.api._require_initialized +def global_barrier(timeout=None): + r""" Block until all local and remote RPC processes.""" + try: + global_all_gather(obj=None, timeout=timeout) + except RuntimeError: + logging.error("Failed to respond to global barrier") + + +def init_rpc( + current_ctx: DistContext, + rpc_worker_names: Dict[DistRole, List[str]], + master_addr: str, + master_port: int, + num_rpc_threads: int = 16, + rpc_timeout: float = 240, +): + with _rpc_init_lock: + if rpc_is_initialized(): + return + + if current_ctx is None: + raise RuntimeError("'dist_context' has not been set in 'init_rpc'") + + options = rpc.TensorPipeRpcBackendOptions( + _transports=['ibv', 'uv'], + _channels=['mpt_uv', 'basic'], + num_worker_threads=num_rpc_threads, + rpc_timeout=rpc_timeout, + init_method=f'tcp://{master_addr}:{master_port}', + ) + + rpc.init_rpc( + name=current_ctx.worker_name, + rank=current_ctx.global_rank, + world_size=current_ctx.global_world_size, + rpc_backend_options=options, + ) + + gathered_results = global_all_gather( + obj=(current_ctx.role, current_ctx.world_size, current_ctx.rank), + timeout=rpc_timeout, + ) + + for worker_name, (role, world_size, rank) in gathered_results.items(): + worker_list = rpc_worker_names.get(role, None) + if worker_list is None: + worker_list = [None for _ in range(world_size)] + else: + if len(worker_list) != world_size: + raise RuntimeError(f"Inconsistent world size found in " + f"'init_rpc' (got {len(worker_list)})") + + worker_list[rank] = worker_name + rpc_worker_names[role] = worker_list + + global_barrier(timeout=rpc_timeout) + + +def shutdown_rpc(graceful: bool = True): + if rpc_is_initialized(): + rpc.shutdown(graceful) + + +atexit.register(shutdown_rpc, False) + + +class RpcRouter: + r"""A router to get the worker based on the partition ID.""" + def __init__(self, partition_to_workers: List[List[str]]): + for pid, rpc_worker_list in enumerate(partition_to_workers): + if len(rpc_worker_list) == 0: + raise ValueError("No RPC worker is in worker list") + self.partition_to_workers = partition_to_workers + self.rpc_worker_indices = [0 for _ in range(len(partition_to_workers))] + + def get_to_worker(self, partition_idx: int) -> str: + rpc_worker_list = self.partition_to_workers[partition_idx] + worker_idx = self.rpc_worker_indices[partition_idx] + router_worker = rpc_worker_list[worker_idx] + self.rpc_worker_indices[partition_idx] = ((worker_idx + 1) % + len(rpc_worker_list)) + return router_worker + + +@rpc.api._require_initialized +def rpc_partition_to_workers( + current_ctx: DistContext, + num_partitions: int, + current_partition_idx: int, +): + r"""Performs an :obj:`all_gather` to get the mapping between partition and + workers.""" + ctx = current_ctx + partition_to_workers = [[] for _ in range(num_partitions)] + gathered_results = global_all_gather( + (ctx.role, num_partitions, current_partition_idx)) + for worker_name, (role, nparts, idx) in gathered_results.items(): + partition_to_workers[idx].append(worker_name) + return partition_to_workers + + +class RpcCallBase(ABC): + r"""A wrapper base class for RPC calls in remote processes.""" + @abstractmethod + def rpc_sync(self, *args, **kwargs): + pass + + @abstractmethod + def rpc_async(self, *args, **kwargs): + pass + + +_rpc_call_lock = threading.RLock() +_rpc_call_id: int = 0 +_rpc_call_pool: Dict[int, RpcCallBase] = {} + + +@rpc.api._require_initialized +def rpc_register(call: RpcCallBase) -> int: + r"""Registers a call for RPC requests.""" + global _rpc_call_id, _rpc_call_pool + + with _rpc_call_lock: + call_id = _rpc_call_id + _rpc_call_id += 1 + if call_id in _rpc_call_pool: + raise RuntimeError("Registered function twice in 'rpc_register'") + _rpc_call_pool[call_id] = call + + return call_id + + +def _rpc_async_call(call_id: int, *args, **kwargs): + r""" Entry point for RPC requests.""" + return _rpc_call_pool.get(call_id).rpc_async(*args, **kwargs) + + +@rpc.api._require_initialized +def rpc_async(worker_name: str, call_id: int, args=None, kwargs=None): + r"""Performs an asynchronous RPC request and returns a future.""" + return rpc.rpc_async( + to=worker_name, + func=_rpc_async_call, + args=(call_id, *args), + kwargs=kwargs, + ) + + +def _rpc_sync_call(call_id: int, *args, **kwargs): + r"""Entry point for synchronous RPC requests.""" + return _rpc_call_pool.get(call_id).rpc_sync(*args, **kwargs) + + +@rpc.api._require_initialized +def rpc_sync(worker_name: str, call_id: int, args=None, kwargs=None): + r"""Performs a synchronous RPC request and returns a future.""" + future = rpc.rpc_async( + to=worker_name, + func=_rpc_sync_call, + args=(call_id, *args), + kwargs=kwargs, + ) + return future.wait()