diff --git a/tests/distributed/test_distributed_sampling.py b/tests/distributed/test_distributed_sampling.py index e6b28be8c83a..2ecce809ff18 100644 --- a/tests/distributed/test_distributed_sampling.py +++ b/tests/distributed/test_distributed_sampling.py @@ -1902,7 +1902,25 @@ def create_hetero_graph(dense=False, empty=False): part_config = tmpdir / "test_sampling.json" - dgl.distributed.initialize("rpc_ip_config.txt") + pserver_list = [] + ctx = mp.get_context("spawn") + for i in range(num_server): + p = ctx.Process( + target=start_server, + args=( + i, + tmpdir, + num_server > 1, + "test_sampling", + ["csc", "coo"], + True, + ), + ) + p.start() + time.sleep(1) + pserver_list.append(p) + + dgl.distributed.initialize("rpc_ip_config.txt", use_graphbolt=True) dist_graph = DistGraph("test_sampling", part_config=part_config) os.environ["DGL_DIST_DEBUG"] = "1"