Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to use CachedLoader in inference benchmarks #7897

Merged
merged 8 commits into from
Aug 28, 2023
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `CachedLoader` implementation ([#7896](https://github.com/pyg-team/pytorch_geometric/pull/7896))
- Added `CachedLoader` implementation ([#7896](https://github.com/pyg-team/pytorch_geometric/pull/7896), [#7897](https://github.com/pyg-team/pytorch_geometric/pull/7897))
- Added possibility to run training benchmarks on XPU device ([#7925](https://github.com/pyg-team/pytorch_geometric/pull/7925))
- Added `utils.ppr` for personalized PageRank computation ([#7917](https://github.com/pyg-team/pytorch_geometric/pull/7917))
- Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918))
Expand Down
6 changes: 5 additions & 1 deletion benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def run(args: argparse.ArgumentParser):
_, _, test_mask = get_split_masks(data, dataset_name)
degree = None

if hetero and args.cached_loader:
args.cached_loader = False
print('Disabling CachedLoader, not supported in Hetero models')
if args.num_layers != [1] and not hetero and args.num_steps != -1:
raise ValueError("Layer-wise inference requires `steps=-1`")

Expand Down Expand Up @@ -209,7 +212,7 @@ def run(args: argparse.ArgumentParser):
data = transformation(data)

with cpu_affinity, amp, timeit() as time:
inference_kwargs = {}
inference_kwargs = dict(cache=args.cached_loader)
if args.reuse_device_for_embeddings and not hetero:
inference_kwargs['embedding_device'] = device
for _ in range(args.warmup):
Expand Down Expand Up @@ -332,4 +335,5 @@ def run(args: argparse.ArgumentParser):
help='Write benchmark or PyTorch profile data to CSV')
add('--export-chrome-trace', default=True, type=bool,
help='Export chrome trace file. Works only with PyTorch profiler')
add('--cached-loader', action='store_true', help='Use CachedLoader')
run(argparser.parse_args())