Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed node-level and edge-level temporal sampling for hetero #8624

Merged
merged 7 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `ExplainerDataset` will now contain node labels for any motif generator ([#8519](https://github.com/pyg-team/pytorch_geometric/pull/8519))
- Made `utils.softmax` faster via `softmax_csr` ([#8399](https://github.com/pyg-team/pytorch_geometric/pull/8399))
- Made `utils.mask.mask_select` faster ([#8369](https://github.com/pyg-team/pytorch_geometric/pull/8369))
- Update `DistNeighborSampler` for homogeneous graphs ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375))
- Update `DistNeighborSampler` ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375), ([#8624](https://github.com/pyg-team/pytorch_geometric/pull/8624))
- Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083))
- Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210))
- Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220))
Expand Down
219 changes: 218 additions & 1 deletion test/distributed/test_dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def create_data(rank: int, world_size: int, time_attr: Optional[str] = None):
return (feature_store, graph_store), data


def create_hetero_data(tmp_path: str, rank: int):
def create_hetero_data(
tmp_path: str,
rank: int,
time_attr: Optional[str] = None,
):
graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank)
feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank)
(
Expand All @@ -89,6 +93,28 @@ def create_hetero_data(tmp_path: str, rank: int):
graph_store.edge_pb = feature_store.edge_feat_pb = edge_pb
graph_store.meta = feature_store.meta = meta

if time_attr == 'time': # Create node-level time data:
feature_store.put_tensor(
tensor=torch.full((len(node_pb['v0']), ), 1, dtype=torch.int64),
group_name='v0',
attr_name=time_attr,
)
feature_store.put_tensor(
tensor=torch.full((len(node_pb['v1']), ), 2, dtype=torch.int64),
group_name='v1',
attr_name=time_attr,
)
elif time_attr == 'edge_time': # Create edge-level time data:
i = 0
for attr, edge_index in graph_store._edge_index.items():
time = torch.full((edge_index.size(1), ), i, dtype=torch.int64)
feature_store.put_tensor(
tensor=time,
group_name=attr[0],
attr_name=time_attr,
)
i += 1

return feature_store, graph_store


Expand Down Expand Up @@ -313,6 +339,89 @@ def dist_neighbor_sampler_hetero(
)


def dist_neighbor_sampler_temporal_hetero(
data: FakeHeteroDataset,
tmp_path: str,
world_size: int,
rank: int,
master_port: int,
input_type: str,
seed_time: torch.tensor = None,
temporal_strategy: str = 'uniform',
time_attr: str = 'time',
):
dist_data = create_hetero_data(tmp_path, rank, time_attr)

current_ctx = DistContext(
rank=rank,
global_rank=rank,
world_size=world_size,
global_world_size=world_size,
group_name='dist-sampler-test',
)

dist_sampler = DistNeighborSampler(
data=dist_data,
current_ctx=current_ctx,
rpc_worker_names={},
num_neighbors=[-1, -1],
shuffle=False,
disjoint=True,
temporal_strategy=temporal_strategy,
time_attr=time_attr,
)

# Close RPC & worker group at exit:
atexit.register(shutdown_rpc)

init_rpc(
current_ctx=current_ctx,
rpc_worker_names={},
master_addr='localhost',
master_port=master_port,
)

dist_sampler.init_sampler_instance()
dist_sampler.register_sampler_rpc()
dist_sampler.event_loop = ConcurrentEventLoop(2)
dist_sampler.event_loop.start_loop()

# Create inputs nodes such that each belongs to a different partition:
node_pb_list = dist_data[1].node_pb[input_type].tolist()
node_0 = node_pb_list.index(0)
node_1 = node_pb_list.index(1)

input_node = torch.tensor([node_0, node_1], dtype=torch.int64)

inputs = NodeSamplerInput(
input_id=None,
node=input_node,
time=seed_time,
input_type=input_type,
)

# Evaluate distributed node sample function:
out_dist = dist_sampler.event_loop.run_task(
coro=dist_sampler.node_sample(inputs))

sampler = NeighborSampler(
data=data,
num_neighbors=[-1, -1],
disjoint=True,
temporal_strategy=temporal_strategy,
time_attr=time_attr,
)

# Evaluate node sample function:
out = node_sample(inputs, sampler._sample)

# Compare distributed output with single machine output:
for k in data.node_types:
assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0])
assert torch.equal(out_dist.batch[k].sort()[0], out.batch[k].sort()[0])
assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k]


@onlyLinux
@withPackage('pyg_lib')
@pytest.mark.parametrize('disjoint', [False, True])
Expand Down Expand Up @@ -437,3 +546,111 @@ def test_dist_neighbor_sampler_hetero(tmp_path, disjoint):
w1.start()
w0.join()
w1.join()


@withPackage('pyg_lib')
@pytest.mark.parametrize('seed_time', [None, [0, 0], [2, 2]])
@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])
def test_dist_neighbor_sampler_temporal_hetero(
tmp_path,
seed_time,
temporal_strategy,
):
if seed_time is not None:
seed_time = torch.tensor(seed_time)

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]

world_size = 2
data = FakeHeteroDataset(
num_graphs=1,
avg_num_nodes=100,
avg_degree=3,
num_node_types=2,
num_edge_types=4,
edge_dim=2,
)[0]

partitioner = Partitioner(data, world_size, tmp_path)
partitioner.generate_partition()

# The partition generation script does not currently support temporal data.
# Therefore, it needs to be added after generating partitions.
data['v0'].time = torch.full((data.num_nodes_dict['v0'], ), 1,
dtype=torch.int64)
data['v1'].time = torch.full((data.num_nodes_dict['v1'], ), 2,
dtype=torch.int64)

w0 = mp_context.Process(
target=dist_neighbor_sampler_temporal_hetero,
args=(data, tmp_path, world_size, 0, port, 'v0', seed_time,
temporal_strategy, 'time'),
)

w1 = mp_context.Process(
target=dist_neighbor_sampler_temporal_hetero,
args=(data, tmp_path, world_size, 1, port, 'v1', seed_time,
temporal_strategy, 'time'),
)

w0.start()
w1.start()
w0.join()
w1.join()


@withPackage('pyg_lib')
@pytest.mark.parametrize('seed_time', [[0, 0], [1, 2]])
@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])
def test_dist_neighbor_sampler_edge_level_temporal_hetero(
tmp_path,
seed_time,
temporal_strategy,
):
seed_time = torch.tensor(seed_time)

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]

world_size = 2
data = FakeHeteroDataset(
num_graphs=1,
avg_num_nodes=100,
avg_degree=3,
num_node_types=2,
num_edge_types=4,
edge_dim=2,
)[0]

partitioner = Partitioner(data, world_size, tmp_path)
partitioner.generate_partition()

# The partition generation script does not currently support temporal data.
# Therefore, it needs to be added after generating partitions.
for i, edge_type in enumerate(data.edge_types):
data[edge_type].edge_time = torch.full(
(data[edge_type].edge_index.size(1), ), i, dtype=torch.int64)

w0 = mp_context.Process(
target=dist_neighbor_sampler_temporal_hetero,
args=(data, tmp_path, world_size, 0, port, 'v0', seed_time,
temporal_strategy, 'edge_time'),
)

w1 = mp_context.Process(
target=dist_neighbor_sampler_temporal_hetero,
args=(data, tmp_path, world_size, 1, port, 'v1', seed_time,
temporal_strategy, 'edge_time'),
)

w0.start()
w1.start()
w0.join()
w1.join()
43 changes: 30 additions & 13 deletions torch_geometric/distributed/dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def register_sampler_rpc(self) -> None:
rpc_sample_callee = RPCSamplingCallee(self)
self.rpc_sample_callee_id = rpc_register(rpc_sample_callee)

def init_event_loop(self):
def init_event_loop(self) -> None:
if self.event_loop is None:
self.event_loop = ConcurrentEventLoop(self.concurrency)
self.event_loop.start_loop()
Expand Down Expand Up @@ -228,16 +228,19 @@ async def node_sample(
else:
raise ValueError("Seed time needs to be specified")

# Heterogeneous Neighborhood Sampling #################################

if self.is_hetero:
if input_type is None:
raise ValueError("Input type should be defined")

seed_dict: Dict[NodeType, Tensor] = {input_type: seed}
seed_time_dict: Dict[NodeType, Tensor] = {input_type: seed_time}

node_dict = NodeDict(self.node_types, self.num_hops)
batch_dict = BatchDict(self.node_types, self.num_hops)

seed_dict: Dict[NodeType, Tensor] = {input_type: seed}
if self.temporal:
node_dict.seed_time[input_type][0] = seed_time.clone()

edge_dict: Dict[EdgeType, Tensor] = {
k: torch.empty(0, dtype=torch.int64)
for k in self.edge_types
Expand Down Expand Up @@ -281,8 +284,6 @@ async def node_sample(
num_sampled_edges_dict[edge_type].append(0)
continue

seed_time = (seed_time_dict or {}).get(src, None)

if isinstance(self.num_neighbors, list):
one_hop_num = self.num_neighbors[i]
else:
Expand All @@ -292,7 +293,7 @@ async def node_sample(
out = await self.sample_one_hop(
node_dict.src[src][i],
one_hop_num,
seed_time,
node_dict.seed_time[src][i],
batch_dict.src[src][i],
edge_type,
)
Expand All @@ -306,9 +307,9 @@ async def node_sample(

# Remove duplicates:
(
node_src,
src_node,
node_dict.out[dst],
batch_src,
src_batch,
batch_dict.out[dst],
) = remove_duplicates(
out,
Expand All @@ -319,10 +320,10 @@ async def node_sample(

# Create src nodes for the next layer:
node_dict.src[dst][i + 1] = torch.cat(
[node_dict.src[dst][i + 1], node_src])
[node_dict.src[dst][i + 1], src_node])
if self.disjoint:
batch_dict.src[dst][i + 1] = torch.cat(
[batch_dict.src[dst][i + 1], batch_src])
[batch_dict.src[dst][i + 1], src_batch])

# Save sampled nodes with duplicates to be able to create
# local edge indices:
Expand All @@ -336,6 +337,18 @@ async def node_sample(
batch_dict.with_dupl[dst] = torch.cat(
[batch_dict.with_dupl[dst], out.batch])

if self.temporal and i < self.num_hops - 1:
# Assign seed time based on source node subgraph ID:
src_seed_time = [
seed_time[(seed_batch == batch_idx).nonzero()]
for batch_idx in src_batch
]
src_seed_time = torch.as_tensor(
src_seed_time, dtype=torch.int64)

node_dict.seed_time[dst][i + 1] = torch.cat(
[node_dict.seed_time[dst][i + 1], src_seed_time])

# Collect sampled neighbors per node for each layer:
sampled_nbrs_per_node_dict[edge_type][i] += out.metadata[0]

Expand Down Expand Up @@ -371,6 +384,9 @@ async def node_sample(
num_sampled_edges=num_sampled_edges_dict,
metadata=metadata,
)

# Homogeneous Neighborhood Sampling ###################################

else:
src = seed
node = src.clone()
Expand Down Expand Up @@ -797,8 +813,9 @@ def _sample_one_hop(
rel_type = '__'.join(edge_type)
colptr = self._sampler.colptr_dict[rel_type]
row = self._sampler.row_dict[rel_type]
node_time = (self.node_time or {}).get(edge_type[2], None)
edge_time = (self.edge_time or {}).get(edge_type[2], None)
# `node_time` is a destination node time:
node_time = (self.node_time or {}).get(edge_type[0], None)
edge_time = (self.edge_time or {}).get(edge_type, None)

out = torch.ops.pyg.dist_neighbor_sample(
colptr,
Expand Down
4 changes: 4 additions & 0 deletions torch_geometric/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __init__(self, node_types, num_hops):
k: torch.empty(0, dtype=torch.int64)
for k in node_types
}
self.seed_time: Dict[NodeType, List[Tensor]] = {
k: num_hops * [torch.empty(0, dtype=torch.int64)]
for k in node_types
}


class BatchDict:
Expand Down
Loading
Loading