Skip to content

Commit

Permalink
reset
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 13, 2022
1 parent e6ffafe commit fea0f43
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 90 deletions.
47 changes: 5 additions & 42 deletions examples/hetero/to_hetero_mag.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
import argparse
import os.path as osp
import time

import torch
import torch.nn.functional as F
from torch.nn import ReLU
from torch.profiler import ProfilerActivity, profile
from tqdm import tqdm

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import HGTLoader, NeighborLoader
from torch_geometric.nn import Linear, SAGEConv, Sequential, to_hetero
from torch_geometric.profile import rename_profile_file, trace_handler

parser = argparse.ArgumentParser()
parser.add_argument('--use_hgt_loader', action='store_true')
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand Down Expand Up @@ -98,42 +93,10 @@ def test(loader):
return total_correct / total_examples


@torch.no_grad()
def inference(loader):
model.eval()
for batch in tqdm(loader):
batch = batch.to(device, 'edge_index')
model(batch.x_dict, batch.edge_index_dict)


init_params() # Initialize parameters.
if not args.inference:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1, 21):
loss = train()
val_acc = test(val_loader)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')
else:
for epoch in range(1, 21):
if epoch == 20:
if args.profile:
with profile(
activities=[
ProfilerActivity.CPU, ProfilerActivity.CUDA
], on_trace_ready=trace_handler) as p:
inference(val_loader)
p.step()
rename_profile_file('to_hetero_mag')
else:
if torch.cuda.is_available():
torch.cuda.synchronize()
t_start = time.time()
inference(val_loader)
if torch.cuda.is_available():
torch.cuda.synchronize()
t_end = time.time()
duration = t_end - t_start
print("End-to-End time: {} s".format(duration), flush=True)
else:
inference(val_loader)
for epoch in range(1, 21):
loss = train()
val_acc = test(val_loader)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')
55 changes: 7 additions & 48 deletions examples/pna.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
import argparse
import os.path as osp
import time

import torch
import torch.nn.functional as F
from torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.profiler import ProfilerActivity, profile

from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import BatchNorm, PNAConv, global_add_pool
from torch_geometric.profile import rename_profile_file, trace_handler
from torch_geometric.utils import degree

parser = argparse.ArgumentParser()
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
args = parser.parse_args()

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ZINC')
train_dataset = ZINC(path, subset=True, split='train')
val_dataset = ZINC(path, subset=True, split='val')
Expand Down Expand Up @@ -109,42 +100,10 @@ def test(loader):
return total_error / len(loader.dataset)


@torch.no_grad()
def inference(loader):
model.eval()
for data in loader:
data = data.to(device)
model(data.x, data.edge_index, data.edge_attr, data.batch)


if not args.inference:
for epoch in range(1, 301):
loss = train(epoch)
val_mae = test(val_loader)
test_mae = test(test_loader)
scheduler.step(val_mae)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
f'Test: {test_mae:.4f}')
else:
for epoch in range(1, 301):
if epoch == 300:
if args.profile:
with profile(
activities=[
ProfilerActivity.CPU, ProfilerActivity.CUDA
], on_trace_ready=trace_handler) as p:
inference(test_loader)
p.step()
rename_profile_file('pna')
else:
if torch.cuda.is_available():
torch.cuda.synchronize()
t_start = time.time()
inference(test_loader)
if torch.cuda.is_available():
torch.cuda.synchronize()
t_end = time.time()
duration = t_end - t_start
print("End-to-End time: {} s".format(duration), flush=True)
else:
inference(test_loader)
for epoch in range(1, 301):
loss = train(epoch)
val_mae = test(val_loader)
test_mae = test(test_loader)
scheduler.step(val_mae)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
f'Test: {test_mae:.4f}')

0 comments on commit fea0f43

Please sign in to comment.