Skip to content

Commit

Permalink
using global node ids for partitioned graph's row/col
Browse files Browse the repository at this point in the history
Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com>
  • Loading branch information
kaixuanliu committed Aug 21, 2023
1 parent deff5a4 commit b5e9809
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions torch_geometric/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,14 @@ def generate_partition(self):
size = (self.data[src].num_nodes, self.data[dst].num_nodes)

mask = part_data.edge_type == i
rows = part_data.edge_index[0, mask]
cols = part_data.edge_index[1, mask]
global_rows = node_id[rows]
global_cols = node_perm[cols]
out[edge_type] = {
'edge_id': edge_id[mask],
'row': part_data.edge_index[0, mask],
'col': part_data.edge_index[1, mask],
'row': global_rows,
'col': global_cols,
'size': size,
}
torch.save(out, osp.join(path, 'graph.pt'))
Expand Down Expand Up @@ -213,12 +217,16 @@ def generate_partition(self):

node_id = node_perm[start:end]
node_map[node_id] = pid
rows = part_data.edge_index[0]
cols = part_data.edge_index[1]
global_rows = node_id[rows]
global_cols = node_perm[cols]

torch.save(
{
'edge_id': edge_id,
'row': part_data.edge_index[0],
'col': part_data.edge_index[1],
'row': global_rows,
'col': global_cols,
'size': (data.num_nodes, data.num_nodes),
}, osp.join(path, 'graph.pt'))

Expand Down

0 comments on commit b5e9809

Please sign in to comment.