Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ Models #728

Merged
merged 50 commits into from
Jun 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
bf23f2f
Added the existing code
ShahriarSS Jan 17, 2019
de9a842
Added squeezenet and fixed some stuff in the other models
ShahriarSS Jan 17, 2019
eb5c277
Wrote DenseNet and a part of InceptionV3
ShahriarSS Jan 21, 2019
5da78be
Fixed some errors in the models
ShahriarSS Jan 21, 2019
de08cca
Completed inception and changed models directory
ShahriarSS Jan 22, 2019
df91685
Fixed and wrote some stuff
ShahriarSS Jan 23, 2019
d1d8327
fixed maxpoool2d and avgpool2d and adaptiveavgpool2d
ShahriarSS Jan 23, 2019
24b543b
Fixed a few stuff
ShahriarSS Jan 25, 2019
a30b4b0
Added models namespace and changed cmakelists
ShahriarSS Jan 25, 2019
19eb406
Removed some comments
ShahriarSS Jan 26, 2019
6deac40
Changed style to pytorch style, added some comments and fixed some mi…
ShahriarSS Feb 12, 2019
dd2420e
Removed truncated normal init
ShahriarSS Feb 12, 2019
45347c0
Changed classes to structs and fixed a few errors
ShahriarSS Feb 14, 2019
91501b0
Replaced modelsimpl structs with functional wherever possible
ShahriarSS Feb 14, 2019
33f045d
Changed adaptive average pool from struct to function
ShahriarSS Feb 14, 2019
031deed
Wrote a max_pool2d wrapper and added some comments
ShahriarSS Feb 14, 2019
a46df6c
Replaced xavier init with kaiming init
ShahriarSS Feb 16, 2019
03a3863
Fixed an error in kaiming inits
ShahriarSS Feb 19, 2019
d0e119e
Added model conversion and tests
ShahriarSS Mar 28, 2019
90b01be
Fixed a typo in alexnet and removed tests from cmake
ShahriarSS Apr 1, 2019
f698fd4
Made an extension of tests and added module names to Densenet
ShahriarSS Apr 3, 2019
b31e104
Added python tests
ShahriarSS Apr 15, 2019
943e621
Added MobileNet and GoogLeNet models
ShahriarSS Apr 18, 2019
c321259
Added tests and conversions for new models and fixed a few errors
ShahriarSS Apr 19, 2019
e7196e4
Updated Alexnet ad VGG
ShahriarSS Apr 20, 2019
e3ca869
Updated Densenet, Squeezenet and Inception
ShahriarSS Apr 20, 2019
277a74b
Added ResNexts and their conversions
ShahriarSS Apr 20, 2019
ecf6e5c
Added tests for ResNexts
ShahriarSS Apr 20, 2019
6ff7184
Wrote tools nessesary to write ShuffleNet
ShahriarSS Apr 30, 2019
dc9dbbd
Added ShuffleNetV2
ShahriarSS Apr 30, 2019
d18a30a
Fixed some errors in ShuffleNetV2
ShahriarSS May 1, 2019
45b0452
Added conversions for shufflenetv2
ShahriarSS May 7, 2019
289be31
Fixed the errors in test_models.cpp
ShahriarSS May 7, 2019
744790c
Updated setup.py
ShahriarSS May 8, 2019
de149ec
Merge branch 'master' into cppmodels
fmassa May 10, 2019
94c1674
Fixed flake8 error on test_cpp_models.py
ShahriarSS May 10, 2019
06b5071
Changed view to reshape in forward of ResNet
ShahriarSS May 11, 2019
4d8959c
Updated ShuffleNetV2
ShahriarSS May 11, 2019
2a0d007
Split extensions to tests and ops
ShahriarSS May 17, 2019
fcd9e8c
Fixed test extension
ShahriarSS May 17, 2019
65b2075
Fixed image path in test_cpp_models.py
ShahriarSS May 17, 2019
975f494
Fixed image path in test_cpp_models.py
ShahriarSS May 17, 2019
159e5e7
Fixed a few things in test_cpp_models.py
ShahriarSS May 17, 2019
b25b3c2
Put the test models in evaluation mode
ShahriarSS May 19, 2019
930d0b3
Fixed registering error in GoogLeNet
ShahriarSS May 23, 2019
4dff6b8
Updated setup.py
ShahriarSS Jun 6, 2019
35749a5
Merge branch 'master' into cppmodels
ShahriarSS Jun 6, 2019
6acb64e
write test_cpp_models.py with unittest
ShahriarSS Jun 10, 2019
d98e7cf
Fixed a problem with pytest in test_cpp_models.py
ShahriarSS Jun 10, 2019
a6873d9
Fixed a lint problem
ShahriarSS Jun 10, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
cmake_minimum_required(VERSION 2.8)
project(torchvision)
set(CMAKE_CXX_STANDARD 11)

find_package(Torch REQUIRED)

file(GLOB_RECURSE HEADERS torchvision/csrc/vision.h)
file(GLOB_RECURSE MODELS_HEADERS torchvision/csrc/models/*.h)
file(GLOB_RECURSE MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp)

add_library (${PROJECT_NAME} SHARED ${MODELS_SOURCES})
target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}")

add_executable(convertmodels torchvision/csrc/convert_models/convert_models.cpp)
target_link_libraries(convertmodels "${PROJECT_NAME}")
target_link_libraries(convertmodels "${TORCH_LIBRARIES}")

#add_executable(testmodels test/test_models.cpp)
#target_link_libraries(testmodels "${PROJECT_NAME}")
#target_link_libraries(testmodels "${TORCH_LIBRARIES}")

install(TARGETS ${PROJECT_NAME} DESTINATION ${CMAKE_INSTALL_PREFIX}/lib)
install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME})
install(FILES ${MODELS_HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME}/models)
17 changes: 17 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ def get_extensions():
sources = main_file + source_cpu
ShahriarRezghi marked this conversation as resolved.
Show resolved Hide resolved
extension = CppExtension

test_dir = os.path.join(this_dir, 'test')
models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
source_models = glob.glob(os.path.join(models_dir, '*.cpp'))

test_file = [os.path.join(test_dir, s) for s in test_file]
source_models = [os.path.join(models_dir, s) for s in source_models]
tests = test_file + source_models

define_macros = []

extra_compile_args = {}
Expand All @@ -109,6 +118,7 @@ def get_extensions():
sources = [os.path.join(extensions_dir, s) for s in sources]

include_dirs = [extensions_dir]
tests_include_dirs = [test_dir, models_dir]

ext_modules = [
extension(
Expand All @@ -117,6 +127,13 @@ def get_extensions():
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
),
extension(
'torchvision._C_tests',
tests,
include_dirs=tests_include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]

Expand Down
122 changes: 122 additions & 0 deletions test/test_cpp_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import torch
import os
import unittest
from torchvision import models, transforms, _C_tests

from PIL import Image
import torchvision.transforms.functional as F


def process_model(model, tensor, func, name):
model.eval()
traced_script_module = torch.jit.trace(model, tensor)
traced_script_module.save("model.pt")

py_output = model.forward(tensor)
cpp_output = func("model.pt", tensor)

assert torch.allclose(py_output, cpp_output), 'Output mismatch of ' + name + ' models'


def read_image1():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
image = Image.open(image_path)
image = image.resize((224, 224))
x = F.to_tensor(image)
return x.view(1, 3, 224, 224)


def read_image2():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
image = Image.open(image_path)
image = image.resize((299, 299))
x = F.to_tensor(image)
x = x.view(1, 3, 299, 299)
return torch.cat([x, x], 0)


class Tester(unittest.TestCase):
pretrained = False
image = read_image1()

def test_alexnet(self):
process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, 'Alexnet')

def test_vgg11(self):
process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, 'VGG11')

def test_vgg13(self):
process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, 'VGG13')

def test_vgg16(self):
process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, 'VGG16')

def test_vgg19(self):
process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, 'VGG19')

def test_vgg11_bn(self):
process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, 'VGG11BN')

def test_vgg13_bn(self):
process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, 'VGG13BN')

def test_vgg16_bn(self):
process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, 'VGG16BN')

def test_vgg19_bn(self):
process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, 'VGG19BN')

def test_resnet18(self):
process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, 'Resnet18')

def test_resnet34(self):
process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, 'Resnet34')

def test_resnet50(self):
process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, 'Resnet50')

def test_resnet101(self):
process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, 'Resnet101')

def test_resnet152(self):
process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, 'Resnet152')

def test_resnext50_32x4d(self):
process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, 'ResNext50_32x4d')

def test_resnext101_32x8d(self):
process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d')

def test_squeezenet1_0(self):
process_model(models.squeezenet1_0(self.pretrained), self.image,
_C_tests.forward_squeezenet1_0, 'Squeezenet1.0')

def test_squeezenet1_1(self):
process_model(models.squeezenet1_1(self.pretrained), self.image,
_C_tests.forward_squeezenet1_1, 'Squeezenet1.1')

def test_densenet121(self):
process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, 'Densenet121')

def test_densenet169(self):
process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, 'Densenet169')

def test_densenet201(self):
process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, 'Densenet201')

def test_densenet161(self):
process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, 'Densenet161')

def test_mobilenet_v2(self):
process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, 'MobileNet')

def test_googlenet(self):
process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, 'GoogLeNet')

def test_inception_v3(self):
self.image = read_image2()
process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, 'Inceptionv3')


if __name__ == '__main__':
unittest.main()
173 changes: 173 additions & 0 deletions test/test_models.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>

#include "../torchvision/csrc/models/models.h"

using namespace vision::models;

template <typename Model>
torch::Tensor forward_model(const std::string& input_path, torch::Tensor x) {
Model network;
torch::load(network, input_path);
network->eval();
return network->forward(x);
}

torch::Tensor forward_alexnet(const std::string& input_path, torch::Tensor x) {
return forward_model<AlexNet>(input_path, x);
}

torch::Tensor forward_vgg11(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG11>(input_path, x);
}
torch::Tensor forward_vgg13(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG13>(input_path, x);
}
torch::Tensor forward_vgg16(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG16>(input_path, x);
}
torch::Tensor forward_vgg19(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG19>(input_path, x);
}

torch::Tensor forward_vgg11bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG11BN>(input_path, x);
}
torch::Tensor forward_vgg13bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG13BN>(input_path, x);
}
torch::Tensor forward_vgg16bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG16BN>(input_path, x);
}
torch::Tensor forward_vgg19bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG19BN>(input_path, x);
}

torch::Tensor forward_resnet18(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet18>(input_path, x);
}
torch::Tensor forward_resnet34(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet34>(input_path, x);
}
torch::Tensor forward_resnet50(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet50>(input_path, x);
}
torch::Tensor forward_resnet101(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNet101>(input_path, x);
}
torch::Tensor forward_resnet152(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNet152>(input_path, x);
}
torch::Tensor forward_resnext50_32x4d(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNext50_32x4d>(input_path, x);
}
torch::Tensor forward_resnext101_32x8d(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNext101_32x8d>(input_path, x);
}

torch::Tensor forward_squeezenet1_0(
const std::string& input_path,
torch::Tensor x) {
return forward_model<SqueezeNet1_0>(input_path, x);
}
torch::Tensor forward_squeezenet1_1(
const std::string& input_path,
torch::Tensor x) {
return forward_model<SqueezeNet1_1>(input_path, x);
}

torch::Tensor forward_densenet121(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet121>(input_path, x);
}
torch::Tensor forward_densenet169(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet169>(input_path, x);
}
torch::Tensor forward_densenet201(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet201>(input_path, x);
}
torch::Tensor forward_densenet161(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet161>(input_path, x);
}

torch::Tensor forward_mobilenetv2(
const std::string& input_path,
torch::Tensor x) {
return forward_model<MobileNetV2>(input_path, x);
}

torch::Tensor forward_googlenet(
const std::string& input_path,
torch::Tensor x) {
GoogLeNet network;
torch::load(network, input_path);
network->eval();
return network->forward(x).output;
}
torch::Tensor forward_inceptionv3(
const std::string& input_path,
torch::Tensor x) {
InceptionV3 network;
torch::load(network, input_path);
network->eval();
return network->forward(x).output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_alexnet", &forward_alexnet, "forward_alexnet");

m.def("forward_vgg11", &forward_vgg11, "forward_vgg11");
m.def("forward_vgg13", &forward_vgg13, "forward_vgg13");
m.def("forward_vgg16", &forward_vgg16, "forward_vgg16");
m.def("forward_vgg19", &forward_vgg19, "forward_vgg19");

m.def("forward_vgg11bn", &forward_vgg11bn, "forward_vgg11bn");
m.def("forward_vgg13bn", &forward_vgg13bn, "forward_vgg13bn");
m.def("forward_vgg16bn", &forward_vgg16bn, "forward_vgg16bn");
m.def("forward_vgg19bn", &forward_vgg19bn, "forward_vgg19bn");

m.def("forward_resnet18", &forward_resnet18, "forward_resnet18");
m.def("forward_resnet34", &forward_resnet34, "forward_resnet34");
m.def("forward_resnet50", &forward_resnet50, "forward_resnet50");
m.def("forward_resnet101", &forward_resnet101, "forward_resnet101");
m.def("forward_resnet152", &forward_resnet152, "forward_resnet152");
m.def(
"forward_resnext50_32x4d",
&forward_resnext50_32x4d,
"forward_resnext50_32x4d");
m.def(
"forward_resnext101_32x8d",
&forward_resnext101_32x8d,
"forward_resnext101_32x8d");

m.def(
"forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0");
m.def(
"forward_squeezenet1_1", &forward_squeezenet1_1, "forward_squeezenet1_1");

m.def("forward_densenet121", &forward_densenet121, "forward_densenet121");
m.def("forward_densenet169", &forward_densenet169, "forward_densenet169");
m.def("forward_densenet201", &forward_densenet201, "forward_densenet201");
m.def("forward_densenet161", &forward_densenet161, "forward_densenet161");

m.def("forward_mobilenetv2", &forward_mobilenetv2, "forward_mobilenetv2");

m.def("forward_googlenet", &forward_googlenet, "forward_googlenet");
m.def("forward_inceptionv3", &forward_inceptionv3, "forward_inceptionv3");
}
Loading