Skip to content

Commit

Permalink
Add inference and profile for to_hetero_mag
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Jul 1, 2022
1 parent 0a0d349 commit 5eb0491
Showing 1 changed file with 46 additions and 5 deletions.
51 changes: 46 additions & 5 deletions examples/hetero/to_hetero_mag.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import argparse
import os.path as osp
import time

import torch
import torch.nn.functional as F
from torch.nn import ReLU
from tqdm import tqdm
from torch.profiler import profile, ProfilerActivity

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
Expand All @@ -13,9 +15,12 @@

parser = argparse.ArgumentParser()
parser.add_argument('--use_hgt_loader', action='store_true')
parser.add_argument('--inference', type=bool, default=False)
parser.add_argument('--profile', type=bool, default=False) # Currently support profile in inference
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
profile_sort = "self_cuda_time_total" if torch.cuda.is_available() else "self_cpu_time_total"

path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/OGB')
transform = T.ToUndirected(merge=True)
Expand Down Expand Up @@ -48,6 +53,13 @@
])
model = to_hetero(model, data.metadata(), aggr='sum').to(device)

def trace_handler(p):
output = p.key_averages().table(sort_by=profile_sort)
print(output)
import pathlib
profile_dir = str(pathlib.Path.cwd()) + '/'
timeline_file = profile_dir + 'timeline-to-hetero-mag' + '.json'
p.export_chrome_trace(timeline_file)

@torch.no_grad()
def init_params():
Expand Down Expand Up @@ -92,11 +104,40 @@ def test(loader):

return total_correct / total_examples

@torch.no_grad()
def inference(loader):
model.eval()
for batch in tqdm(loader):
batch = batch.to(device, 'edge_index')
batch_size = batch['paper'].batch_size
model(batch.x_dict, batch.edge_index_dict)

init_params() # Initialize parameters.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
if not args.inference:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1, 21):
loss = train()
val_acc = test(val_loader)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')
for epoch in range(1, 21):
loss = train()
val_acc = test(val_loader)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')
else:
for epoch in range(1, 21):
if epoch == 20:
if args.profile:
with profile(activities=[
ProfilerActivity.CPU, ProfilerActivity.CUDA],
on_trace_ready=trace_handler) as p:
inference(val_loader)
p.step()
else:
if torch.cuda.is_available():
torch.cuda.synchronize()
t_start = time.time()
inference(val_loader)
if torch.cuda.is_available():
torch.cuda.synchronize()
t_end = time.time()
duration = t_end - t_start
print("End-to-End time: {} s".format(duration), flush=True)
else:
inference(val_loader)

0 comments on commit 5eb0491

Please sign in to comment.