From 097291df52ce14a217bf52844680731a1b720631 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 2 Nov 2022 07:52:08 +0000 Subject: [PATCH] update --- .gitignore | 1 + setup.py | 2 ++ test/nn/models/test_basic_gnn.py | 40 +++++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index df1ce1a898e1..2909363549bf 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ coverage.xml .venv *.out *.pt +*.onnx !torch_geometric/data/ !test/data/ diff --git a/setup.py b/setup.py index 23aa510514bb..aece1634d69a 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,8 @@ test_requires = [ 'pytest', 'pytest-cov', + 'onnx', + 'onnxruntime', ] dev_requires = test_requires + [ diff --git a/test/nn/models/test_basic_gnn.py b/test/nn/models/test_basic_gnn.py index 8f23bf905905..d002a0783967 100644 --- a/test/nn/models/test_basic_gnn.py +++ b/test/nn/models/test_basic_gnn.py @@ -7,8 +7,9 @@ import torch.nn.functional as F from torch_geometric.loader import NeighborLoader +from torch_geometric.nn import SAGEConv from torch_geometric.nn.models import GAT, GCN, GIN, PNA, EdgeCNN, GraphSAGE -from torch_geometric.testing import withPython +from torch_geometric.testing import withPackage, withPython out_dims = [None, 8] dropouts = [0.0, 0.5] @@ -163,3 +164,40 @@ def test_packaging(): model = pi.load_pickle('models', 'model.pkl') with torch.no_grad(): assert model(x, edge_index).size() == (3, 16) + + +@withPackage('onnx', 'onnxruntime') +def test_onnx(): + import onnx + import onnxruntime as ort + + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = SAGEConv(8, 16) + self.conv2 = SAGEConv(16, 16) + + def forward(self, x, edge_index): + x = self.conv1(x, edge_index).relu() + x = self.conv2(x, edge_index) + return x + + model = MyModel() + x = torch.randn(3, 8) + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + torch.onnx.export(model, (x, edge_index), 'model.onnx', + input_names=('x', 'edge_index')) + + model = onnx.load('model.onnx') + onnx.checker.check_model(model) + + ort_session = ort.InferenceSession('model.onnx') + + out = ort_session.run(None, { + 'x': x.numpy(), + 'edge_index': edge_index.numpy() + })[0] + assert out.shape == (3, 16) + + os.remove('model.onnx')