forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_pytorch_helper.py
68 lines (51 loc) · 2.48 KB
/
test_pytorch_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Some standard imports
import numpy as np
from torch import nn
import torch.onnx
import torch.nn.init as init
from caffe2.python.model_helper import ModelHelper
from pytorch_helper import PyTorchModule
import unittest
from caffe2.python.core import workspace
from test_pytorch_common import skipIfNoLapack
class TestCaffe2Backend(unittest.TestCase):
@skipIfNoLapack
def test_helper(self):
class SuperResolutionNet(nn.Module):
def __init__(self, upscale_factor, inplace=False):
super(SuperResolutionNet, self).__init__()
self.relu = nn.ReLU(inplace=inplace)
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self._initialize_weights()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
def _initialize_weights(self):
init.orthogonal(self.conv1.weight, init.calculate_gain('relu'))
init.orthogonal(self.conv2.weight, init.calculate_gain('relu'))
init.orthogonal(self.conv3.weight, init.calculate_gain('relu'))
init.orthogonal(self.conv4.weight)
torch_model = SuperResolutionNet(upscale_factor=3)
fake_input = torch.randn(1, 1, 224, 224, requires_grad=True)
# use ModelHelper to create a C2 net
helper = ModelHelper(name="test_model")
start = helper.Sigmoid(['the_input'])
# Embed the ONNX-converted pytorch net inside it
toutput, = PyTorchModule(helper, torch_model, (fake_input,), [start])
output = helper.Sigmoid(toutput)
workspace.RunNetOnce(helper.InitProto())
workspace.FeedBlob('the_input', fake_input.data.numpy())
# print([ k for k in workspace.blobs ])
workspace.RunNetOnce(helper.Proto())
c2_out = workspace.FetchBlob(str(output))
torch_out = torch.sigmoid(torch_model(torch.sigmoid(fake_input)))
np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)
if __name__ == '__main__':
unittest.main()