Skip to content

Commit

Permalink
Handle node-level and edge-level temporal information when generating…
Browse files Browse the repository at this point in the history
… partitions (#8718)

**Description**

**Temporal data definition:**
Time data can be added to the FeatureStore in two ways:
- it can be obtained directly from the partition using
`LocalFeatureStore.from_partition()` function
or
- you can add them yourself using the `put_tensor()` function on the
`LocalFeatureStore` object.


- Node-level temporal data: each partition must have the same time
vector, which is global and its size is equal to the number of nodes in
the whole graph.
Why:
We operate on global node ids.

- Edge-level temporal data: each partition has its own time vector,
which is local for a given partition and its size is equal to the number
of edges in the given partition (part_data.edge_index.size(1)).
Why:
Each partition has its own unique edge_index in COO format, which is
later converted to a CSR/CSC matrix in the neighbor sampler. So we do
not have information about the global edge IDs when sampling and we
would not be able to find the correct time information for a specific
edge. Therefore, this information must be local.

**How to distinguish node-level or edge-level temporal data:**
- `time_attr`='time' for node-level temporal sampling.
- `time_attr`='edge_time' for edge-level temporal sampling. It is
different from a single machine case when both edge time and node time
have `time_attr`='time'. It is handled this way because of the lack of
the node_store/edge_store in the feature store, so at the moment we
determine whether to use node-level or edge-level temporal sampling
based on the attribute name.

**Where temporal data is saved:**
- `time` has been added to the node features -> node_feats.pt
- `edge_time` has been added to the edge features -> edge_feats.pt

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
kgajdamo and rusty1s authored Jan 22, 2024
1 parent b26c034 commit 81fdeaf
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 55 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for graph partitioning for temporal data in `torch_geometric.distributed` ([#8718](https://github.com/pyg-team/pytorch_geometric/pull/8718))
- Added `TreeGraph` and `GridMotif` generators ([#8736](https://github.com/pyg-team/pytorch_geometric/pull/8736))
- Added an example for edge-level temporal sampling on a heterogenous graph ([#8383](https://github.com/pyg-team/pytorch_geometric/pull/8383))
- Added the `num_graphs` option to the `StochasticBlockModelDataset` ([#8648](https://github.com/pyg-team/pytorch_geometric/pull/8648))
Expand Down
41 changes: 7 additions & 34 deletions test/distributed/test_dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def create_data(rank: int, world_size: int, time_attr: Optional[str] = None):
def create_hetero_data(
tmp_path: str,
rank: int,
time_attr: Optional[str] = None,
):
graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank)
feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank)
Expand All @@ -93,28 +92,6 @@ def create_hetero_data(
graph_store.edge_pb = feature_store.edge_feat_pb = edge_pb
graph_store.meta = feature_store.meta = meta

if time_attr == 'time': # Create node-level time data:
feature_store.put_tensor(
tensor=torch.full((len(node_pb['v0']), ), 1, dtype=torch.int64),
group_name='v0',
attr_name=time_attr,
)
feature_store.put_tensor(
tensor=torch.full((len(node_pb['v1']), ), 2, dtype=torch.int64),
group_name='v1',
attr_name=time_attr,
)
elif time_attr == 'edge_time': # Create edge-level time data:
i = 0
for attr, edge_index in graph_store._edge_index.items():
time = torch.full((edge_index.size(1), ), i, dtype=torch.int64)
feature_store.put_tensor(
tensor=time,
group_name=attr[0],
attr_name=time_attr,
)
i += 1

return feature_store, graph_store


Expand Down Expand Up @@ -350,7 +327,7 @@ def dist_neighbor_sampler_temporal_hetero(
temporal_strategy: str = 'uniform',
time_attr: str = 'time',
):
dist_data = create_hetero_data(tmp_path, rank, time_attr)
dist_data = create_hetero_data(tmp_path, rank)

current_ctx = DistContext(
rank=rank,
Expand Down Expand Up @@ -575,16 +552,14 @@ def test_dist_neighbor_sampler_temporal_hetero(
edge_dim=2,
)[0]

partitioner = Partitioner(data, world_size, tmp_path)
partitioner.generate_partition()

# The partition generation script does not currently support temporal data.
# Therefore, it needs to be added after generating partitions.
data['v0'].time = torch.full((data.num_nodes_dict['v0'], ), 1,
dtype=torch.int64)
data['v1'].time = torch.full((data.num_nodes_dict['v1'], ), 2,
dtype=torch.int64)

partitioner = Partitioner(data, world_size, tmp_path)
partitioner.generate_partition()

w0 = mp_context.Process(
target=dist_neighbor_sampler_temporal_hetero,
args=(data, tmp_path, world_size, 0, port, 'v0', seed_time,
Expand Down Expand Up @@ -629,15 +604,13 @@ def test_dist_neighbor_sampler_edge_level_temporal_hetero(
edge_dim=2,
)[0]

partitioner = Partitioner(data, world_size, tmp_path)
partitioner.generate_partition()

# The partition generation script does not currently support temporal data.
# Therefore, it needs to be added after generating partitions.
for i, edge_type in enumerate(data.edge_types):
data[edge_type].edge_time = torch.full(
(data[edge_type].edge_index.size(1), ), i, dtype=torch.int64)

partitioner = Partitioner(data, world_size, tmp_path)
partitioner.generate_partition()

w0 = mp_context.Process(
target=dist_neighbor_sampler_temporal_hetero,
args=(data, tmp_path, world_size, 0, port, 'v0', seed_time,
Expand Down
224 changes: 216 additions & 8 deletions test/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_partition_data(tmp_path):
data = FakeDataset()[0]
num_parts = 2

partitioner = Partitioner(data, num_parts, tmp_path)
partitioner = Partitioner(data, num_parts=2, root=tmp_path)
partitioner.generate_partition()

node_map_path = osp.join(tmp_path, 'node_map.pt')
Expand Down Expand Up @@ -72,9 +71,9 @@ def test_partition_data(tmp_path):
@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_partition_hetero_data(tmp_path):
data = FakeHeteroDataset()[0]
num_parts = 2

partitioner = Partitioner(data, num_parts, tmp_path)
num_parts = 2
partitioner = Partitioner(data, num_parts=num_parts, root=tmp_path)
partitioner.generate_partition()

meta_path = osp.join(tmp_path, 'META.json')
Expand Down Expand Up @@ -103,12 +102,105 @@ def test_partition_hetero_data(tmp_path):
assert osp.exists(edge_feats_path)


@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_partition_data_temporal(tmp_path):
data = FakeDataset()[0]
data.time = torch.arange(data.num_nodes)

partitioner = Partitioner(data, num_parts=2, root=tmp_path)
partitioner.generate_partition()

node_feats0_path = osp.join(tmp_path, 'part_0', 'node_feats.pt')
assert osp.exists(node_feats0_path)
node_feats0 = torch.load(node_feats0_path)

node_feats1_path = osp.join(tmp_path, 'part_1', 'node_feats.pt')
assert osp.exists(node_feats1_path)
node_feats1 = torch.load(node_feats1_path)

assert torch.equal(data.time, node_feats0['time'])
assert torch.equal(data.time, node_feats1['time'])


@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_partition_data_edge_level_temporal(tmp_path):
data = FakeDataset(edge_dim=2)[0]
data.edge_time = torch.arange(data.num_edges)

partitioner = Partitioner(data, num_parts=2, root=tmp_path)
partitioner.generate_partition()

edge_feats0_path = osp.join(tmp_path, 'part_0', 'edge_feats.pt')
assert osp.exists(edge_feats0_path)
edge_feats0 = torch.load(edge_feats0_path)

edge_feats1_path = osp.join(tmp_path, 'part_1', 'edge_feats.pt')
assert osp.exists(edge_feats1_path)
edge_feats1 = torch.load(edge_feats1_path)

assert torch.equal(data.edge_time[edge_feats0['global_id']],
edge_feats0['edge_time'])
assert torch.equal(data.edge_time[edge_feats1['global_id']],
edge_feats1['edge_time'])


@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_partition_hetero_data_temporal(tmp_path):
data = FakeHeteroDataset()[0]

for key in data.node_types:
data[key].time = torch.arange(data[key].num_nodes)

partitioner = Partitioner(data, num_parts=2, root=tmp_path)
partitioner.generate_partition()

node_feats0_path = osp.join(tmp_path, 'part_0', 'node_feats.pt')
assert osp.exists(node_feats0_path)
node_feats0 = torch.load(node_feats0_path)

node_feats1_path = osp.join(tmp_path, 'part_1', 'node_feats.pt')
assert osp.exists(node_feats1_path)
node_feats1 = torch.load(node_feats1_path)

for key in data.node_types:
assert torch.equal(data[key].time, node_feats0[key]['time'])
assert torch.equal(data[key].time, node_feats1[key]['time'])


@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_partition_hetero_data_edge_level_temporal(tmp_path):
data = FakeHeteroDataset(edge_dim=2)[0]

for key in data.edge_types:
data[key].edge_time = torch.arange(data[key].num_edges)

partitioner = Partitioner(data, num_parts=2, root=tmp_path)
partitioner.generate_partition()

edge_feats0_path = osp.join(tmp_path, 'part_0', 'edge_feats.pt')
assert osp.exists(edge_feats0_path)
edge_feats0 = torch.load(edge_feats0_path)

edge_feats1_path = osp.join(tmp_path, 'part_1', 'edge_feats.pt')
assert osp.exists(edge_feats1_path)
edge_feats1 = torch.load(edge_feats1_path)

for key in data.edge_types:
assert torch.equal(
data[key].edge_time[edge_feats0[key]['global_id']],
edge_feats0[key]['edge_time'],
)
assert torch.equal(
data[key].edge_time[edge_feats1[key]['global_id']],
edge_feats1[key]['edge_time'],
)


@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_from_partition_data(tmp_path):
data = FakeDataset()[0]
num_parts = 2

partitioner = Partitioner(data, num_parts, tmp_path)
partitioner = Partitioner(data, num_parts=2, root=tmp_path)
partitioner.generate_partition()

graph_store1 = LocalGraphStore.from_partition(tmp_path, pid=0)
Expand Down Expand Up @@ -141,9 +233,8 @@ def test_from_partition_data(tmp_path):
@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_from_partition_hetero_data(tmp_path):
data = FakeHeteroDataset()[0]
num_parts = 2

partitioner = Partitioner(data, num_parts, tmp_path)
partitioner = Partitioner(data, num_parts=2, root=tmp_path)
partitioner.generate_partition()

graph_store1 = LocalGraphStore.from_partition(tmp_path, pid=0)
Expand All @@ -164,3 +255,120 @@ def test_from_partition_hetero_data(tmp_path):
node_types.add(attr.edge_type[0])
node_types.add(attr.edge_type[2])
assert node_types == set(data.node_types)


@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_from_partition_temporal_data(tmp_path):
data = FakeDataset()[0]
data.time = torch.arange(data.num_nodes)

partitioner = Partitioner(data, num_parts=2, root=tmp_path)
partitioner.generate_partition()

feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0)
feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1)

time_attr1 = feat_store1.get_all_tensor_attrs()[1]
assert time_attr1.attr_name == 'time'
time1 = feat_store1.get_tensor(time_attr1)

time_attr2 = feat_store2.get_all_tensor_attrs()[1]
assert time_attr2.attr_name == 'time'
time2 = feat_store2.get_tensor(time_attr2)

assert time1.size(0) == data.num_nodes
assert time2.size(0) == data.num_nodes
assert torch.equal(time1, data.time)
assert torch.equal(time2, data.time)


@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_from_partition_edge_level_temporal_data(tmp_path):
data = FakeDataset(edge_dim=2)[0]
data.edge_time = torch.arange(data.num_edges)

partitioner = Partitioner(data, num_parts=2, root=tmp_path)
partitioner.generate_partition()

feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0)
feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1)

time_attr1 = feat_store1.get_all_tensor_attrs()[2]
assert time_attr1.attr_name == 'edge_time'
time1 = feat_store1.get_tensor(time_attr1)

time_attr2 = feat_store2.get_all_tensor_attrs()[2]
assert time_attr2.attr_name == 'edge_time'
time2 = feat_store2.get_tensor(time_attr2)

edge_id1 = feat_store1.get_global_id(group_name=(None, None))
edge_id2 = feat_store2.get_global_id(group_name=(None, None))

assert time1.size(0) + time2.size(0) == data.edge_index.size(1)
assert torch.equal(data.edge_time[edge_id1], time1)
assert torch.equal(data.edge_time[edge_id2], time2)


@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_from_partition_hetero_temporal_data(tmp_path):
data = FakeHeteroDataset()[0]

for key in data.node_types:
data[key].time = torch.arange(data[key].num_nodes)

partitioner = Partitioner(data, num_parts=2, root=tmp_path)
partitioner.generate_partition()

feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0)
feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1)

attrs1 = feat_store1.get_all_tensor_attrs()
attrs2 = feat_store2.get_all_tensor_attrs()

times1 = {
attr.group_name: feat_store1.get_tensor(attr)
for attr in attrs1 if attr.attr_name == 'time'
}
times2 = {
attr.group_name: feat_store2.get_tensor(attr)
for attr in attrs2 if attr.attr_name == 'time'
}

for key in data.node_types:
assert times1[key].size(0) == data[key].num_nodes
assert times2[key].size(0) == data[key].num_nodes
assert torch.equal(times1[key], data[key].time)
assert torch.equal(times2[key], data[key].time)


@pytest.mark.skipif(not WITH_METIS, reason='Not compiled with METIS support')
def test_from_partition_hetero_edge_level_temporal_data(tmp_path):
data = FakeHeteroDataset(edge_dim=2)[0]

for key in data.edge_types:
data[key].edge_time = torch.arange(data[key].num_edges)

partitioner = Partitioner(data, num_parts=2, root=tmp_path)
partitioner.generate_partition()

feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0)
feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1)

attrs1 = feat_store1.get_all_tensor_attrs()
attrs2 = feat_store2.get_all_tensor_attrs()

times1 = {
attr.group_name: feat_store1.get_tensor(attr)
for attr in attrs1 if attr.attr_name == 'edge_time'
}
times2 = {
attr.group_name: feat_store2.get_tensor(attr)
for attr in attrs2 if attr.attr_name == 'edge_time'
}

for key in data.edge_types:
edge_id1 = feat_store1.get_global_id(group_name=key)
edge_id2 = feat_store2.get_global_id(group_name=key)
assert times1[key].size(0) + times2[key].size(0) == data[key].num_edges
assert torch.equal(data[key].edge_time[edge_id1], times1[key])
assert torch.equal(data[key].edge_time[edge_id2], times2[key])
15 changes: 15 additions & 0 deletions torch_geometric/distributed/local_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,20 @@ def from_partition(cls, root: str, pid: int) -> 'LocalFeatureStore':
feat_store.put_global_id(node_feats['global_id'], group_name=None)
for key, value in node_feats['feats'].items():
feat_store.put_tensor(value, group_name=None, attr_name=key)
if 'time' in node_feats:
feat_store.put_tensor(node_feats['time'], group_name=None,
attr_name='time')

if not meta['is_hetero'] and edge_feats is not None:
feat_store.put_global_id(edge_feats['global_id'],
group_name=(None, None))
for key, value in edge_feats['feats'].items():
feat_store.put_tensor(value, group_name=(None, None),
attr_name=key)
if 'edge_time' in edge_feats:
feat_store.put_tensor(edge_feats['edge_time'],
group_name=(None, None),
attr_name='edge_time')

if meta['is_hetero'] and node_feats is not None:
for node_type, node_feat in node_feats.items():
Expand All @@ -431,6 +438,10 @@ def from_partition(cls, root: str, pid: int) -> 'LocalFeatureStore':
for key, value in node_feat['feats'].items():
feat_store.put_tensor(value, group_name=node_type,
attr_name=key)
if 'time' in node_feat:
feat_store.put_tensor(node_feat['time'],
group_name=node_type,
attr_name='time')

if meta['is_hetero'] and edge_feats is not None:
for edge_type, edge_feat in edge_feats.items():
Expand All @@ -439,5 +450,9 @@ def from_partition(cls, root: str, pid: int) -> 'LocalFeatureStore':
for key, value in edge_feat['feats'].items():
feat_store.put_tensor(value, group_name=edge_type,
attr_name=key)
if 'edge_time' in edge_feat:
feat_store.put_tensor(edge_feat['edge_time'],
group_name=edge_type,
attr_name='edge_time')

return feat_store
Loading

0 comments on commit 81fdeaf

Please sign in to comment.