-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added the existing code * Added squeezenet and fixed some stuff in the other models * Wrote DenseNet and a part of InceptionV3 Going to clean and check all of the models and finish inception * Fixed some errors in the models Next step is writing inception and comparing with python code again. * Completed inception and changed models directory * Fixed and wrote some stuff * fixed maxpoool2d and avgpool2d and adaptiveavgpool2d * Fixed a few stuff Moved cmakelists to root and changed the namespace to vision and wrote weight initialization in inception * Added models namespace and changed cmakelists the project is now installable * Removed some comments * Changed style to pytorch style, added some comments and fixed some minor errors * Removed truncated normal init * Changed classes to structs and fixed a few errors * Replaced modelsimpl structs with functional wherever possible * Changed adaptive average pool from struct to function * Wrote a max_pool2d wrapper and added some comments * Replaced xavier init with kaiming init * Fixed an error in kaiming inits * Added model conversion and tests * Fixed a typo in alexnet and removed tests from cmake * Made an extension of tests and added module names to Densenet * Added python tests * Added MobileNet and GoogLeNet models * Added tests and conversions for new models and fixed a few errors * Updated Alexnet ad VGG * Updated Densenet, Squeezenet and Inception * Added ResNexts and their conversions * Added tests for ResNexts * Wrote tools nessesary to write ShuffleNet * Added ShuffleNetV2 * Fixed some errors in ShuffleNetV2 * Added conversions for shufflenetv2 * Fixed the errors in test_models.cpp * Updated setup.py * Fixed flake8 error on test_cpp_models.py * Changed view to reshape in forward of ResNet * Updated ShuffleNetV2 * Split extensions to tests and ops * Fixed test extension * Fixed image path in test_cpp_models.py * Fixed image path in test_cpp_models.py * Fixed a few things in test_cpp_models.py * Put the test models in evaluation mode * Fixed registering error in GoogLeNet * Updated setup.py * write test_cpp_models.py with unittest * Fixed a problem with pytest in test_cpp_models.py * Fixed a lint problem
- Loading branch information
1 parent
394de98
commit b5db97b
Showing
26 changed files
with
2,779 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
cmake_minimum_required(VERSION 2.8) | ||
project(torchvision) | ||
set(CMAKE_CXX_STANDARD 11) | ||
|
||
find_package(Torch REQUIRED) | ||
|
||
file(GLOB_RECURSE HEADERS torchvision/csrc/vision.h) | ||
file(GLOB_RECURSE MODELS_HEADERS torchvision/csrc/models/*.h) | ||
file(GLOB_RECURSE MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp) | ||
|
||
add_library (${PROJECT_NAME} SHARED ${MODELS_SOURCES}) | ||
target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}") | ||
|
||
add_executable(convertmodels torchvision/csrc/convert_models/convert_models.cpp) | ||
target_link_libraries(convertmodels "${PROJECT_NAME}") | ||
target_link_libraries(convertmodels "${TORCH_LIBRARIES}") | ||
|
||
#add_executable(testmodels test/test_models.cpp) | ||
#target_link_libraries(testmodels "${PROJECT_NAME}") | ||
#target_link_libraries(testmodels "${TORCH_LIBRARIES}") | ||
|
||
install(TARGETS ${PROJECT_NAME} DESTINATION ${CMAKE_INSTALL_PREFIX}/lib) | ||
install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME}) | ||
install(FILES ${MODELS_HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME}/models) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import torch | ||
import os | ||
import unittest | ||
from torchvision import models, transforms, _C_tests | ||
|
||
from PIL import Image | ||
import torchvision.transforms.functional as F | ||
|
||
|
||
def process_model(model, tensor, func, name): | ||
model.eval() | ||
traced_script_module = torch.jit.trace(model, tensor) | ||
traced_script_module.save("model.pt") | ||
|
||
py_output = model.forward(tensor) | ||
cpp_output = func("model.pt", tensor) | ||
|
||
assert torch.allclose(py_output, cpp_output), 'Output mismatch of ' + name + ' models' | ||
|
||
|
||
def read_image1(): | ||
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') | ||
image = Image.open(image_path) | ||
image = image.resize((224, 224)) | ||
x = F.to_tensor(image) | ||
return x.view(1, 3, 224, 224) | ||
|
||
|
||
def read_image2(): | ||
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') | ||
image = Image.open(image_path) | ||
image = image.resize((299, 299)) | ||
x = F.to_tensor(image) | ||
x = x.view(1, 3, 299, 299) | ||
return torch.cat([x, x], 0) | ||
|
||
|
||
class Tester(unittest.TestCase): | ||
pretrained = False | ||
image = read_image1() | ||
|
||
def test_alexnet(self): | ||
process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, 'Alexnet') | ||
|
||
def test_vgg11(self): | ||
process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, 'VGG11') | ||
|
||
def test_vgg13(self): | ||
process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, 'VGG13') | ||
|
||
def test_vgg16(self): | ||
process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, 'VGG16') | ||
|
||
def test_vgg19(self): | ||
process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, 'VGG19') | ||
|
||
def test_vgg11_bn(self): | ||
process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, 'VGG11BN') | ||
|
||
def test_vgg13_bn(self): | ||
process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, 'VGG13BN') | ||
|
||
def test_vgg16_bn(self): | ||
process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, 'VGG16BN') | ||
|
||
def test_vgg19_bn(self): | ||
process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, 'VGG19BN') | ||
|
||
def test_resnet18(self): | ||
process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, 'Resnet18') | ||
|
||
def test_resnet34(self): | ||
process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, 'Resnet34') | ||
|
||
def test_resnet50(self): | ||
process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, 'Resnet50') | ||
|
||
def test_resnet101(self): | ||
process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, 'Resnet101') | ||
|
||
def test_resnet152(self): | ||
process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, 'Resnet152') | ||
|
||
def test_resnext50_32x4d(self): | ||
process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, 'ResNext50_32x4d') | ||
|
||
def test_resnext101_32x8d(self): | ||
process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d') | ||
|
||
def test_squeezenet1_0(self): | ||
process_model(models.squeezenet1_0(self.pretrained), self.image, | ||
_C_tests.forward_squeezenet1_0, 'Squeezenet1.0') | ||
|
||
def test_squeezenet1_1(self): | ||
process_model(models.squeezenet1_1(self.pretrained), self.image, | ||
_C_tests.forward_squeezenet1_1, 'Squeezenet1.1') | ||
|
||
def test_densenet121(self): | ||
process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, 'Densenet121') | ||
|
||
def test_densenet169(self): | ||
process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, 'Densenet169') | ||
|
||
def test_densenet201(self): | ||
process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, 'Densenet201') | ||
|
||
def test_densenet161(self): | ||
process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, 'Densenet161') | ||
|
||
def test_mobilenet_v2(self): | ||
process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, 'MobileNet') | ||
|
||
def test_googlenet(self): | ||
process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, 'GoogLeNet') | ||
|
||
def test_inception_v3(self): | ||
self.image = read_image2() | ||
process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, 'Inceptionv3') | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
#include <torch/script.h> | ||
#include <torch/torch.h> | ||
#include <iostream> | ||
|
||
#include "../torchvision/csrc/models/models.h" | ||
|
||
using namespace vision::models; | ||
|
||
template <typename Model> | ||
torch::Tensor forward_model(const std::string& input_path, torch::Tensor x) { | ||
Model network; | ||
torch::load(network, input_path); | ||
network->eval(); | ||
return network->forward(x); | ||
} | ||
|
||
torch::Tensor forward_alexnet(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<AlexNet>(input_path, x); | ||
} | ||
|
||
torch::Tensor forward_vgg11(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<VGG11>(input_path, x); | ||
} | ||
torch::Tensor forward_vgg13(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<VGG13>(input_path, x); | ||
} | ||
torch::Tensor forward_vgg16(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<VGG16>(input_path, x); | ||
} | ||
torch::Tensor forward_vgg19(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<VGG19>(input_path, x); | ||
} | ||
|
||
torch::Tensor forward_vgg11bn(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<VGG11BN>(input_path, x); | ||
} | ||
torch::Tensor forward_vgg13bn(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<VGG13BN>(input_path, x); | ||
} | ||
torch::Tensor forward_vgg16bn(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<VGG16BN>(input_path, x); | ||
} | ||
torch::Tensor forward_vgg19bn(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<VGG19BN>(input_path, x); | ||
} | ||
|
||
torch::Tensor forward_resnet18(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<ResNet18>(input_path, x); | ||
} | ||
torch::Tensor forward_resnet34(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<ResNet34>(input_path, x); | ||
} | ||
torch::Tensor forward_resnet50(const std::string& input_path, torch::Tensor x) { | ||
return forward_model<ResNet50>(input_path, x); | ||
} | ||
torch::Tensor forward_resnet101( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
return forward_model<ResNet101>(input_path, x); | ||
} | ||
torch::Tensor forward_resnet152( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
return forward_model<ResNet152>(input_path, x); | ||
} | ||
torch::Tensor forward_resnext50_32x4d( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
return forward_model<ResNext50_32x4d>(input_path, x); | ||
} | ||
torch::Tensor forward_resnext101_32x8d( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
return forward_model<ResNext101_32x8d>(input_path, x); | ||
} | ||
|
||
torch::Tensor forward_squeezenet1_0( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
return forward_model<SqueezeNet1_0>(input_path, x); | ||
} | ||
torch::Tensor forward_squeezenet1_1( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
return forward_model<SqueezeNet1_1>(input_path, x); | ||
} | ||
|
||
torch::Tensor forward_densenet121( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
return forward_model<DenseNet121>(input_path, x); | ||
} | ||
torch::Tensor forward_densenet169( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
return forward_model<DenseNet169>(input_path, x); | ||
} | ||
torch::Tensor forward_densenet201( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
return forward_model<DenseNet201>(input_path, x); | ||
} | ||
torch::Tensor forward_densenet161( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
return forward_model<DenseNet161>(input_path, x); | ||
} | ||
|
||
torch::Tensor forward_mobilenetv2( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
return forward_model<MobileNetV2>(input_path, x); | ||
} | ||
|
||
torch::Tensor forward_googlenet( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
GoogLeNet network; | ||
torch::load(network, input_path); | ||
network->eval(); | ||
return network->forward(x).output; | ||
} | ||
torch::Tensor forward_inceptionv3( | ||
const std::string& input_path, | ||
torch::Tensor x) { | ||
InceptionV3 network; | ||
torch::load(network, input_path); | ||
network->eval(); | ||
return network->forward(x).output; | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("forward_alexnet", &forward_alexnet, "forward_alexnet"); | ||
|
||
m.def("forward_vgg11", &forward_vgg11, "forward_vgg11"); | ||
m.def("forward_vgg13", &forward_vgg13, "forward_vgg13"); | ||
m.def("forward_vgg16", &forward_vgg16, "forward_vgg16"); | ||
m.def("forward_vgg19", &forward_vgg19, "forward_vgg19"); | ||
|
||
m.def("forward_vgg11bn", &forward_vgg11bn, "forward_vgg11bn"); | ||
m.def("forward_vgg13bn", &forward_vgg13bn, "forward_vgg13bn"); | ||
m.def("forward_vgg16bn", &forward_vgg16bn, "forward_vgg16bn"); | ||
m.def("forward_vgg19bn", &forward_vgg19bn, "forward_vgg19bn"); | ||
|
||
m.def("forward_resnet18", &forward_resnet18, "forward_resnet18"); | ||
m.def("forward_resnet34", &forward_resnet34, "forward_resnet34"); | ||
m.def("forward_resnet50", &forward_resnet50, "forward_resnet50"); | ||
m.def("forward_resnet101", &forward_resnet101, "forward_resnet101"); | ||
m.def("forward_resnet152", &forward_resnet152, "forward_resnet152"); | ||
m.def( | ||
"forward_resnext50_32x4d", | ||
&forward_resnext50_32x4d, | ||
"forward_resnext50_32x4d"); | ||
m.def( | ||
"forward_resnext101_32x8d", | ||
&forward_resnext101_32x8d, | ||
"forward_resnext101_32x8d"); | ||
|
||
m.def( | ||
"forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0"); | ||
m.def( | ||
"forward_squeezenet1_1", &forward_squeezenet1_1, "forward_squeezenet1_1"); | ||
|
||
m.def("forward_densenet121", &forward_densenet121, "forward_densenet121"); | ||
m.def("forward_densenet169", &forward_densenet169, "forward_densenet169"); | ||
m.def("forward_densenet201", &forward_densenet201, "forward_densenet201"); | ||
m.def("forward_densenet161", &forward_densenet161, "forward_densenet161"); | ||
|
||
m.def("forward_mobilenetv2", &forward_mobilenetv2, "forward_mobilenetv2"); | ||
|
||
m.def("forward_googlenet", &forward_googlenet, "forward_googlenet"); | ||
m.def("forward_inceptionv3", &forward_inceptionv3, "forward_inceptionv3"); | ||
} |
Oops, something went wrong.