Skip to content

Commit

Permalink
C++ Models (#728)
Browse files Browse the repository at this point in the history
* Added the existing code

* Added squeezenet and fixed some stuff in the other models

* Wrote DenseNet and a part of InceptionV3

Going to clean and check all of the models and finish inception

* Fixed some errors in the models

Next step is writing inception and comparing with python code again.

* Completed inception and changed models directory

* Fixed and wrote some stuff

* fixed maxpoool2d and avgpool2d and adaptiveavgpool2d

* Fixed a few stuff

Moved cmakelists to root and changed the namespace to vision and wrote weight initialization in inception

* Added models namespace and changed cmakelists

the project is now installable

* Removed some comments

* Changed style to pytorch style, added some comments and fixed some minor errors

* Removed truncated normal init

* Changed classes to structs and fixed a few errors

* Replaced modelsimpl structs with functional wherever possible

* Changed adaptive average pool from struct to function

* Wrote a max_pool2d wrapper and added some comments

* Replaced xavier init with kaiming init

* Fixed an error in kaiming inits

* Added model conversion and tests

* Fixed a typo in alexnet and removed tests from cmake

* Made an extension of tests and added module names to Densenet

* Added python tests

* Added MobileNet and GoogLeNet models

* Added tests and conversions for new models and fixed a few errors

* Updated Alexnet ad VGG

* Updated Densenet, Squeezenet and Inception

* Added ResNexts and their conversions

* Added tests for ResNexts

* Wrote tools nessesary to write ShuffleNet

* Added ShuffleNetV2

* Fixed some errors in ShuffleNetV2

* Added conversions for shufflenetv2

* Fixed the errors in test_models.cpp

* Updated setup.py

* Fixed flake8 error on test_cpp_models.py

* Changed view to reshape in forward of ResNet

* Updated ShuffleNetV2

* Split extensions to tests and ops

* Fixed test extension

* Fixed image path in test_cpp_models.py

* Fixed image path in test_cpp_models.py

* Fixed a few things in test_cpp_models.py

* Put the test models in evaluation mode

* Fixed registering error in GoogLeNet

* Updated setup.py

* write test_cpp_models.py with unittest

* Fixed a problem with pytest in test_cpp_models.py

* Fixed a lint problem
  • Loading branch information
ShahriarSS authored and fmassa committed Jun 11, 2019
1 parent 394de98 commit b5db97b
Show file tree
Hide file tree
Showing 26 changed files with 2,779 additions and 0 deletions.
24 changes: 24 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
cmake_minimum_required(VERSION 2.8)
project(torchvision)
set(CMAKE_CXX_STANDARD 11)

find_package(Torch REQUIRED)

file(GLOB_RECURSE HEADERS torchvision/csrc/vision.h)
file(GLOB_RECURSE MODELS_HEADERS torchvision/csrc/models/*.h)
file(GLOB_RECURSE MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp)

add_library (${PROJECT_NAME} SHARED ${MODELS_SOURCES})
target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}")

add_executable(convertmodels torchvision/csrc/convert_models/convert_models.cpp)
target_link_libraries(convertmodels "${PROJECT_NAME}")
target_link_libraries(convertmodels "${TORCH_LIBRARIES}")

#add_executable(testmodels test/test_models.cpp)
#target_link_libraries(testmodels "${PROJECT_NAME}")
#target_link_libraries(testmodels "${TORCH_LIBRARIES}")

install(TARGETS ${PROJECT_NAME} DESTINATION ${CMAKE_INSTALL_PREFIX}/lib)
install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME})
install(FILES ${MODELS_HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME}/models)
17 changes: 17 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ def get_extensions():
sources = main_file + source_cpu
extension = CppExtension

test_dir = os.path.join(this_dir, 'test')
models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
source_models = glob.glob(os.path.join(models_dir, '*.cpp'))

test_file = [os.path.join(test_dir, s) for s in test_file]
source_models = [os.path.join(models_dir, s) for s in source_models]
tests = test_file + source_models

define_macros = []

extra_compile_args = {}
Expand All @@ -109,6 +118,7 @@ def get_extensions():
sources = [os.path.join(extensions_dir, s) for s in sources]

include_dirs = [extensions_dir]
tests_include_dirs = [test_dir, models_dir]

ext_modules = [
extension(
Expand All @@ -117,6 +127,13 @@ def get_extensions():
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
),
extension(
'torchvision._C_tests',
tests,
include_dirs=tests_include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]

Expand Down
122 changes: 122 additions & 0 deletions test/test_cpp_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import torch
import os
import unittest
from torchvision import models, transforms, _C_tests

from PIL import Image
import torchvision.transforms.functional as F


def process_model(model, tensor, func, name):
model.eval()
traced_script_module = torch.jit.trace(model, tensor)
traced_script_module.save("model.pt")

py_output = model.forward(tensor)
cpp_output = func("model.pt", tensor)

assert torch.allclose(py_output, cpp_output), 'Output mismatch of ' + name + ' models'


def read_image1():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
image = Image.open(image_path)
image = image.resize((224, 224))
x = F.to_tensor(image)
return x.view(1, 3, 224, 224)


def read_image2():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
image = Image.open(image_path)
image = image.resize((299, 299))
x = F.to_tensor(image)
x = x.view(1, 3, 299, 299)
return torch.cat([x, x], 0)


class Tester(unittest.TestCase):
pretrained = False
image = read_image1()

def test_alexnet(self):
process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, 'Alexnet')

def test_vgg11(self):
process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, 'VGG11')

def test_vgg13(self):
process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, 'VGG13')

def test_vgg16(self):
process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, 'VGG16')

def test_vgg19(self):
process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, 'VGG19')

def test_vgg11_bn(self):
process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, 'VGG11BN')

def test_vgg13_bn(self):
process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, 'VGG13BN')

def test_vgg16_bn(self):
process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, 'VGG16BN')

def test_vgg19_bn(self):
process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, 'VGG19BN')

def test_resnet18(self):
process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, 'Resnet18')

def test_resnet34(self):
process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, 'Resnet34')

def test_resnet50(self):
process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, 'Resnet50')

def test_resnet101(self):
process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, 'Resnet101')

def test_resnet152(self):
process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, 'Resnet152')

def test_resnext50_32x4d(self):
process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, 'ResNext50_32x4d')

def test_resnext101_32x8d(self):
process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d')

def test_squeezenet1_0(self):
process_model(models.squeezenet1_0(self.pretrained), self.image,
_C_tests.forward_squeezenet1_0, 'Squeezenet1.0')

def test_squeezenet1_1(self):
process_model(models.squeezenet1_1(self.pretrained), self.image,
_C_tests.forward_squeezenet1_1, 'Squeezenet1.1')

def test_densenet121(self):
process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, 'Densenet121')

def test_densenet169(self):
process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, 'Densenet169')

def test_densenet201(self):
process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, 'Densenet201')

def test_densenet161(self):
process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, 'Densenet161')

def test_mobilenet_v2(self):
process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, 'MobileNet')

def test_googlenet(self):
process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, 'GoogLeNet')

def test_inception_v3(self):
self.image = read_image2()
process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, 'Inceptionv3')


if __name__ == '__main__':
unittest.main()
173 changes: 173 additions & 0 deletions test/test_models.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>

#include "../torchvision/csrc/models/models.h"

using namespace vision::models;

template <typename Model>
torch::Tensor forward_model(const std::string& input_path, torch::Tensor x) {
Model network;
torch::load(network, input_path);
network->eval();
return network->forward(x);
}

torch::Tensor forward_alexnet(const std::string& input_path, torch::Tensor x) {
return forward_model<AlexNet>(input_path, x);
}

torch::Tensor forward_vgg11(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG11>(input_path, x);
}
torch::Tensor forward_vgg13(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG13>(input_path, x);
}
torch::Tensor forward_vgg16(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG16>(input_path, x);
}
torch::Tensor forward_vgg19(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG19>(input_path, x);
}

torch::Tensor forward_vgg11bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG11BN>(input_path, x);
}
torch::Tensor forward_vgg13bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG13BN>(input_path, x);
}
torch::Tensor forward_vgg16bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG16BN>(input_path, x);
}
torch::Tensor forward_vgg19bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG19BN>(input_path, x);
}

torch::Tensor forward_resnet18(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet18>(input_path, x);
}
torch::Tensor forward_resnet34(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet34>(input_path, x);
}
torch::Tensor forward_resnet50(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet50>(input_path, x);
}
torch::Tensor forward_resnet101(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNet101>(input_path, x);
}
torch::Tensor forward_resnet152(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNet152>(input_path, x);
}
torch::Tensor forward_resnext50_32x4d(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNext50_32x4d>(input_path, x);
}
torch::Tensor forward_resnext101_32x8d(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNext101_32x8d>(input_path, x);
}

torch::Tensor forward_squeezenet1_0(
const std::string& input_path,
torch::Tensor x) {
return forward_model<SqueezeNet1_0>(input_path, x);
}
torch::Tensor forward_squeezenet1_1(
const std::string& input_path,
torch::Tensor x) {
return forward_model<SqueezeNet1_1>(input_path, x);
}

torch::Tensor forward_densenet121(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet121>(input_path, x);
}
torch::Tensor forward_densenet169(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet169>(input_path, x);
}
torch::Tensor forward_densenet201(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet201>(input_path, x);
}
torch::Tensor forward_densenet161(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet161>(input_path, x);
}

torch::Tensor forward_mobilenetv2(
const std::string& input_path,
torch::Tensor x) {
return forward_model<MobileNetV2>(input_path, x);
}

torch::Tensor forward_googlenet(
const std::string& input_path,
torch::Tensor x) {
GoogLeNet network;
torch::load(network, input_path);
network->eval();
return network->forward(x).output;
}
torch::Tensor forward_inceptionv3(
const std::string& input_path,
torch::Tensor x) {
InceptionV3 network;
torch::load(network, input_path);
network->eval();
return network->forward(x).output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_alexnet", &forward_alexnet, "forward_alexnet");

m.def("forward_vgg11", &forward_vgg11, "forward_vgg11");
m.def("forward_vgg13", &forward_vgg13, "forward_vgg13");
m.def("forward_vgg16", &forward_vgg16, "forward_vgg16");
m.def("forward_vgg19", &forward_vgg19, "forward_vgg19");

m.def("forward_vgg11bn", &forward_vgg11bn, "forward_vgg11bn");
m.def("forward_vgg13bn", &forward_vgg13bn, "forward_vgg13bn");
m.def("forward_vgg16bn", &forward_vgg16bn, "forward_vgg16bn");
m.def("forward_vgg19bn", &forward_vgg19bn, "forward_vgg19bn");

m.def("forward_resnet18", &forward_resnet18, "forward_resnet18");
m.def("forward_resnet34", &forward_resnet34, "forward_resnet34");
m.def("forward_resnet50", &forward_resnet50, "forward_resnet50");
m.def("forward_resnet101", &forward_resnet101, "forward_resnet101");
m.def("forward_resnet152", &forward_resnet152, "forward_resnet152");
m.def(
"forward_resnext50_32x4d",
&forward_resnext50_32x4d,
"forward_resnext50_32x4d");
m.def(
"forward_resnext101_32x8d",
&forward_resnext101_32x8d,
"forward_resnext101_32x8d");

m.def(
"forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0");
m.def(
"forward_squeezenet1_1", &forward_squeezenet1_1, "forward_squeezenet1_1");

m.def("forward_densenet121", &forward_densenet121, "forward_densenet121");
m.def("forward_densenet169", &forward_densenet169, "forward_densenet169");
m.def("forward_densenet201", &forward_densenet201, "forward_densenet201");
m.def("forward_densenet161", &forward_densenet161, "forward_densenet161");

m.def("forward_mobilenetv2", &forward_mobilenetv2, "forward_mobilenetv2");

m.def("forward_googlenet", &forward_googlenet, "forward_googlenet");
m.def("forward_inceptionv3", &forward_inceptionv3, "forward_inceptionv3");
}
Loading

0 comments on commit b5db97b

Please sign in to comment.