Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add profiling support for benchmark/kernel and benchmark/inference #5073

Merged
merged 24 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3659947
Add inference test for benchmark/kernel
yanbing-j Jul 19, 2022
cc45413
Add gcn + ogbn-products in benchmark/kernel
yanbing-j Jul 19, 2022
6054a6c
Merge GCN and SAGE into one
yanbing-j Jul 29, 2022
51b287c
Add profile support in benchmark/inference
yanbing-j Aug 1, 2022
e31a5f8
Merge branch 'master' into yanbing/benchmark
yanbing-j Aug 1, 2022
96b868c
Add decorator for torch.profile and move dataloader outside of run_tr…
yanbing-j Aug 3, 2022
f7f25c4
Merge branch 'master' into yanbing/benchmark
yanbing-j Aug 3, 2022
8fa35e9
Fix code coverage
yanbing-j Aug 3, 2022
11db83f
Add changelog and fix bug
yanbing-j Aug 4, 2022
07ad269
Update
yanbing-j Aug 5, 2022
feda5b4
Merge branch 'master' into yanbing/benchmark
yanbing-j Aug 5, 2022
5677e49
remove gcn+ogbn from benchmark/kernel, add GraphSage in benchmark/inf…
yanbing-j Aug 11, 2022
6e48b43
Merge branch 'master' into yanbing/benchmark
yanbing-j Aug 11, 2022
d3e7d18
Merge timeit and e2e_time
yanbing-j Aug 15, 2022
78ad7ff
Merge branch 'master' into yanbing/benchmark
yanbing-j Aug 15, 2022
1789e6c
Update test_profile.py
yanbing-j Aug 15, 2022
1cf7d7e
Merge branch 'master' into yanbing/benchmark
yanbing-j Aug 17, 2022
b1a15bb
Update timeit and torch_profile in citation and points
yanbing-j Aug 17, 2022
4fef86a
Merge branch 'master' into yanbing/benchmark
yanbing-j Aug 24, 2022
32c6ff3
Update and add log argument in timeit
yanbing-j Aug 24, 2022
42df421
Update benchmark/inference/utils.py
rusty1s Aug 24, 2022
a5e2ae6
Update benchmark/kernel/train_eval.py
rusty1s Aug 24, 2022
06e9d52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2022
872cdec
Merge branch 'master' into yanbing/benchmark
rusty1s Aug 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions benchmark/citation/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import torch.nn.functional as F
from torch import tensor
from torch.optim import Adam
from torch.profiler import ProfilerActivity, profile

from torch_geometric.profile import trace_handler
from torch_geometric.profile import timeit, torch_profile
from torch_geometric.utils import index_to_mask

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand Down Expand Up @@ -102,24 +101,14 @@ def run_inference(dataset, model, epochs, profiling, permute_masks=None,

for epoch in range(1, epochs + 1):
if epoch == epochs:
if torch.cuda.is_available():
torch.cuda.synchronize()
t_start = time.time()

inference(model, data)

if epoch == epochs:
if torch.cuda.is_available():
torch.cuda.synchronize()
t_end = time.time()
duration = t_end - t_start
print(f'End-to-End Inference Time: {duration:.8f}s', flush=True)
with timeit():
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
inference(model, data)
else:
inference(model, data)

if profiling:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
on_trace_ready=trace_handler) as p:
with torch_profile():
inference(model, data)
p.step()


def run(dataset, model, runs, epochs, lr, weight_decay, early_stopping,
Expand Down
30 changes: 21 additions & 9 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import argparse
from timeit import default_timer

import torch
from utils import get_dataset, get_model

from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import PNAConv
from torch_geometric.profile import rename_profile_file, timeit, torch_profile

supported_sets = {
'ogbn-mag': ['rgat', 'rgcn'],
'ogbn-products': ['edge_cnn', 'gat', 'gcn', 'pna'],
'Reddit': ['edge_cnn', 'gat', 'gcn', 'pna'],
'ogbn-products': ['edge_cnn', 'gat', 'gcn', 'pna', 'graph_sage'],
'Reddit': ['edge_cnn', 'gat', 'gcn', 'pna', 'graph_sage'],
}


Expand Down Expand Up @@ -92,11 +92,21 @@ def run(args: argparse.ArgumentParser) -> None:
model = model.to(device)
model.eval()

start = default_timer()
model.inference(subgraph_loader, device,
progress_bar=True)
stop = default_timer()
print(f'Inference time={stop-start:.3f} seconds\n')
for _ in range(args.warmup):
model.inference(subgraph_loader, device,
progress_bar=True)
with timeit():
model.inference(subgraph_loader, device,
progress_bar=True)

if args.profile:
with torch_profile():
model.inference(subgraph_loader, device,
progress_bar=True)
rename_profile_file(
model_name, dataset_name, str(batch_size),
str(layers), str(hidden_channels),
str(subgraph_loader.num_neighbors))


if __name__ == '__main__':
Expand All @@ -120,7 +130,9 @@ def run(args: argparse.ArgumentParser) -> None:
argparser.add_argument(
'--hetero-num-neighbors', default=-1, type=int,
help='number of neighbors to sample per layer for hetero workloads')
argparser.add_argument('--num-workers', default=2, type=int)
argparser.add_argument('--num-workers', default=0, type=int)
argparser.add_argument('--warmup', default=1, type=int)
argparser.add_argument('--profile', action='store_true')

args = argparser.parse_args()

Expand Down
9 changes: 8 additions & 1 deletion benchmark/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,20 @@

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG, Reddit
from torch_geometric.nn.models.basic_gnn import GAT, GCN, PNA, EdgeCNN
from torch_geometric.nn.models.basic_gnn import (
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
GAT,
GCN,
PNA,
EdgeCNN,
GraphSAGE,
)

models_dict = {
'edge_cnn': EdgeCNN,
'gat': GAT,
'gcn': GCN,
'pna': PNA,
'graph_sage': GraphSAGE,
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
'rgat': HeteroGAT,
'rgcn': HeteroGraphSAGE,
}
Expand Down
100 changes: 76 additions & 24 deletions benchmark/kernel/main_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
from gcn import GCN
from gin import GIN
from graph_sage import GraphSAGE
from train_eval import eval_acc, train
from train_eval import eval_acc, inference_run, train

from torch_geometric import seed_everything
from torch_geometric.loader import DataLoader
from torch_geometric.profile import get_stats_summary, profileit, timeit
from torch_geometric.profile import (
get_stats_summary,
profileit,
rename_profile_file,
timeit,
torch_profile,
)

seed_everything(0)

Expand All @@ -22,6 +28,8 @@
help='Skip the first few runs')
parser.add_argument('--goal_accuracy', type=int, default=1,
help='The goal test accuracy')
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
args = parser.parse_args()

layers = [1, 2, 3]
Expand All @@ -37,11 +45,8 @@

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Decorate train and eval functions:
train = profileit(print_layer_stats=False)(train)
eval_acc = timeit()(eval_acc)

for dataset_name, Net in product(datasets, nets):
def prepare_dataloader(dataset_name):
dataset = get_dataset(dataset_name, sparse=True)
num_train = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1)
Expand All @@ -56,21 +61,68 @@
shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
shuffle=False)

for num_layers, hidden in product(layers, hiddens):
print(f'--\n{dataset_name} - {Net.__name__} - {num_layers} - {hidden}')

model = Net(dataset, num_layers, hidden).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

stats_list = []
for epoch in range(1, args.epochs + 1):
loss, stats = train(model, optimizer, train_loader)
val_acc, val_time = eval_acc(model, val_loader)
test_acc, test_time = eval_acc(model, test_loader)

if epoch >= args.warmup_profile:
stats_list.append(stats)

stats_summary = get_stats_summary(stats_list)
print(stats_summary)
return dataset, train_loader, val_loader, test_loader


def run_train():
for dataset_name, Net in product(datasets, nets):
dataset, train_loader, val_loader, test_loader = prepare_dataloader(
dataset_name)

for num_layers, hidden in product(layers, hiddens):
print("--\n{} - {} - {} - {}".format(dataset_name, Net.__name__,
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
num_layers, hidden))

model = Net(dataset, num_layers, hidden).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

stats_list = []
acc_list = []
for epoch in range(1, args.epochs + 1):
loss, stats = train(model, optimizer, train_loader)
with timeit() as t:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
val_acc = eval_acc(model, val_loader)
val_time = t.duration
with timeit() as t:
test_acc = eval_acc(model, test_loader)
test_time = t.duration

if epoch >= args.warmup_profile:
stats_list.append(stats)
acc_list.append([val_acc, val_time, test_acc, test_time])
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

stats_summary = get_stats_summary(stats_list)
print(stats_summary)


@torch.no_grad()
def run_inference():
for dataset_name, Net in product(datasets, nets):
dataset, _, _, test_loader = prepare_dataloader(dataset_name)

for num_layers, hidden in product(layers, hiddens):
print("--\n{} - {} - {} - {}".format(dataset_name, Net.__name__,
num_layers, hidden))

model = Net(dataset, num_layers, hidden).to(device)

for epoch in range(1, args.epochs + 1):
if epoch == args.epochs:
with timeit():
inference_run(model, test_loader)
else:
inference_run(model, test_loader)

if args.profile:
with torch_profile():
inference_run(model, test_loader)
rename_profile_file(Net.__name__, dataset_name,
str(num_layers), str(hidden))


if not args.inference:
# Decorate train functions:
train = profileit()(train)
run_train()
else:
run_inference()
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 8 additions & 0 deletions benchmark/kernel/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,11 @@ def eval_loss(model, loader):
out = model(data)
loss += F.nll_loss(out, data.y.view(-1), reduction='sum').item()
return loss / len(loader.dataset)


def inference_run(model, loader):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
model.eval()
for data in loader:
data = data.to(device)
with torch.no_grad():
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
model(data)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
24 changes: 7 additions & 17 deletions benchmark/points/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.profiler import ProfilerActivity, profile

from torch_geometric.loader import DataLoader
from torch_geometric.profile import trace_handler
from torch_geometric.profile import timeit, torch_profile

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand Down Expand Up @@ -47,25 +46,16 @@ def run_inference(test_dataset, model, epochs, batch_size, profiling):
test_loader = DataLoader(test_dataset, batch_size, shuffle=False)

for epoch in range(1, epochs + 1):
print("Epoch: ", epoch)
if epoch == epochs:
if torch.cuda.is_available():
torch.cuda.synchronize()
t_start = time.time()

inference(model, test_loader, device)

if epoch == epochs:
if torch.cuda.is_available():
torch.cuda.synchronize()
t_end = time.time()
duration = t_end - t_start
print(f'End-to-End Inference Time: {duration:.8f}s', flush=True)
with timeit():
inference(model, test_loader, device)
else:
inference(model, test_loader, device)

if profiling:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
on_trace_ready=trace_handler) as p:
with torch_profile():
inference(model, test_loader, device)
p.step()


def run(train_dataset, test_dataset, model, epochs, batch_size, lr,
Expand Down
19 changes: 12 additions & 7 deletions test/profile/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

import torch
import torch.nn.functional as F
from torch.profiler import ProfilerActivity, profile

from torch_geometric.nn import GraphSAGE
from torch_geometric.profile import (
get_stats_summary,
profileit,
rename_profile_file,
timeit,
trace_handler,
)
from torch_geometric.profile.profile import torch_profile
from torch_geometric.testing import onlyFullTest, withCUDA


Expand Down Expand Up @@ -69,18 +68,24 @@ def test(model, x, edge_index, y):


@onlyFullTest
def test_trace_handler(get_dataset):
def test_torch_profile(get_dataset):
dataset = get_dataset(name='PubMed')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = dataset[0].to(device)
model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3,
out_channels=dataset.num_classes).to(device)
model.eval()

@timeit()
def inference_e2e(model, data):
model(data.x, data.edge_index)

@torch_profile()
def inference_profile(model, data):
model(data.x, data.edge_index)

for epoch in range(3):
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
on_trace_ready=trace_handler) as p:
model(data.x, data.edge_index)
p.step()
inference_e2e(model, data)
inference_profile(model, data)
rename_profile_file('test_profile')
assert os.path.exists('profile-test_profile.json')
3 changes: 2 additions & 1 deletion torch_geometric/profile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .profile import profileit, timeit, get_stats_summary
from .profile import trace_handler, rename_profile_file
from .profile import trace_handler, rename_profile_file, torch_profile
from .utils import count_parameters
from .utils import get_model_size
from .utils import get_data_size
Expand All @@ -13,6 +13,7 @@
'get_stats_summary',
'trace_handler',
'rename_profile_file',
'torch_profile',
'count_parameters',
'get_model_size',
'get_data_size',
Expand Down
Loading