diff --git a/examples/hetero/to_hetero_mag.py b/examples/hetero/to_hetero_mag.py index 6605038c9af3..69c2e97ecd74 100644 --- a/examples/hetero/to_hetero_mag.py +++ b/examples/hetero/to_hetero_mag.py @@ -1,10 +1,12 @@ import argparse import os.path as osp +import time import torch import torch.nn.functional as F from torch.nn import ReLU from tqdm import tqdm +from torch.profiler import profile, ProfilerActivity import torch_geometric.transforms as T from torch_geometric.datasets import OGB_MAG @@ -13,9 +15,12 @@ parser = argparse.ArgumentParser() parser.add_argument('--use_hgt_loader', action='store_true') +parser.add_argument('--inference', type=bool, default=False) +parser.add_argument('--profile', type=bool, default=False) # Currently support profile in inference args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +profile_sort = "self_cuda_time_total" if torch.cuda.is_available() else "self_cpu_time_total" path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/OGB') transform = T.ToUndirected(merge=True) @@ -48,6 +53,13 @@ ]) model = to_hetero(model, data.metadata(), aggr='sum').to(device) +def trace_handler(p): + output = p.key_averages().table(sort_by=profile_sort) + print(output) + import pathlib + profile_dir = str(pathlib.Path.cwd()) + '/' + timeline_file = profile_dir + 'timeline-to-hetero-mag' + '.json' + p.export_chrome_trace(timeline_file) @torch.no_grad() def init_params(): @@ -92,11 +104,40 @@ def test(loader): return total_correct / total_examples +@torch.no_grad() +def inference(loader): + model.eval() + for batch in tqdm(loader): + batch = batch.to(device, 'edge_index') + batch_size = batch['paper'].batch_size + model(batch.x_dict, batch.edge_index_dict) init_params() # Initialize parameters. -optimizer = torch.optim.Adam(model.parameters(), lr=0.01) +if not args.inference: + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) -for epoch in range(1, 21): - loss = train() - val_acc = test(val_loader) - print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}') + for epoch in range(1, 21): + loss = train() + val_acc = test(val_loader) + print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}') +else: + for epoch in range(1, 21): + if epoch == 20: + if args.profile: + with profile(activities=[ + ProfilerActivity.CPU, ProfilerActivity.CUDA], + on_trace_ready=trace_handler) as p: + inference(val_loader) + p.step() + else: + if torch.cuda.is_available(): + torch.cuda.synchronize() + t_start = time.time() + inference(val_loader) + if torch.cuda.is_available(): + torch.cuda.synchronize() + t_end = time.time() + duration = t_end - t_start + print("End-to-End time: {} s".format(duration), flush=True) + else: + inference(val_loader)