diff --git a/README.rst b/README.rst index bca61ce4680..a52c7f2bd3d 100644 --- a/README.rst +++ b/README.rst @@ -106,9 +106,9 @@ otherwise, add the include and library paths in the environment variables ``TORC .. _libjpeg: http://ijg.org/ .. _libjpeg-turbo: https://libjpeg-turbo.org/ -C++ API -======= -TorchVision also offers a C++ API that contains C++ equivalent of python models. +Using the models on C++ +======================= +TorchVision provides an example project for how to use the models on C++ using JIT Script. Installation From source: diff --git a/examples/cpp/hello_world/main.cpp b/examples/cpp/hello_world/main.cpp index 3a75bdec6cb..bcbe68dd07d 100644 --- a/examples/cpp/hello_world/main.cpp +++ b/examples/cpp/hello_world/main.cpp @@ -1,25 +1,44 @@ #include +#include #include #include -#include -int main() -{ - auto model = vision::models::ResNet18(); - model->eval(); +int main() { + torch::DeviceType device_type; + device_type = torch::kCPU; - // Create a random input tensor and run it through the model. - auto in = torch::rand({1, 3, 10, 10}); - auto out = model->forward(in); + torch::jit::script::Module model; + try { + std::cout << "Loading model\n"; + // Deserialize the ScriptModule from a file using torch::jit::load(). + model = torch::jit::load("resnet18.pt"); + std::cout << "Model loaded\n"; + } catch (const torch::Error& e) { + std::cout << "error loading the model\n"; + return -1; + } catch (const std::exception& e) { + std::cout << "Other error: " << e.what() << "\n"; + return -1; + } - std::cout << out.sizes(); + // TorchScript models require a List[IValue] as input + std::vector inputs; + + // Create a random input tensor and run it through the model. + inputs.push_back(torch::rand({1, 3, 10, 10})); + auto out = model.forward(inputs); + std::cout << out << "\n"; if (torch::cuda::is_available()) { // Move model and inputs to GPU - model->to(torch::kCUDA); - auto gpu_in = in.to(torch::kCUDA); - auto gpu_out = model->forward(gpu_in); + model.to(torch::kCUDA); + + // Add GPU inputs + inputs.clear(); + torch::TensorOptions options = torch::TensorOptions{torch::kCUDA}; + inputs.push_back(torch::rand({1, 3, 10, 10}, options)); - std::cout << gpu_out.sizes(); + auto gpu_out = model.forward(inputs); + std::cout << gpu_out << "\n"; } } diff --git a/examples/cpp/hello_world/trace_model.py b/examples/cpp/hello_world/trace_model.py new file mode 100644 index 00000000000..c8b8d6911e7 --- /dev/null +++ b/examples/cpp/hello_world/trace_model.py @@ -0,0 +1,13 @@ +import os.path as osp + +import torch +import torchvision + +HERE = osp.dirname(osp.abspath(__file__)) +ASSETS = osp.dirname(osp.dirname(HERE)) + +model = torchvision.models.resnet18(pretrained=False) +model.eval() + +traced_model = torch.jit.script(model) +traced_model.save("resnet18.pt") diff --git a/packaging/build_cmake.sh b/packaging/build_cmake.sh index 5950c176b5d..f99922bb4d5 100755 --- a/packaging/build_cmake.sh +++ b/packaging/build_cmake.sh @@ -98,13 +98,18 @@ fi # Compile and run the CPP example popd cd examples/cpp/hello_world - mkdir build + +# Trace model +python trace_model.py +cp resnet18.pt build + cd build cmake .. -DTorch_DIR=$TORCH_PATH/share/cmake/Torch if [[ "$OSTYPE" == "msys" ]]; then "$script_dir/windows/internal/vc_env_helper.bat" "$script_dir/windows/internal/build_cpp_example.bat" $PARALLELISM + mv resnet18.pt Release cd Release else make -j$PARALLELISM diff --git a/torchvision/csrc/models/alexnet.cpp b/torchvision/csrc/models/alexnet.cpp index e29674b706a..8a6bfd9dacb 100644 --- a/torchvision/csrc/models/alexnet.cpp +++ b/torchvision/csrc/models/alexnet.cpp @@ -32,6 +32,8 @@ AlexNetImpl::AlexNetImpl(int64_t num_classes) { register_module("features", features); register_module("classifier", classifier); + + modelsimpl::deprecation_warning(); } torch::Tensor AlexNetImpl::forward(torch::Tensor x) { diff --git a/torchvision/csrc/models/densenet.cpp b/torchvision/csrc/models/densenet.cpp index 145748b1449..5eff294c1c4 100644 --- a/torchvision/csrc/models/densenet.cpp +++ b/torchvision/csrc/models/densenet.cpp @@ -142,6 +142,8 @@ DenseNetImpl::DenseNetImpl( } else if (auto M = dynamic_cast(module.get())) torch::nn::init::constant_(M->bias, 0); } + + modelsimpl::deprecation_warning(); } torch::Tensor DenseNetImpl::forward(torch::Tensor x) { diff --git a/torchvision/csrc/models/googlenet.cpp b/torchvision/csrc/models/googlenet.cpp index 9e381b1628d..563f75d6380 100644 --- a/torchvision/csrc/models/googlenet.cpp +++ b/torchvision/csrc/models/googlenet.cpp @@ -1,5 +1,7 @@ #include "googlenet.h" +#include "modelsimpl.h" + namespace vision { namespace models { @@ -143,6 +145,8 @@ GoogLeNetImpl::GoogLeNetImpl( if (init_weights) _initialize_weights(); + + modelsimpl::deprecation_warning(); } void GoogLeNetImpl::_initialize_weights() { diff --git a/torchvision/csrc/models/inception.cpp b/torchvision/csrc/models/inception.cpp index 002bf7ee2d1..f94f89778b2 100644 --- a/torchvision/csrc/models/inception.cpp +++ b/torchvision/csrc/models/inception.cpp @@ -1,5 +1,7 @@ #include "inception.h" +#include "modelsimpl.h" + namespace vision { namespace models { @@ -297,6 +299,8 @@ InceptionV3Impl::InceptionV3Impl( register_module("Mixed_7b", Mixed_7b); register_module("Mixed_7c", Mixed_7c); register_module("fc", fc); + + modelsimpl::deprecation_warning(); } InceptionV3Output InceptionV3Impl::forward(torch::Tensor x) { diff --git a/torchvision/csrc/models/mnasnet.cpp b/torchvision/csrc/models/mnasnet.cpp index 8e433c50bd9..7bb5eb9c7da 100644 --- a/torchvision/csrc/models/mnasnet.cpp +++ b/torchvision/csrc/models/mnasnet.cpp @@ -158,6 +158,8 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) { register_module("classifier", classifier); _initialize_weights(); + + modelsimpl::deprecation_warning(); } torch::Tensor MNASNetImpl::forward(torch::Tensor x) { diff --git a/torchvision/csrc/models/mobilenet.cpp b/torchvision/csrc/models/mobilenet.cpp index beeec89653b..54655f76f82 100644 --- a/torchvision/csrc/models/mobilenet.cpp +++ b/torchvision/csrc/models/mobilenet.cpp @@ -146,6 +146,8 @@ MobileNetV2Impl::MobileNetV2Impl( torch::nn::init::zeros_(M->bias); } } + + modelsimpl::deprecation_warning(); } torch::Tensor MobileNetV2Impl::forward(at::Tensor x) { diff --git a/torchvision/csrc/models/modelsimpl.h b/torchvision/csrc/models/modelsimpl.h index f159d1502a3..d4227647804 100644 --- a/torchvision/csrc/models/modelsimpl.h +++ b/torchvision/csrc/models/modelsimpl.h @@ -34,6 +34,13 @@ inline bool double_compare(double a, double b) { return double(std::abs(a - b)) < std::numeric_limits::epsilon(); }; +inline void deprecation_warning() { + TORCH_WARN_ONCE( + "The vision::models namespace is not actively maintained, use at " + "your own discretion. We recommend using Torch Script instead: " + "https://pytorch.org/tutorials/advanced/cpp_export.html"); +} + } // namespace modelsimpl } // namespace models } // namespace vision diff --git a/torchvision/csrc/models/resnet.h b/torchvision/csrc/models/resnet.h index 7e41de6e072..4b32bfc76b8 100644 --- a/torchvision/csrc/models/resnet.h +++ b/torchvision/csrc/models/resnet.h @@ -2,6 +2,7 @@ #include #include "../macros.h" +#include "modelsimpl.h" namespace vision { namespace models { @@ -164,6 +165,8 @@ ResNetImpl::ResNetImpl( else if (auto* M = dynamic_cast<_resnetimpl::BasicBlock*>(module.get())) torch::nn::init::constant_(M->bn2->weight, 0); } + + modelsimpl::deprecation_warning(); } template diff --git a/torchvision/csrc/models/shufflenetv2.cpp b/torchvision/csrc/models/shufflenetv2.cpp index d84c11de42c..0dce3665115 100644 --- a/torchvision/csrc/models/shufflenetv2.cpp +++ b/torchvision/csrc/models/shufflenetv2.cpp @@ -146,6 +146,8 @@ ShuffleNetV2Impl::ShuffleNetV2Impl( register_module("stage4", stage4); register_module("conv2", conv5); register_module("fc", fc); + + modelsimpl::deprecation_warning(); } torch::Tensor ShuffleNetV2Impl::forward(torch::Tensor x) { diff --git a/torchvision/csrc/models/squeezenet.cpp b/torchvision/csrc/models/squeezenet.cpp index 96a9a1800d0..3f0820da3a8 100644 --- a/torchvision/csrc/models/squeezenet.cpp +++ b/torchvision/csrc/models/squeezenet.cpp @@ -93,6 +93,8 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes) if (M->options.bias()) torch::nn::init::constant_(M->bias, 0); } + + modelsimpl::deprecation_warning(); } torch::Tensor SqueezeNetImpl::forward(torch::Tensor x) { diff --git a/torchvision/csrc/models/vgg.cpp b/torchvision/csrc/models/vgg.cpp index 73d32d98214..61c3cb844c1 100644 --- a/torchvision/csrc/models/vgg.cpp +++ b/torchvision/csrc/models/vgg.cpp @@ -69,6 +69,8 @@ VGGImpl::VGGImpl( if (initialize_weights) _initialize_weights(); + + modelsimpl::deprecation_warning(); } torch::Tensor VGGImpl::forward(torch::Tensor x) {