Skip to content

Commit

Permalink
test update
Browse files Browse the repository at this point in the history
  • Loading branch information
kgajdamo committed Dec 21, 2023
1 parent ff919a8 commit af46c11
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions test/distributed/test_dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ def create_hetero_data(tmp_path: str, rank: int,

if time_attr == 'time': # Create node-level time data:
feature_store.put_tensor(
tensor=torch.full((len(out[3]['v0']), ), 1, dtype=torch.int64),
tensor=torch.full((len(node_pb['v0']), ), 1, dtype=torch.int64),
group_name='v0',
attr_name=time_attr,
)
feature_store.put_tensor(
tensor=torch.full((len(out[3]['v1']), ), 2, dtype=torch.int64),
tensor=torch.full((len(node_pb['v1']), ), 2, dtype=torch.int64),
group_name='v1',
attr_name=time_attr,
)
Expand Down Expand Up @@ -368,7 +368,7 @@ def dist_neighbor_sampler_temporal_hetero(
)

# Close RPC & worker group at exit:
atexit.register(close_sampler, 0, dist_sampler)
atexit.register(shutdown_rpc)

init_rpc(
current_ctx=current_ctx,
Expand All @@ -377,8 +377,10 @@ def dist_neighbor_sampler_temporal_hetero(
master_port=master_port,
)

dist_sampler.init_sampler_instance()
dist_sampler.register_sampler_rpc()
dist_sampler.init_event_loop()
dist_sampler.event_loop = ConcurrentEventLoop(2)
dist_sampler.event_loop.start_loop()

# Create inputs nodes such that each belongs to a different partition:
node_pb_list = dist_data[1].node_pb[input_type].tolist()
Expand Down

0 comments on commit af46c11

Please sign in to comment.