Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend torch.compile benchmark to dynamic=True; Add assert_indirect_indexing = False #8220

Merged
merged 4 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220))

### Deprecated

### Fixed
Expand Down
22 changes: 18 additions & 4 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import os.path as osp
import random
import sys
import warnings

Expand Down Expand Up @@ -388,11 +389,24 @@ def test_basic_gnn_cache():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--backward', action='store_true')
parser.add_argument('--dynamic', action='store_true')
args = parser.parse_args()

num_nodes, num_edges = 10_000, 200_000
x = torch.randn(num_nodes, 64, device=args.device)
edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)
if args.dynamic:
min_num_nodes, max_num_nodes = 10_000, 15_000
min_num_edges, max_num_edges = 200_000, 300_000
else:
min_num_nodes, max_num_nodes = 10_000, 10_000
min_num_edges, max_num_edges = 200_000, 200_000

def gen_args():
N = random.randint(min_num_nodes, max_num_nodes)
E = random.randint(min_num_edges, max_num_edges)

x = torch.randn(N, 64, device=args.device)
edge_index = torch.randint(N, (2, E), device=args.device)

return x, edge_index

for Model in [GCN, GraphSAGE, GIN, EdgeCNN]:
print(f'Model: {Model.__name__}')
Expand All @@ -403,7 +417,7 @@ def test_basic_gnn_cache():
benchmark(
funcs=[model, compiled_model],
func_names=['Vanilla', 'Compiled'],
args=(x, edge_index),
args=gen_args,
num_steps=50 if args.device == 'cpu' else 500,
num_warmups=10 if args.device == 'cpu' else 100,
backward=args.backward,
Expand Down
6 changes: 6 additions & 0 deletions torch_geometric/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def compile(model: Optional[Callable] = None, *args, **kwargs) -> Callable:
jittable instances
(see :meth:`torch_geometric.nn.conv.MessagePassing.jittable`)
3. disables generation of device asserts during fused gather/scatter calls
to avoid performance impacts
.. note::
Without these adjustments, :meth:`torch.compile` may currently fail to
correctly optimize your :pyg:`PyG` model.
Expand Down Expand Up @@ -89,6 +92,9 @@ def fn(model: Callable) -> Callable:
# Replace instances of `MessagePassing` by their jittable version:
model = to_jittable(model)

# Do not generate device asserts which may slow down model execution:
torch._inductor.config.triton.assert_indirect_indexing = False

# Finally, run `torch.compile` to create an optimized version:
out = torch.compile(model, *args, **kwargs)

Expand Down
7 changes: 5 additions & 2 deletions torch_geometric/profile/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def benchmark(
args ((Any, ) or [(Any, )]): The arguments to pass to the functions.
Can be a list of arguments for each function in :obj:`funcs` in
case their headers differ.
Alternatively, you can pass in functions that generate arguments
on-the-fly (e.g., useful for benchmarking models on various sizes).
num_steps (int): The number of steps to run the benchmark.
func_names ([str], optional): The names of the functions. If not given,
will try to infer the name from the function itself.
Expand Down Expand Up @@ -69,17 +71,18 @@ def benchmark(
f"'func_names' (got {len(func_names)}) must be equal")

# Zero-copy `args` for each function (if necessary):
args_list = [args] * len(funcs) if isinstance(args, tuple) else args
args_list = [args] * len(funcs) if not isinstance(args, list) else args

iterator = zip(funcs, args_list, func_names)
if progress_bar:
from tqdm import tqdm
iterator = tqdm(iterator, total=len(funcs))

ts: List[List[str]] = []
for func, args, name in iterator:
for func, inputs, name in iterator:
t_forward = t_backward = 0
for i in range(num_warmups + num_steps):
args = inputs() if callable(inputs) else inputs
args = require_grad(args, backward)

if torch.cuda.is_available():
Expand Down