From b5e9809afc3cf909f1479cb687e29f71ef99bd1f Mon Sep 17 00:00:00 2001 From: "Liu,Kaixuan" Date: Mon, 21 Aug 2023 02:54:03 -0700 Subject: [PATCH] using global node ids for partitioned graph's row/col Signed-off-by: Liu,Kaixuan --- torch_geometric/distributed/partition.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torch_geometric/distributed/partition.py b/torch_geometric/distributed/partition.py index 918d03648de5..13b72bc835f0 100644 --- a/torch_geometric/distributed/partition.py +++ b/torch_geometric/distributed/partition.py @@ -150,10 +150,14 @@ def generate_partition(self): size = (self.data[src].num_nodes, self.data[dst].num_nodes) mask = part_data.edge_type == i + rows = part_data.edge_index[0, mask] + cols = part_data.edge_index[1, mask] + global_rows = node_id[rows] + global_cols = node_perm[cols] out[edge_type] = { 'edge_id': edge_id[mask], - 'row': part_data.edge_index[0, mask], - 'col': part_data.edge_index[1, mask], + 'row': global_rows, + 'col': global_cols, 'size': size, } torch.save(out, osp.join(path, 'graph.pt')) @@ -213,12 +217,16 @@ def generate_partition(self): node_id = node_perm[start:end] node_map[node_id] = pid + rows = part_data.edge_index[0] + cols = part_data.edge_index[1] + global_rows = node_id[rows] + global_cols = node_perm[cols] torch.save( { 'edge_id': edge_id, - 'row': part_data.edge_index[0], - 'col': part_data.edge_index[1], + 'row': global_rows, + 'col': global_cols, 'size': (data.num_nodes, data.num_nodes), }, osp.join(path, 'graph.pt'))