Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to run inference benchmarks on XPU device #7705

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added possibility to run inference benchmarks on XPU device ([#7705](https://github.com/pyg-team/pytorch_geometric/pull/7705))
- Added `HeteroData` support in `to_networkx` ([#7713](https://github.com/pyg-team/pytorch_geometric/pull/7713))
- Added `FlopsCount` support via `fvcore` ([#7693](https://github.com/pyg-team/pytorch_geometric/pull/7693))
- Added back support for PyTorch >= 1.11.0 ([#7656](https://github.com/pyg-team/pytorch_geometric/pull/7656))
Expand Down
41 changes: 34 additions & 7 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
save_benchmark_data,
test,
write_to_csv,
xpu_profiler,
)
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import PNAConv
Expand Down Expand Up @@ -42,11 +43,23 @@ def run(args: argparse.ArgumentParser):
warnings.warn("Cannot write profile data to CSV because profiling is "
"disabled")

# cuda device is not suitable for full batch mode
device = torch.device(
'cuda' if not args.full_batch and torch.cuda.is_available() else 'cpu')
if args.device == 'xpu':
try:
import intel_extension_for_pytorch as ipex
except ImportError:
raise RuntimeError('XPU device requires IPEX to be installed')

if ((args.device == 'cuda' and not torch.cuda.is_available())
or (args.device == 'xpu' and not torch.xpu.is_available())):
raise RuntimeError(f'{args.device.upper()} is not available')

if args.device == 'cuda' and args.full_batch:
raise RuntimeError('CUDA device is not suitable for full batch mode')

device = torch.device(args.device)

print('BENCHMARK STARTS')
print(f'Running on {args.device.upper()}')
for dataset_name in args.datasets:
assert dataset_name in supported_sets.keys(
), f"Dataset {dataset_name} isn't supported."
Expand All @@ -66,11 +79,17 @@ def run(args: argparse.ArgumentParser):
if args.num_layers != [1] and not hetero and args.num_steps != -1:
raise ValueError("Layer-wise inference requires `steps=-1`")

if torch.cuda.is_available():
if args.device == 'cuda':
amp = torch.cuda.amp.autocast(enabled=False)
elif args.device == 'xpu':
amp = torch.xpu.amp.autocast(enabled=False)
else:
amp = torch.cpu.amp.autocast(enabled=args.bf16)

if args.device == 'xpu' and args.warmup < 1:
print('XPU device requires warmup - setting warmup=1')
args.warmup = 1

inputs_channels = data[
'paper'].num_features if dataset_name == 'ogbn-mag' \
else dataset.num_features
Expand Down Expand Up @@ -163,16 +182,22 @@ def run(args: argparse.ArgumentParser):
state_dict = torch.load(args.ckpt_path)
model.load_state_dict(state_dict)
model.eval()
if args.device == 'xpu':
model = ipex.optimize(model)

# Define context manager parameters:
if args.cpu_affinity and with_loader:
cpu_affinity = subgraph_loader.enable_cpu_affinity(
args.loader_cores)
else:
cpu_affinity = nullcontext()
profile = torch_profile(
args.export_chrome_trace, csv_data,
args.write_csv) if args.profile else nullcontext()
if args.profile and args.device == 'xpu':
profile = xpu_profiler(args.export_chrome_trace)
elif args.profile:
profile = torch_profile(args.export_chrome_trace,
csv_data, args.write_csv)
else:
profile = nullcontext()
itt = emit_itt(
) if args.vtune_profile else nullcontext()

Expand Down Expand Up @@ -256,6 +281,8 @@ def run(args: argparse.ArgumentParser):
argparser = argparse.ArgumentParser('GNN inference benchmark')
add = argparser.add_argument

add('--device', choices=['cpu', 'cuda', 'xpu'], default='cuda',
DamianSzwichtenberg marked this conversation as resolved.
Show resolved Hide resolved
help='Device to run benchmark on')
add('--datasets', nargs='+',
default=['ogbn-mag', 'ogbn-products', 'Reddit'], type=str)
add('--use-sparse-tensor', action='store_true',
Expand Down
2 changes: 2 additions & 0 deletions benchmark/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .utils import get_split_masks
from .utils import save_benchmark_data, write_to_csv
from .utils import test
from .utils import xpu_profiler

__all__ = [
'emit_itt',
Expand All @@ -14,4 +15,5 @@
'save_benchmark_data',
'write_to_csv',
'test',
'xpu_profiler',
]
11 changes: 10 additions & 1 deletion benchmark/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import os.path as osp
from contextlib import contextmanager
from datetime import datetime

import torch
Expand All @@ -18,7 +19,6 @@
try:
from torch.autograd.profiler import emit_itt
except ImportError:
from contextlib import contextmanager

@contextmanager
def emit_itt(*args, **kwargs):
Expand Down Expand Up @@ -194,3 +194,12 @@ def test(model, loader, device, hetero, progress_bar=True,
total_examples += batch_size
total_correct += int((pred == batch.y[:batch_size]).sum())
return total_correct / total_examples


@contextmanager
def xpu_profiler(export_chrome_trace=True):
DamianSzwichtenberg marked this conversation as resolved.
Show resolved Hide resolved
with torch.autograd.profiler_legacy.profile(use_xpu=True) as profile:
yield
print(profile.key_averages().table(sort_by='self_xpu_time_total'))
if export_chrome_trace:
profile.export_chrome_trace('timeline.json')