Skip to content

Commit

Permalink
Distributed node-level and edge-level temporal sampling for hetero (#…
Browse files Browse the repository at this point in the history
…8624)

The purpose of this PR is to enable distributed sampling with node-level
and edge-level temporal information for heterogeneous graphs.

**Description:**
- Heterogeneous temporal sampling is analogous to homogeneous temporal
sampling, but takes into account the presence of node types and edge
types. We define the node time information for each node type and the
edge time information for each edge type.
- Because of the lack of the node_store/edge_store in the feature store
we determine whether to use node-level or edge-level temporal sampling
based on the time_attr value ('time' or 'edge_time').
- seed_time is mandatory for edge-level sampling.
- seed_time field has been added to the NodeDict class to store the time
information of the source nodes.
- The time information of source nodes for the next layer is calculated
based on the subgraph ID a given node belongs to (each subgraph has a
seed time specified, which is common for all source nodes).

---------

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
kgajdamo and rusty1s authored Dec 30, 2023
1 parent d1f305b commit 1b3112f
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 24 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `ExplainerDataset` will now contain node labels for any motif generator ([#8519](https://github.com/pyg-team/pytorch_geometric/pull/8519))
- Made `utils.softmax` faster via `softmax_csr` ([#8399](https://github.com/pyg-team/pytorch_geometric/pull/8399))
- Made `utils.mask.mask_select` faster ([#8369](https://github.com/pyg-team/pytorch_geometric/pull/8369))
- Update `DistNeighborSampler` for homogeneous graphs ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375))
- Update `DistNeighborSampler` ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375), ([#8624](https://github.com/pyg-team/pytorch_geometric/pull/8624))
- Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083))
- Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210))
- Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220))
Expand Down
219 changes: 218 additions & 1 deletion test/distributed/test_dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def create_data(rank: int, world_size: int, time_attr: Optional[str] = None):
return (feature_store, graph_store), data


def create_hetero_data(tmp_path: str, rank: int):
def create_hetero_data(
tmp_path: str,
rank: int,
time_attr: Optional[str] = None,
):
graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank)
feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank)
(
Expand All @@ -89,6 +93,28 @@ def create_hetero_data(tmp_path: str, rank: int):
graph_store.edge_pb = feature_store.edge_feat_pb = edge_pb
graph_store.meta = feature_store.meta = meta

if time_attr == 'time': # Create node-level time data:
feature_store.put_tensor(
tensor=torch.full((len(node_pb['v0']), ), 1, dtype=torch.int64),
group_name='v0',
attr_name=time_attr,
)
feature_store.put_tensor(
tensor=torch.full((len(node_pb['v1']), ), 2, dtype=torch.int64),
group_name='v1',
attr_name=time_attr,
)
elif time_attr == 'edge_time': # Create edge-level time data:
i = 0
for attr, edge_index in graph_store._edge_index.items():
time = torch.full((edge_index.size(1), ), i, dtype=torch.int64)
feature_store.put_tensor(
tensor=time,
group_name=attr[0],
attr_name=time_attr,
)
i += 1

return feature_store, graph_store


Expand Down Expand Up @@ -313,6 +339,89 @@ def dist_neighbor_sampler_hetero(
)


def dist_neighbor_sampler_temporal_hetero(
data: FakeHeteroDataset,
tmp_path: str,
world_size: int,
rank: int,
master_port: int,
input_type: str,
seed_time: torch.tensor = None,
temporal_strategy: str = 'uniform',
time_attr: str = 'time',
):
dist_data = create_hetero_data(tmp_path, rank, time_attr)

current_ctx = DistContext(
rank=rank,
global_rank=rank,
world_size=world_size,
global_world_size=world_size,
group_name='dist-sampler-test',
)

dist_sampler = DistNeighborSampler(
data=dist_data,
current_ctx=current_ctx,
rpc_worker_names={},
num_neighbors=[-1, -1],
shuffle=False,
disjoint=True,
temporal_strategy=temporal_strategy,
time_attr=time_attr,
)

# Close RPC & worker group at exit:
atexit.register(shutdown_rpc)

init_rpc(
current_ctx=current_ctx,
rpc_worker_names={},
master_addr='localhost',
master_port=master_port,
)

dist_sampler.init_sampler_instance()
dist_sampler.register_sampler_rpc()
dist_sampler.event_loop = ConcurrentEventLoop(2)
dist_sampler.event_loop.start_loop()

# Create inputs nodes such that each belongs to a different partition:
node_pb_list = dist_data[1].node_pb[input_type].tolist()
node_0 = node_pb_list.index(0)
node_1 = node_pb_list.index(1)

input_node = torch.tensor([node_0, node_1], dtype=torch.int64)

inputs = NodeSamplerInput(
input_id=None,
node=input_node,
time=seed_time,
input_type=input_type,
)

# Evaluate distributed node sample function:
out_dist = dist_sampler.event_loop.run_task(
coro=dist_sampler.node_sample(inputs))

sampler = NeighborSampler(
data=data,
num_neighbors=[-1, -1],
disjoint=True,
temporal_strategy=temporal_strategy,
time_attr=time_attr,
)

# Evaluate node sample function:
out = node_sample(inputs, sampler._sample)

# Compare distributed output with single machine output:
for k in data.node_types:
assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0])
assert torch.equal(out_dist.batch[k].sort()[0], out.batch[k].sort()[0])
assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k]


@onlyLinux
@withPackage('pyg_lib')
@pytest.mark.parametrize('disjoint', [False, True])
Expand Down Expand Up @@ -437,3 +546,111 @@ def test_dist_neighbor_sampler_hetero(tmp_path, disjoint):
w1.start()
w0.join()
w1.join()


@withPackage('pyg_lib')
@pytest.mark.parametrize('seed_time', [None, [0, 0], [2, 2]])
@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])
def test_dist_neighbor_sampler_temporal_hetero(
tmp_path,
seed_time,
temporal_strategy,
):
if seed_time is not None:
seed_time = torch.tensor(seed_time)

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]

world_size = 2
data = FakeHeteroDataset(
num_graphs=1,
avg_num_nodes=100,
avg_degree=3,
num_node_types=2,
num_edge_types=4,
edge_dim=2,
)[0]

partitioner = Partitioner(data, world_size, tmp_path)
partitioner.generate_partition()

# The partition generation script does not currently support temporal data.
# Therefore, it needs to be added after generating partitions.
data['v0'].time = torch.full((data.num_nodes_dict['v0'], ), 1,
dtype=torch.int64)
data['v1'].time = torch.full((data.num_nodes_dict['v1'], ), 2,
dtype=torch.int64)

w0 = mp_context.Process(
target=dist_neighbor_sampler_temporal_hetero,
args=(data, tmp_path, world_size, 0, port, 'v0', seed_time,
temporal_strategy, 'time'),
)

w1 = mp_context.Process(
target=dist_neighbor_sampler_temporal_hetero,
args=(data, tmp_path, world_size, 1, port, 'v1', seed_time,
temporal_strategy, 'time'),
)

w0.start()
w1.start()
w0.join()
w1.join()


@withPackage('pyg_lib')
@pytest.mark.parametrize('seed_time', [[0, 0], [1, 2]])
@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])
def test_dist_neighbor_sampler_edge_level_temporal_hetero(
tmp_path,
seed_time,
temporal_strategy,
):
seed_time = torch.tensor(seed_time)

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]

world_size = 2
data = FakeHeteroDataset(
num_graphs=1,
avg_num_nodes=100,
avg_degree=3,
num_node_types=2,
num_edge_types=4,
edge_dim=2,
)[0]

partitioner = Partitioner(data, world_size, tmp_path)
partitioner.generate_partition()

# The partition generation script does not currently support temporal data.
# Therefore, it needs to be added after generating partitions.
for i, edge_type in enumerate(data.edge_types):
data[edge_type].edge_time = torch.full(
(data[edge_type].edge_index.size(1), ), i, dtype=torch.int64)

w0 = mp_context.Process(
target=dist_neighbor_sampler_temporal_hetero,
args=(data, tmp_path, world_size, 0, port, 'v0', seed_time,
temporal_strategy, 'edge_time'),
)

w1 = mp_context.Process(
target=dist_neighbor_sampler_temporal_hetero,
args=(data, tmp_path, world_size, 1, port, 'v1', seed_time,
temporal_strategy, 'edge_time'),
)

w0.start()
w1.start()
w0.join()
w1.join()
43 changes: 30 additions & 13 deletions torch_geometric/distributed/dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def register_sampler_rpc(self) -> None:
rpc_sample_callee = RPCSamplingCallee(self)
self.rpc_sample_callee_id = rpc_register(rpc_sample_callee)

def init_event_loop(self):
def init_event_loop(self) -> None:
if self.event_loop is None:
self.event_loop = ConcurrentEventLoop(self.concurrency)
self.event_loop.start_loop()
Expand Down Expand Up @@ -228,16 +228,19 @@ async def node_sample(
else:
raise ValueError("Seed time needs to be specified")

# Heterogeneous Neighborhood Sampling #################################

if self.is_hetero:
if input_type is None:
raise ValueError("Input type should be defined")

seed_dict: Dict[NodeType, Tensor] = {input_type: seed}
seed_time_dict: Dict[NodeType, Tensor] = {input_type: seed_time}

node_dict = NodeDict(self.node_types, self.num_hops)
batch_dict = BatchDict(self.node_types, self.num_hops)

seed_dict: Dict[NodeType, Tensor] = {input_type: seed}
if self.temporal:
node_dict.seed_time[input_type][0] = seed_time.clone()

edge_dict: Dict[EdgeType, Tensor] = {
k: torch.empty(0, dtype=torch.int64)
for k in self.edge_types
Expand Down Expand Up @@ -281,8 +284,6 @@ async def node_sample(
num_sampled_edges_dict[edge_type].append(0)
continue

seed_time = (seed_time_dict or {}).get(src, None)

if isinstance(self.num_neighbors, list):
one_hop_num = self.num_neighbors[i]
else:
Expand All @@ -292,7 +293,7 @@ async def node_sample(
out = await self.sample_one_hop(
node_dict.src[src][i],
one_hop_num,
seed_time,
node_dict.seed_time[src][i],
batch_dict.src[src][i],
edge_type,
)
Expand All @@ -306,9 +307,9 @@ async def node_sample(

# Remove duplicates:
(
node_src,
src_node,
node_dict.out[dst],
batch_src,
src_batch,
batch_dict.out[dst],
) = remove_duplicates(
out,
Expand All @@ -319,10 +320,10 @@ async def node_sample(

# Create src nodes for the next layer:
node_dict.src[dst][i + 1] = torch.cat(
[node_dict.src[dst][i + 1], node_src])
[node_dict.src[dst][i + 1], src_node])
if self.disjoint:
batch_dict.src[dst][i + 1] = torch.cat(
[batch_dict.src[dst][i + 1], batch_src])
[batch_dict.src[dst][i + 1], src_batch])

# Save sampled nodes with duplicates to be able to create
# local edge indices:
Expand All @@ -336,6 +337,18 @@ async def node_sample(
batch_dict.with_dupl[dst] = torch.cat(
[batch_dict.with_dupl[dst], out.batch])

if self.temporal and i < self.num_hops - 1:
# Assign seed time based on source node subgraph ID:
src_seed_time = [
seed_time[(seed_batch == batch_idx).nonzero()]
for batch_idx in src_batch
]
src_seed_time = torch.as_tensor(
src_seed_time, dtype=torch.int64)

node_dict.seed_time[dst][i + 1] = torch.cat(
[node_dict.seed_time[dst][i + 1], src_seed_time])

# Collect sampled neighbors per node for each layer:
sampled_nbrs_per_node_dict[edge_type][i] += out.metadata[0]

Expand Down Expand Up @@ -371,6 +384,9 @@ async def node_sample(
num_sampled_edges=num_sampled_edges_dict,
metadata=metadata,
)

# Homogeneous Neighborhood Sampling ###################################

else:
src = seed
node = src.clone()
Expand Down Expand Up @@ -797,8 +813,9 @@ def _sample_one_hop(
rel_type = '__'.join(edge_type)
colptr = self._sampler.colptr_dict[rel_type]
row = self._sampler.row_dict[rel_type]
node_time = (self.node_time or {}).get(edge_type[2], None)
edge_time = (self.edge_time or {}).get(edge_type[2], None)
# `node_time` is a destination node time:
node_time = (self.node_time or {}).get(edge_type[0], None)
edge_time = (self.edge_time or {}).get(edge_type, None)

out = torch.ops.pyg.dist_neighbor_sample(
colptr,
Expand Down
4 changes: 4 additions & 0 deletions torch_geometric/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __init__(self, node_types, num_hops):
k: torch.empty(0, dtype=torch.int64)
for k in node_types
}
self.seed_time: Dict[NodeType, List[Tensor]] = {
k: num_hops * [torch.empty(0, dtype=torch.int64)]
for k in node_types
}


class BatchDict:
Expand Down
Loading

0 comments on commit 1b3112f

Please sign in to comment.