Skip to content

Commit

Permalink
Merge branch 'master' into test/loader
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta authored Jul 28, 2023
2 parents 5bf8ba7 + 4889a1e commit ddd41d6
Show file tree
Hide file tree
Showing 8 changed files with 350 additions and 16 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493))
- Added the `AmazonBook` heterogeneous dataset ([#7483](https://github.com/pyg-team/pytorch_geometric/pull/7483))
- Added hierarchical heterogeneous GraphSAGE example on OGB-MAG ([#7425](https://github.com/pyg-team/pytorch_geometric/pull/7425))
- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628))
- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628), [#7671](https://github.com/pyg-team/pytorch_geometric/pull/7671))
- Added the `GDELTLite` dataset ([#7442](https://github.com/pyg-team/pytorch_geometric/pull/7442))
- Added the `approx_knn` function for approximated nearest neighbor search ([#7421](https://github.com/pyg-team/pytorch_geometric/pull/7421))
- Added the `IGMCDataset` ([#7441](https://github.com/pyg-team/pytorch_geometric/pull/7441))
Expand Down Expand Up @@ -77,6 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Made `FieldStatus` enum picklable to avoid `PicklingError` in a multi-process setting ([#7808](https://github.com/pyg-team/pytorch_geometric/pull/7808))
- Fixed `edge_label_time` computation in `LinkNeighborLoader` for homogeneous graphs ([#7807](https://github.com/pyg-team/pytorch_geometric/pull/7807))
- Fixed `edge_label_index` computation in `LinkNeighborLoader` for the homogeneous+`disjoint` mode ([#7791](https://github.com/pyg-team/pytorch_geometric/pull/7791))
- Fixed `CaptumExplainer` for `binary_classification` tasks ([#7787](https://github.com/pyg-team/pytorch_geometric/pull/7787))
Expand Down
5 changes: 2 additions & 3 deletions test/data/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import torch

from torch_geometric.data import TensorAttr
from torch_geometric.data.feature_store import AttrView, _field_status
from torch_geometric.data.feature_store import AttrView, _FieldStatus
from torch_geometric.testing import MyFeatureStore


@dataclass
class MyTensorAttrNoGroupName(TensorAttr):
def __init__(self, attr_name=_field_status.UNSET,
index=_field_status.UNSET):
def __init__(self, attr_name=_FieldStatus.UNSET, index=_FieldStatus.UNSET):
# Treat group_name as optional, and move it to the end
super().__init__(None, attr_name, index)

Expand Down
121 changes: 121 additions & 0 deletions test/distributed/test_rpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import socket
from typing import Dict, List

import torch

import torch_geometric.distributed.rpc as rpc
from torch_geometric.distributed import LocalFeatureStore
from torch_geometric.distributed.dist_context import DistContext, DistRole
from torch_geometric.distributed.rpc import RpcRouter


def run_rpc_feature_test(
world_size: int,
rank: int,
feature: LocalFeatureStore,
partition_book: torch.Tensor,
master_port: int,
):
# 1) Initialize the context info:
current_ctx = DistContext(
rank=rank,
global_rank=rank,
world_size=world_size,
global_world_size=world_size,
group_name='dist-feature-test',
)
rpc_worker_names: Dict[DistRole, List[str]] = {}

rpc.init_rpc(
current_ctx=current_ctx,
rpc_worker_names=rpc_worker_names,
master_addr='localhost',
master_port=master_port,
)

# 2) Collect all workers:
partition_to_workers = rpc.rpc_partition_to_workers(
current_ctx, world_size, rank)

assert partition_to_workers == [
['dist-feature-test-0'],
['dist-feature-test-1'],
]

# 3) Find the mapping between worker and partition ID:
rpc_router = RpcRouter(partition_to_workers)

assert rpc_router.get_to_worker(partition_idx=0) == 'dist-feature-test-0'
assert rpc_router.get_to_worker(partition_idx=1) == 'dist-feature-test-1'

meta = {
'edge_types': None,
'is_hetero': False,
'node_types': None,
'num_parts': 2
}

feature.num_partitions = world_size
feature.partition_idx = rank
feature.feature_pb = partition_book
feature.meta = meta
feature.set_local_only(local_only=False)
feature.set_rpc_router(rpc_router)

# Global node IDs:
global_id0 = torch.arange(128 * 2)
global_id1 = torch.arange(128 * 2) + 128 * 2

# Lookup the features from stores including locally and remotely:
tensor0 = feature.lookup_features(global_id0)
tensor1 = feature.lookup_features(global_id1)

# Expected searched results:
cpu_tensor0 = torch.cat([torch.ones(128, 1024), torch.ones(128, 1024) * 2])
cpu_tensor1 = torch.cat([torch.zeros(128, 1024), torch.zeros(128, 1024)])

# Verify..
assert torch.allclose(cpu_tensor0, tensor0.wait())
assert torch.allclose(cpu_tensor1, tensor1.wait())

rpc.shutdown_rpc()


def test_dist_feature_lookup():
cpu_tensor0 = torch.cat([torch.ones(128, 1024), torch.ones(128, 1024) * 2])
cpu_tensor1 = torch.cat([torch.zeros(128, 1024), torch.zeros(128, 1024)])

# Global node IDs:
global_id0 = torch.arange(128 * 2)
global_id1 = torch.arange(128 * 2) + 128 * 2

# Set the partition book for two features (partition 0 and 1):
partition_book = torch.cat([
torch.zeros(128 * 2, dtype=torch.long),
torch.ones(128 * 2, dtype=torch.long)
])

# Put the test tensor into the different feature stores with IDs:
feature0 = LocalFeatureStore()
feature0.put_global_id(global_id0, group_name=None)
feature0.put_tensor(cpu_tensor0, group_name=None, attr_name='x')

feature1 = LocalFeatureStore()
feature1.put_global_id(global_id1, group_name=None)
feature1.put_tensor(cpu_tensor1, group_name=None, attr_name='x')

mp_context = torch.multiprocessing.get_context('spawn')
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]
s.close()

w0 = mp_context.Process(target=run_rpc_feature_test,
args=(2, 0, feature0, partition_book, port))
w1 = mp_context.Process(target=run_rpc_feature_test,
args=(2, 1, feature1, partition_book, port))

w0.start()
w1.start()
w0.join()
w1.join()
4 changes: 2 additions & 2 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor

from torch_geometric.data import EdgeAttr, FeatureStore, GraphStore, TensorAttr
from torch_geometric.data.feature_store import _field_status
from torch_geometric.data.feature_store import _FieldStatus
from torch_geometric.data.graph_store import EdgeLayout
from torch_geometric.data.storage import (
BaseStorage,
Expand Down Expand Up @@ -372,7 +372,7 @@ class DataTensorAttr(TensorAttr):
r"""Tensor attribute for `Data` without group name."""
def __init__(
self,
attr_name=_field_status.UNSET,
attr_name=_FieldStatus.UNSET,
index=None,
):
super().__init__(None, attr_name, index)
Expand Down
16 changes: 9 additions & 7 deletions torch_geometric/data/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@
from torch_geometric.typing import FeatureTensorType, NodeType
from torch_geometric.utils.mixin import CastMixin

_field_status = Enum("FieldStatus", "UNSET")

# We allow indexing with a tensor, numpy array, Python slicing, or a single
# integer index.
IndexType = Union[torch.Tensor, np.ndarray, slice, int]


class _FieldStatus(Enum):
UNSET = None


@dataclass
class TensorAttr(CastMixin):
r"""Defines the attributes of a :class:`FeatureStore` tensor.
Expand All @@ -52,20 +54,20 @@ class TensorAttr(CastMixin):
"""

# The group name that the tensor corresponds to. Defaults to UNSET.
group_name: Optional[NodeType] = _field_status.UNSET
group_name: Optional[NodeType] = _FieldStatus.UNSET

# The name of the tensor within its group. Defaults to UNSET.
attr_name: Optional[str] = _field_status.UNSET
attr_name: Optional[str] = _FieldStatus.UNSET

# The node indices the rows of the tensor correspond to. Defaults to UNSET.
index: Optional[IndexType] = _field_status.UNSET
index: Optional[IndexType] = _FieldStatus.UNSET

# Convenience methods #####################################################

def is_set(self, key: str) -> bool:
r"""Whether an attribute is set in :obj:`TensorAttr`."""
assert key in self.__dataclass_fields__
return getattr(self, key) != _field_status.UNSET
return getattr(self, key) != _FieldStatus.UNSET

def is_fully_specified(self) -> bool:
r"""Whether the :obj:`TensorAttr` has no unset fields."""
Expand Down Expand Up @@ -137,7 +139,7 @@ def __getattr__(self, key: Any) -> Union['AttrView', FeatureTensorType]:
# Find the first attribute name that is UNSET:
attr_name: Optional[str] = None
for field in out._attr.__dataclass_fields__:
if getattr(out._attr, field) == _field_status.UNSET:
if getattr(out._attr, field) == _FieldStatus.UNSET:
attr_name = field
break

Expand Down
21 changes: 21 additions & 0 deletions torch_geometric/distributed/dist_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass
from enum import Enum


class DistRole(Enum):
WORKER = 1


@dataclass
class DistContext:
r"""Context information of the current process."""
rank: int
global_rank: int
world_size: int
global_world_size: int
group_name: str
role: DistRole = DistRole.WORKER

@property
def worker_name(self) -> str:
return f'{self.group_name}-{self.rank}'
6 changes: 3 additions & 3 deletions torch_geometric/distributed/local_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import Tensor

from torch_geometric.data import FeatureStore, TensorAttr
from torch_geometric.data.feature_store import _field_status
from torch_geometric.data.feature_store import _FieldStatus
from torch_geometric.typing import EdgeType, NodeType


Expand All @@ -17,8 +17,8 @@ class LocalTensorAttr(TensorAttr):
r"""Tensor attribute for storing features without :obj:`index`."""
def __init__(
self,
group_name: Optional[Union[NodeType, EdgeType]] = _field_status.UNSET,
attr_name: Optional[str] = _field_status.UNSET,
group_name: Optional[Union[NodeType, EdgeType]] = _FieldStatus.UNSET,
attr_name: Optional[str] = _FieldStatus.UNSET,
index=None,
):
super().__init__(group_name, attr_name, index)
Expand Down
Loading

0 comments on commit ddd41d6

Please sign in to comment.