From 517e93c5435eb309a9ca332036d2ed0f7abfa393 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 17 May 2023 21:03:26 +0000 Subject: [PATCH] typo --- test/loader/test_prefetch.py | 6 +++--- torch_geometric/loader/__init__.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/loader/test_prefetch.py b/test/loader/test_prefetch.py index 7877a9ebc135..caa7ba3ae314 100644 --- a/test/loader/test_prefetch.py +++ b/test/loader/test_prefetch.py @@ -1,7 +1,6 @@ import torch -from torch_geometric.loader import NeighborLoader -from torch_geometric.loader.prefetch import PrefetchLoader +from torch_geometric.loader import NeighborLoader, PrefetchLoader from torch_geometric.nn import GraphSAGE from torch_geometric.testing import withCUDA @@ -26,7 +25,7 @@ def test_prefetch_loader(device): from tqdm import tqdm parser = argparse.ArgumentParser() - parser.add_argument('--num_workers', type=int, default=8) + parser.add_argument('--num_workers', type=int, default=0) args = parser.parse_args() data = PygNodePropPredDataset('ogbn-products', root='/tmp/ogb')[0] @@ -44,6 +43,7 @@ def test_prefetch_loader(device): num_neighbors=[10, 10], num_workers=args.num_workers, filter_per_worker=True, + persistent_workers=args.num_workers > 0, ) print('Forward pass without prefetching...') diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py index 600d3483e3fa..494a380023e2 100644 --- a/torch_geometric/loader/__init__.py +++ b/torch_geometric/loader/__init__.py @@ -18,6 +18,7 @@ from .neighbor_sampler import NeighborSampler from .imbalanced_sampler import ImbalancedSampler from .dynamic_batch_sampler import DynamicBatchSampler +from .prefetch import PrefetchLoader from .mixin import AffinityMixin __all__ = classes = [ @@ -42,6 +43,7 @@ 'NeighborSampler', 'ImbalancedSampler', 'DynamicBatchSampler', + 'PrefetchLoader', 'AffinityMixin', ]