Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Nov 2, 2022
1 parent 554453d commit 097291d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ coverage.xml
.venv
*.out
*.pt
*.onnx

!torch_geometric/data/
!test/data/
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
test_requires = [
'pytest',
'pytest-cov',
'onnx',
'onnxruntime',
]

dev_requires = test_requires + [
Expand Down
40 changes: 39 additions & 1 deletion test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import torch.nn.functional as F

from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.nn.models import GAT, GCN, GIN, PNA, EdgeCNN, GraphSAGE
from torch_geometric.testing import withPython
from torch_geometric.testing import withPackage, withPython

out_dims = [None, 8]
dropouts = [0.0, 0.5]
Expand Down Expand Up @@ -163,3 +164,40 @@ def test_packaging():
model = pi.load_pickle('models', 'model.pkl')
with torch.no_grad():
assert model(x, edge_index).size() == (3, 16)


@withPackage('onnx', 'onnxruntime')
def test_onnx():
import onnx
import onnxruntime as ort

class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = SAGEConv(8, 16)
self.conv2 = SAGEConv(16, 16)

def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x

model = MyModel()
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

torch.onnx.export(model, (x, edge_index), 'model.onnx',
input_names=('x', 'edge_index'))

model = onnx.load('model.onnx')
onnx.checker.check_model(model)

ort_session = ort.InferenceSession('model.onnx')

out = ort_session.run(None, {
'x': x.numpy(),
'edge_index': edge_index.numpy()
})[0]
assert out.shape == (3, 16)

os.remove('model.onnx')

0 comments on commit 097291d

Please sign in to comment.