Skip to content

Commit

Permalink
Fixed a bug in which distributed_sampling_loader.py mistakenly spli…
Browse files Browse the repository at this point in the history
…t data. (#6312)

I am sorry that I made a mistake in that we don't split data into each
GPU. Now I fixed it and tested it.

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
LukeLIN-web and rusty1s authored Dec 29, 2022
1 parent e27d935 commit 77d38cb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/multi_gpu/distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def run(rank, world_size, dataset):
train_idx = train_idx.split(train_idx.size(0) // world_size)[rank]

kwargs = dict(batch_size=1024, num_workers=4, persistent_workers=True)
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
train_loader = NeighborLoader(data, input_nodes=train_idx,
num_neighbors=[25, 10], shuffle=True,
drop_last=True, **kwargs)

Expand Down

0 comments on commit 77d38cb

Please sign in to comment.