Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 1, 2022
1 parent dd80d4c commit 334cec0
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions examples/hetero/to_hetero_mag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch
import torch.nn.functional as F
from torch.nn import ReLU
from torch.profiler import ProfilerActivity, profile
from tqdm import tqdm
from torch.profiler import profile, ProfilerActivity

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
Expand All @@ -16,11 +16,13 @@
parser = argparse.ArgumentParser()
parser.add_argument('--use_hgt_loader', action='store_true')
parser.add_argument('--inference', type=bool, default=False)
parser.add_argument('--profile', type=bool, default=False) # Currently support profile in inference
parser.add_argument('--profile', type=bool,
default=False) # Currently support profile in inference
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
profile_sort = "self_cuda_time_total" if torch.cuda.is_available() else "self_cpu_time_total"
profile_sort = "self_cuda_time_total" if torch.cuda.is_available(
) else "self_cpu_time_total"

path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/OGB')
transform = T.ToUndirected(merge=True)
Expand Down Expand Up @@ -53,6 +55,7 @@
])
model = to_hetero(model, data.metadata(), aggr='sum').to(device)


def trace_handler(p):
output = p.key_averages().table(sort_by=profile_sort)
print(output)
Expand All @@ -61,6 +64,7 @@ def trace_handler(p):
timeline_file = profile_dir + 'timeline-to-hetero-mag' + '.json'
p.export_chrome_trace(timeline_file)


@torch.no_grad()
def init_params():
# Initialize lazy parameters via forwarding a single batch to the model:
Expand Down Expand Up @@ -104,6 +108,7 @@ def test(loader):

return total_correct / total_examples


@torch.no_grad()
def inference(loader):
model.eval()
Expand All @@ -112,6 +117,7 @@ def inference(loader):
batch_size = batch['paper'].batch_size
model(batch.x_dict, batch.edge_index_dict)


init_params() # Initialize parameters.
if not args.inference:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
Expand All @@ -124,11 +130,12 @@ def inference(loader):
for epoch in range(1, 21):
if epoch == 20:
if args.profile:
with profile(activities=[
ProfilerActivity.CPU, ProfilerActivity.CUDA],
on_trace_ready=trace_handler) as p:
inference(val_loader)
p.step()
with profile(
activities=[
ProfilerActivity.CPU, ProfilerActivity.CUDA
], on_trace_ready=trace_handler) as p:
inference(val_loader)
p.step()
else:
if torch.cuda.is_available():
torch.cuda.synchronize()
Expand Down

0 comments on commit 334cec0

Please sign in to comment.