diff --git a/examples/hetero/to_hetero_mag.py b/examples/hetero/to_hetero_mag.py index 69c2e97ecd740..bbad09c1aa57a 100644 --- a/examples/hetero/to_hetero_mag.py +++ b/examples/hetero/to_hetero_mag.py @@ -5,8 +5,8 @@ import torch import torch.nn.functional as F from torch.nn import ReLU +from torch.profiler import ProfilerActivity, profile from tqdm import tqdm -from torch.profiler import profile, ProfilerActivity import torch_geometric.transforms as T from torch_geometric.datasets import OGB_MAG @@ -16,11 +16,13 @@ parser = argparse.ArgumentParser() parser.add_argument('--use_hgt_loader', action='store_true') parser.add_argument('--inference', type=bool, default=False) -parser.add_argument('--profile', type=bool, default=False) # Currently support profile in inference +parser.add_argument('--profile', type=bool, + default=False) # Currently support profile in inference args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -profile_sort = "self_cuda_time_total" if torch.cuda.is_available() else "self_cpu_time_total" +profile_sort = "self_cuda_time_total" if torch.cuda.is_available( +) else "self_cpu_time_total" path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/OGB') transform = T.ToUndirected(merge=True) @@ -53,6 +55,7 @@ ]) model = to_hetero(model, data.metadata(), aggr='sum').to(device) + def trace_handler(p): output = p.key_averages().table(sort_by=profile_sort) print(output) @@ -61,6 +64,7 @@ def trace_handler(p): timeline_file = profile_dir + 'timeline-to-hetero-mag' + '.json' p.export_chrome_trace(timeline_file) + @torch.no_grad() def init_params(): # Initialize lazy parameters via forwarding a single batch to the model: @@ -104,6 +108,7 @@ def test(loader): return total_correct / total_examples + @torch.no_grad() def inference(loader): model.eval() @@ -112,6 +117,7 @@ def inference(loader): batch_size = batch['paper'].batch_size model(batch.x_dict, batch.edge_index_dict) + init_params() # Initialize parameters. if not args.inference: optimizer = torch.optim.Adam(model.parameters(), lr=0.01) @@ -124,11 +130,12 @@ def inference(loader): for epoch in range(1, 21): if epoch == 20: if args.profile: - with profile(activities=[ - ProfilerActivity.CPU, ProfilerActivity.CUDA], - on_trace_ready=trace_handler) as p: - inference(val_loader) - p.step() + with profile( + activities=[ + ProfilerActivity.CPU, ProfilerActivity.CUDA + ], on_trace_ready=trace_handler) as p: + inference(val_loader) + p.step() else: if torch.cuda.is_available(): torch.cuda.synchronize()