Skip to content

Commit

Permalink
typo
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed May 17, 2023
1 parent c4da3e4 commit 517e93c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/loader/test_prefetch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch

from torch_geometric.loader import NeighborLoader
from torch_geometric.loader.prefetch import PrefetchLoader
from torch_geometric.loader import NeighborLoader, PrefetchLoader
from torch_geometric.nn import GraphSAGE
from torch_geometric.testing import withCUDA

Expand All @@ -26,7 +25,7 @@ def test_prefetch_loader(device):
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--num_workers', type=int, default=0)
args = parser.parse_args()

data = PygNodePropPredDataset('ogbn-products', root='/tmp/ogb')[0]
Expand All @@ -44,6 +43,7 @@ def test_prefetch_loader(device):
num_neighbors=[10, 10],
num_workers=args.num_workers,
filter_per_worker=True,
persistent_workers=args.num_workers > 0,
)

print('Forward pass without prefetching...')
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .neighbor_sampler import NeighborSampler
from .imbalanced_sampler import ImbalancedSampler
from .dynamic_batch_sampler import DynamicBatchSampler
from .prefetch import PrefetchLoader
from .mixin import AffinityMixin

__all__ = classes = [
Expand All @@ -42,6 +43,7 @@
'NeighborSampler',
'ImbalancedSampler',
'DynamicBatchSampler',
'PrefetchLoader',
'AffinityMixin',
]

Expand Down

0 comments on commit 517e93c

Please sign in to comment.