Skip to content

Commit

Permalink
change partition.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Sep 9, 2024
1 parent b2907b4 commit 24522e1
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,9 +1314,9 @@ def get_homogeneous(g, balance_ntypes):
for name in g.edges[etype].data:
if name in [EID, "inner_edge"]:
continue
edge_feats[_etype_tuple_to_str(etype) + "/" + name] = (
F.gather_row(g.edges[etype].data[name], local_edges)
)
edge_feats[
_etype_tuple_to_str(etype) + "/" + name
] = F.gather_row(g.edges[etype].data[name], local_edges)
else:
for ntype in g.ntypes:
if len(g.ntypes) > 1:
Expand Down Expand Up @@ -1351,9 +1351,9 @@ def get_homogeneous(g, balance_ntypes):
for name in g.edges[etype].data:
if name in [EID, "inner_edge"]:
continue
edge_feats[_etype_tuple_to_str(etype) + "/" + name] = (
F.gather_row(g.edges[etype].data[name], local_edges)
)
edge_feats[
_etype_tuple_to_str(etype) + "/" + name
] = F.gather_row(g.edges[etype].data[name], local_edges)
# delete `orig_id` from ndata/edata
del part.ndata["orig_id"]
del part.edata["orig_id"]
Expand All @@ -1372,22 +1372,35 @@ def get_homogeneous(g, balance_ntypes):
}
sort_etypes = len(g.etypes) > 1
part = _process_partitions(part, graph_formats, sort_etypes)
if use_graphbolt:
# save FusedCSCSamplingGraph
kwargs["graph_formats"] = graph_formats
kwargs.pop("n_jobs", None)
_partition_to_graphbolt(
part_i=part_id,
part_config=part_config,
part_metadata=part_metadata,
parts=parts,
**kwargs,
)
else:

# transmit to graphbolt and save graph
if use_graphbolt:
# save FusedCSCSamplingGraph
kwargs["graph_formats"] = graph_formats
n_jobs = kwargs.pop("n_jobs", 1)
mp_ctx = mp.get_context("spawn")
with concurrent.futures.ProcessPoolExecutor(
max_workers=min(num_parts, n_jobs),
mp_context=mp_ctx,
) as executor:
for part_id in range(num_parts):
executor.submit(
_partition_to_graphbolt(
part_i=part_id,
part_config=part_config,
part_metadata=part_metadata,
parts=parts,
**kwargs,
)
)
else:
for part_id, part in parts.items():
part_dir = os.path.join(out_path, "part" + str(part_id))
part_graph_file = os.path.join(part_dir, "graph.dgl")

Check warning on line 1399 in python/dgl/distributed/partition.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
part_metadata["part-{}".format(part_id)]["part_graph"] = (
os.path.relpath(part_graph_file, out_path)
)
part_metadata[
"part-{}".format(part_id)][
"part_graph"
] = os.path.relpath(part_graph_file, out_path)
# save DGLGraph
_save_dgl_graphs(
part_graph_file,
Expand Down

0 comments on commit 24522e1

Please sign in to comment.