diff --git a/CHANGELOG.md b/CHANGELOG.md index fc84c1755963..3f210d30ece1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added `CachedLoader` implementation ([#7896](https://github.com/pyg-team/pytorch_geometric/pull/7896)) +- Added `CachedLoader` implementation ([#7896](https://github.com/pyg-team/pytorch_geometric/pull/7896), [#7897](https://github.com/pyg-team/pytorch_geometric/pull/7897)) - Added possibility to run training benchmarks on XPU device ([#7925](https://github.com/pyg-team/pytorch_geometric/pull/7925)) - Added `utils.ppr` for personalized PageRank computation ([#7917](https://github.com/pyg-team/pytorch_geometric/pull/7917)) - Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918)) diff --git a/benchmark/inference/inference_benchmark.py b/benchmark/inference/inference_benchmark.py index 7625735086fd..35e839396d53 100644 --- a/benchmark/inference/inference_benchmark.py +++ b/benchmark/inference/inference_benchmark.py @@ -80,6 +80,9 @@ def run(args: argparse.ArgumentParser): _, _, test_mask = get_split_masks(data, dataset_name) degree = None + if hetero and args.cached_loader: + args.cached_loader = False + print('Disabling CachedLoader, not supported in Hetero models') if args.num_layers != [1] and not hetero and args.num_steps != -1: raise ValueError("Layer-wise inference requires `steps=-1`") @@ -209,7 +212,7 @@ def run(args: argparse.ArgumentParser): data = transformation(data) with cpu_affinity, amp, timeit() as time: - inference_kwargs = {} + inference_kwargs = dict(cache=args.cached_loader) if args.reuse_device_for_embeddings and not hetero: inference_kwargs['embedding_device'] = device for _ in range(args.warmup): @@ -332,4 +335,5 @@ def run(args: argparse.ArgumentParser): help='Write benchmark or PyTorch profile data to CSV') add('--export-chrome-trace', default=True, type=bool, help='Export chrome trace file. Works only with PyTorch profiler') + add('--cached-loader', action='store_true', help='Use CachedLoader') run(argparser.parse_args())