diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..4fa4a56d --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,20 @@ +cmake_minimum_required(VERSION 3.14) + +project( + tensorrtx + VERSION 0.1 + LANGUAGES C CXX CUDA) + +set(TensorRT_7_8_10_TARGETS mlp lenet) + +set(TensorRT_8_TARGETS) + +set(TensorRT_10_TARGETS) + +set(ALL_TARGETS ${TensorRT_7_8_10_TARGETS} ${TensorRT_8_TARGETS} + ${TensorRT_10_TARGETS}) + +foreach(sub_dir ${ALL_TARGETS}) + message(STATUS "Add subdirectory: ${sub_dir}") + add_subdirectory(${sub_dir}) +endforeach() diff --git a/README.md b/README.md index 7e2083ba..a2f1cae9 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ The basic workflow of TensorRTx is: - [A guide for quickly getting started, taking lenet5 as a demo.](./tutorials/getting_started.md) - [The .wts file content format](./tutorials/getting_started.md#the-wts-content-format) - [Frequently Asked Questions (FAQ)](./tutorials/faq.md) -- [Migrating from TensorRT 4 to 7](./tutorials/migrating_from_tensorrt_4_to_7.md) +- [Migration Guide](./tutorials/migration_guide.md) - [How to implement multi-GPU processing, taking YOLOv4 as example](./tutorials/multi_GPU_processing.md) - [Check if Your GPU support FP16/INT8](./tutorials/check_fp16_int8_support.md) - [How to Compile and Run on Windows](./tutorials/run_on_windows.md) @@ -47,21 +47,80 @@ The basic workflow of TensorRTx is: ## Test Environment -1. TensorRT 7.x -2. TensorRT 8.x(Some of the models support 8.x) +1. (**NOT recommended**) TensorRT 7.x +2. (**Recommended**)TensorRT 8.x +3. (**NOT recommended**) TensorRT 10.x + +### Note + +1. For history reason, some of the models are limited to specific TensorRT version, please check the README.md or code for the model you want to use. +2. Currently, TensorRT 8.x has better compatibility and the most of the features supported. ## How to run -Each folder has a readme inside, which explains how to run the models inside. +**Note**: this project support to build each network by the `CMakeLists.txt` in its subfolder, or you can build them together by the `CMakeLists.txt` on top of this project. + +* General procedures before building and running: + +```bash +# 1. generate xxx.wts from https://github.com/wang-xinyu/pytorchx/tree/master/lenet +# ... + +# 2. put xxx.wts on top of this folder +# ... +``` + +* (*Option 1*) To build a single subproject in this project, do: + +```bash +## enter the subfolder +cd tensorrtx/xxx + +## configure & build +cmake -S . -B build +make -C build +``` + +* (*Option 2*) To build many subprojects, firstly, in the top `CMakeLists.txt`, **uncomment** the project you don't want to build or not suppoted by your TensorRT version, e.g., you cannot build subprojects in `${TensorRT_8_Targets}` if your TensorRT is `7.x`. Then: + +```bash +## enter the top of this project +cd tensorrtx + +## configure & build +# you may use "Ninja" rather than "make" to significantly boost the build speed +cmake -G Ninja -S . -B build +ninja -C build +``` + +**WARNING**: This part is still under development, most subprojects are not adapted yet. + +* run the generated executable, e.g.: + +```bash +# serialize model to plan file i.e. 'xxx.engine' +build/xxx -s + +# deserialize plan file and run inference +build/xxx -d + +# (Optional) check if the output is same as pytorchx/lenet +# ... + +# (Optional) customize the project +# ... +``` + +For more details, each subfolder may contain a `README.md` inside, which explains more. ## Models Following models are implemented. -|Name | Description | -|-|-| -|[mlp](./mlp) | the very basic model for starters, properly documented | -|[lenet](./lenet) | the simplest, as a "hello world" of this project | +| Name | Description | Supported TensorRT Version | +|---------------|---------------|---------------| +|[mlp](./mlp) | the very basic model for starters, properly documented | 7.x/8.x/10.x | +|[lenet](./lenet) | the simplest, as a "hello world" of this project | 7.x/8.x/10.x | |[alexnet](./alexnet)| easy to implement, all layers are supported in tensorrt | |[googlenet](./googlenet)| GoogLeNet (Inception v1) | |[inception](./inception)| Inception v3, v4 | diff --git a/docker/README.md b/docker/README.md index b96952a4..2fa21671 100644 --- a/docker/README.md +++ b/docker/README.md @@ -49,11 +49,11 @@ Change the `TAG` on top of the `.dockerfile`. Note: all images are officially ow For more detail of the support matrix, please check [HERE](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html) -### How to customize opencv? +### How to customize the opencv in the image? If prebuilt package from apt cannot meet your requirements, please refer to the demo code in `.dockerfile` to build opencv from source. -### How to solve image build fail issues? +### How to solve failiures when building image? For *443 timeout* or any similar network issues, a proxy may required. To make your host proxy work for building env of docker, please change the `build` node inside docker-compose file like this: ```YAML diff --git a/docker/tensorrtx-docker-compose.yml b/docker/tensorrtx-docker-compose.yml index a1c68864..ca0894f3 100644 --- a/docker/tensorrtx-docker-compose.yml +++ b/docker/tensorrtx-docker-compose.yml @@ -1,6 +1,6 @@ services: tensorrt: - image: tensortx:1.0.0 + image: tensortx:1.0.1 container_name: tensortx environment: - NVIDIA_VISIBLE_DEVICES=all diff --git a/docker/x86_64.dockerfile b/docker/x86_64.dockerfile index 8949c201..00ae5bce 100644 --- a/docker/x86_64.dockerfile +++ b/docker/x86_64.dockerfile @@ -7,13 +7,16 @@ ENV DEBIAN_FRONTEND noninteractive # basic tools RUN apt update && apt-get install -y --fix-missing --no-install-recommends \ sudo wget curl git ca-certificates ninja-build tzdata pkg-config \ -gdb libglib2.0-dev libmount-dev \ +gdb libglib2.0-dev libmount-dev locales \ && rm -rf /var/lib/apt/lists/* RUN pip install --no-cache-dir yapf isort cmake-format pre-commit +## fix a potential pre-commit error +RUN locale-gen "en_US.UTF-8" + ## override older cmake RUN find /usr/local/share -type d -name "cmake-*" -exec rm -rf {} + \ -&& curl -fsSL "https://github.com/Kitware/CMake/releases/download/v3.29.0/cmake-3.29.0-linux-x86_64.sh" \ +&& curl -fsSL "https://github.com/Kitware/CMake/releases/download/v3.30.0/cmake-3.30.0-linux-x86_64.sh" \ -o cmake.sh && bash cmake.sh --skip-license --exclude-subdir --prefix=/usr/local && rm cmake.sh RUN apt update && apt-get install -y \ diff --git a/lenet/CMakeLists.txt b/lenet/CMakeLists.txt index abf00f00..2249c0af 100644 --- a/lenet/CMakeLists.txt +++ b/lenet/CMakeLists.txt @@ -1,29 +1,43 @@ -cmake_minimum_required(VERSION 2.6) - -project(lenet) - -add_definitions(-std=c++11) - -set(TARGET_NAME "lenet") - -option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) -set(CMAKE_CXX_STANDARD 11) -set(CMAKE_BUILD_TYPE Debug) - -include_directories(${PROJECT_SOURCE_DIR}/include) -# include and link dirs of cuda and tensorrt, you need adapt them if yours are different -# cuda -include_directories(/usr/local/cuda/include) -link_directories(/usr/local/cuda/lib64) -# tensorrt -include_directories(/usr/include/x86_64-linux-gnu/) -link_directories(/usr/lib/x86_64-linux-gnu/) - -FILE(GLOB SRC_FILES ${PROJECT_SOURCE_DIR}/lenet.cpp ${PROJECT_SOURCE_DIR}/include/*.h) - -add_executable(${TARGET_NAME} ${SRC_FILES}) -target_link_libraries(${TARGET_NAME} nvinfer) -target_link_libraries(${TARGET_NAME} cudart) - -add_definitions(-O2 -pthread) - +cmake_minimum_required(VERSION 3.17.0) + +project( + lenet + VERSION 0.1 + LANGUAGES C CXX CUDA) + +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES + 60 + 70 + 72 + 75 + 80 + 86 + 89) +endif() + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_INCLUDE_CURRENT_DIR TRUE) +set(CMAKE_BUILD_TYPE + "Debug" + CACHE STRING "Build type for this project" FORCE) + +option(CUDA_USE_STATIC_CUDA_RUNTIME "Use static cudaruntime library" OFF) + +find_package(Threads REQUIRED) +find_package(CUDAToolkit REQUIRED) + +if(NOT TARGET TensorRT::TensorRT) + include(FindTensorRT.cmake) +else() + message("TensorRT has been found, skipping for ${PROJECT_NAME}") +endif() + +add_executable(${PROJECT_NAME} lenet.cpp) + +target_link_libraries(${PROJECT_NAME} PUBLIC Threads::Threads CUDA::cudart + TensorRT::TensorRT) diff --git a/lenet/FindTensorRT.cmake b/lenet/FindTensorRT.cmake new file mode 100644 index 00000000..d4b5e719 --- /dev/null +++ b/lenet/FindTensorRT.cmake @@ -0,0 +1,79 @@ +cmake_minimum_required(VERSION 3.17.0) + +set(TRT_VERSION + $ENV{TRT_VERSION} + CACHE STRING + "TensorRT version, e.g. \"8.6.1.6\" or \"8.6.1.6+cuda12.0.1.011\"") + +# find TensorRT include folder +if(NOT TensorRT_INCLUDE_DIR) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") + set(TensorRT_INCLUDE_DIR + "/usr/local/cuda/targets/aarch64-linux/include" + CACHE PATH "TensorRT_INCLUDE_DIR") + else() + set(TensorRT_INCLUDE_DIR + "/usr/include/x86_64-linux-gnu" + CACHE PATH "TensorRT_INCLUDE_DIR") + endif() + message(STATUS "TensorRT: ${TensorRT_INCLUDE_DIR}") +endif() + +# find TensorRT library folder +if(NOT TensorRT_LIBRARY_DIR) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") + set(TensorRT_LIBRARY_DIR + "/usr/lib/aarch64-linux-gnu/tegra" + CACHE PATH "TensorRT_LIBRARY_DIR") + else() + set(TensorRT_LIBRARY_DIR + "/usr/include/x86_64-linux-gnu" + CACHE PATH "TensorRT_LIBRARY_DIR") + endif() + message(STATUS "TensorRT: ${TensorRT_LIBRARY_DIR}") +endif() + +set(TensorRT_LIBRARIES) + +message(STATUS "Found TensorRT lib: ${TensorRT_LIBRARIES}") + +# process for different TensorRT version +if(DEFINED TRT_VERSION AND NOT TRT_VERSION STREQUAL "") + string(REGEX MATCH "([0-9]+)" _match ${TRT_VERSION}) + set(TRT_MAJOR_VERSION "${_match}") + set(_modules nvinfer nvinfer_plugin) + unset(_match) + + if(TRT_MAJOR_VERSION GREATER_EQUAL 8) + list(APPEND _modules nvinfer_vc_plugin nvinfer_dispatch nvinfer_lean) + endif() +else() + message(FATAL_ERROR "Please set a environment variable \"TRT_VERSION\"") +endif() + +# find and add all modules of TensorRT into list +foreach(lib IN LISTS _modules) + find_library( + TensorRT_${lib}_LIBRARY + NAMES ${lib} + HINTS ${TensorRT_LIBRARY_DIR}) + list(APPEND TensorRT_LIBRARIES ${TensorRT_${lib}_LIBRARY}) +endforeach() + +# make the "TensorRT target" +add_library(TensorRT IMPORTED INTERFACE) +add_library(TensorRT::TensorRT ALIAS TensorRT) +target_link_libraries(TensorRT INTERFACE ${TensorRT_LIBRARIES}) + +set_target_properties( + TensorRT + PROPERTIES C_STANDARD 17 + CXX_STANDARD 17 + POSITION_INDEPENDENT_CODE ON + SKIP_BUILD_RPATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH "$\{ORIGIN\}" + INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIR}") + +unset(TRT_MAJOR_VERSION) +unset(_modules) diff --git a/lenet/README.md b/lenet/README.md index c2656dc3..a735cc01 100644 --- a/lenet/README.md +++ b/lenet/README.md @@ -1,36 +1,22 @@ # lenet5 -lenet5 is the simplest net in this tensorrtx project. You can learn the basic procedures of building tensorrt app from API. Including `define network`, `build engine`, `set output`, `do inference`, `serialize model to file`, `deserialize model from file`, etc. +lenet5 is one of the simplest net in this repo. You can learn the basic procedures of building CNN from TensorRT API. This demo includes 2 major steps: -## TensorRT C++ API - -``` -// 1. generate lenet5.wts from https://github.com/wang-xinyu/pytorchx/tree/master/lenet - -// 2. put lenet5.wts into tensorrtx/lenet - -// 3. build and run - -cd tensorrtx/lenet - -mkdir build +1. Build engine + * define network + * set input/output + * serialize model to `.engine` file +2. Do inference + * load and deserialize model from `.engine` file + * run inference -cd build - -cmake .. - -make - -sudo ./lenet -s // serialize model to plan file i.e. 'lenet5.engine' - -sudo ./lenet -d // deserialize plan file and run inference +## TensorRT C++ API -// 4. see if the output is same as pytorchx/lenet -``` +see [HERE](../README.md#how-to-run) ## TensorRT Python API -``` +```bash # 1. generate lenet5.wts from https://github.com/wang-xinyu/pytorchx/tree/master/lenet # 2. put lenet5.wts into tensorrtx/lenet @@ -39,9 +25,11 @@ sudo ./lenet -d // deserialize plan file and run inference cd tensorrtx/lenet -python lenet.py -s # serialize model to plan file, i.e. 'lenet5.engine' +# 4.1 serialize model to plan file, i.e. 'lenet5.engine' +python lenet.py -s -python lenet.py -d # deserialize plan file and run inference +# 4.2 deserialize plan file and run inference +python lenet.py -d -# 4. see if the output is same as pytorchx/lenet +# 5. (Optional) see if the output is same as pytorchx/lenet ``` diff --git a/lenet/lenet.cpp b/lenet/lenet.cpp index b72777f9..814800c6 100644 --- a/lenet/lenet.cpp +++ b/lenet/lenet.cpp @@ -1,103 +1,64 @@ -#include "NvInfer.h" -#include "cuda_runtime_api.h" -#include "logging.h" -#include -#include #include +#include +#include +#include +#include +#include "logging.h" +#include "utils.h" -#define CHECK(status) \ - do\ - {\ - auto ret = (status);\ - if (ret != 0)\ - {\ - std::cerr << "Cuda failure: " << ret << std::endl;\ - abort();\ - }\ - } while (0) - -// stuff we know about the network and the input/output blobs -static const int INPUT_H = 32; -static const int INPUT_W = 32; -static const int OUTPUT_SIZE = 10; - -const char* INPUT_BLOB_NAME = "data"; -const char* OUTPUT_BLOB_NAME = "prob"; +// parameters we know about the lenet-5 +#define INPUT_H 32 +#define INPUT_W 32 +#define INPUT_SIZE (INPUT_H * INPUT_W) +#define OUTPUT_SIZE 10 +#define INPUT_NAME "data" +#define OUTPUT_NAME "prob" using namespace nvinfer1; static Logger gLogger; -// Load weights from files shared with TensorRT samples. -// TensorRT weight files have a simple space delimited format: -// [type] [size] -std::map loadWeights(const std::string file) -{ - std::cout << "Loading weights: " << file << std::endl; - std::map weightMap; - - // Open weights file - std::ifstream input(file); - assert(input.is_open() && "Unable to load weight file."); - - // Read number of weight blobs - int32_t count; - input >> count; - assert(count > 0 && "Invalid weight map file."); - - while (count--) - { - Weights wt{DataType::kFLOAT, nullptr, 0}; - uint32_t size; - - // Read name and type of blob - std::string name; - input >> name >> std::dec >> size; - wt.type = DataType::kFLOAT; - - // Load blob - uint32_t* val = reinterpret_cast(malloc(sizeof(val) * size)); - for (uint32_t x = 0, y = size; x < y; ++x) - { - input >> std::hex >> val[x]; - } - wt.values = val; - - wt.count = size; - weightMap[name] = wt; - } - - return weightMap; -} - -// Creat the engine using only the API and not any parser. -ICudaEngine* createLenetEngine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt) -{ - INetworkDefinition* network = builder->createNetworkV2(0U); - - // Create input tensor of shape { 1, 32, 32 } with name INPUT_BLOB_NAME - ITensor* data = network->addInput(INPUT_BLOB_NAME, dt, Dims3{1, INPUT_H, INPUT_W}); +/** + * @brief Creat the engine using only the API and not any parser. + * + * @param N max batch size + * @param runtime runtime + * @param builder builder + * @param config config + * @param dt data type + * @return ICudaEngine* + */ +ICudaEngine* createLenetEngine(int32_t N, IRuntime* runtime, IBuilder* builder, IBuilderConfig* config, DataType dt) { + INetworkDefinition* network = builder->createNetworkV2(1u); + + // Create input tensor of shape { 1, 1, 32, 32 } with name INPUT_NAME + ITensor* data = network->addInput(INPUT_NAME, dt, Dims4{N, 1, INPUT_H, INPUT_W}); assert(data); + // clang-format off // Add convolution layer with 6 outputs and a 5x5 filter. - std::map weightMap = loadWeights("../lenet5.wts"); + std::map weightMap = loadWeights("lenet5.wts"); IConvolutionLayer* conv1 = network->addConvolutionNd(*data, 6, DimsHW{5, 5}, weightMap["conv1.weight"], weightMap["conv1.bias"]); assert(conv1); conv1->setStrideNd(DimsHW{1, 1}); + conv1->setName("conv1"); // Add activation layer using the ReLU algorithm. IActivationLayer* relu1 = network->addActivation(*conv1->getOutput(0), ActivationType::kRELU); assert(relu1); + relu1->setName("relu1"); // Add max pooling layer with stride of 2x2 and kernel size of 2x2. IPoolingLayer* pool1 = network->addPoolingNd(*relu1->getOutput(0), PoolingType::kAVERAGE, DimsHW{2, 2}); assert(pool1); pool1->setStrideNd(DimsHW{2, 2}); + pool1->setName("pool1"); // Add second convolution layer with 16 outputs and a 5x5 filter. IConvolutionLayer* conv2 = network->addConvolutionNd(*pool1->getOutput(0), 16, DimsHW{5, 5}, weightMap["conv2.weight"], weightMap["conv2.bias"]); assert(conv2); conv2->setStrideNd(DimsHW{1, 1}); + conv2->setName("conv2"); // Add activation layer using the ReLU algorithm. IActivationLayer* relu2 = network->addActivation(*conv2->getOutput(0), ActivationType::kRELU); @@ -107,104 +68,166 @@ ICudaEngine* createLenetEngine(unsigned int maxBatchSize, IBuilder* builder, IBu IPoolingLayer* pool2 = network->addPoolingNd(*relu2->getOutput(0), PoolingType::kAVERAGE, DimsHW{2, 2}); assert(pool2); pool2->setStrideNd(DimsHW{2, 2}); + pool2->setName("pool2"); // Add fully connected layer - IFullyConnectedLayer* fc1 = network->addFullyConnected(*pool2->getOutput(0), 120, weightMap["fc1.weight"], weightMap["fc1.bias"]); - assert(fc1); + auto* flatten = network->addShuffle(*pool2->getOutput(0)); + flatten->setReshapeDimensions(Dims2{-1, 400}); + auto* tensor_fc1w = network->addConstant(Dims2{120, 400}, weightMap["fc1.weight"])->getOutput(0); + auto* fc1w = network->addMatrixMultiply(*tensor_fc1w, MatrixOperation::kNONE, *flatten->getOutput(0), MatrixOperation::kTRANSPOSE); + assert(tensor_fc1w && fc1w); + auto tensor_fc1b = network->addConstant(Dims2{120, 1}, weightMap["fc1.bias"])->getOutput(0); + auto* fc1b = network->addElementWise(*fc1w->getOutput(0), *tensor_fc1b, ElementWiseOperation::kSUM); + fc1b->setName("fc1b"); + assert(tensor_fc1b && fc1b); // Add activation layer using the ReLU algorithm. - IActivationLayer* relu3 = network->addActivation(*fc1->getOutput(0), ActivationType::kRELU); + IActivationLayer* relu3 = network->addActivation(*fc1b->getOutput(0), ActivationType::kRELU); assert(relu3); + auto* flatten_relu3 = network->addShuffle(*relu3->getOutput(0)); + flatten_relu3->setReshapeDimensions(Dims2{-1, 120}); // Add second fully connected layer - IFullyConnectedLayer* fc2 = network->addFullyConnected(*relu3->getOutput(0), 84, weightMap["fc2.weight"], weightMap["fc2.bias"]); - assert(fc2); + auto* tensor_fc2w = network->addConstant(Dims2{84, 120}, weightMap["fc2.weight"])->getOutput(0); + auto* fc2w = network->addMatrixMultiply(*tensor_fc2w, MatrixOperation::kNONE, *flatten_relu3->getOutput(0), MatrixOperation::kTRANSPOSE); + assert(tensor_fc2w && fc2w); + fc2w->setName("fc2w"); + auto* tensor_fc2b = network->addConstant(Dims2{84, 1}, weightMap["fc2.bias"])->getOutput(0); + auto* fc2b = network->addElementWise(*fc2w->getOutput(0), *tensor_fc2b, ElementWiseOperation::kSUM); + assert(tensor_fc2b && fc2b); + fc2b->setName("fc2b"); // Add activation layer using the ReLU algorithm. - IActivationLayer* relu4 = network->addActivation(*fc2->getOutput(0), ActivationType::kRELU); + IActivationLayer* relu4 = network->addActivation(*fc2b->getOutput(0), ActivationType::kRELU); assert(relu4); + auto* flatten_relu4 = network->addShuffle(*relu4->getOutput(0)); + flatten_relu4->setReshapeDimensions(Dims2{-1, 84}); // Add third fully connected layer - IFullyConnectedLayer* fc3 = network->addFullyConnected(*relu4->getOutput(0), OUTPUT_SIZE, weightMap["fc3.weight"], weightMap["fc3.bias"]); - assert(fc3); + auto* tensor_fc3w = network->addConstant(Dims2{10, 84}, weightMap["fc3.weight"])->getOutput(0); + auto* fc3w = network->addMatrixMultiply(*tensor_fc3w, MatrixOperation::kNONE, *flatten_relu4->getOutput(0), MatrixOperation::kTRANSPOSE); + assert(tensor_fc3w && fc3w); + fc3w->setName("fc3w"); + auto* tensor_fc3b = network->addConstant(Dims2{10, 1}, weightMap["fc3.bias"])->getOutput(0); + auto* fc3b = network->addElementWise(*fc3w->getOutput(0), *tensor_fc3b, ElementWiseOperation::kSUM); + assert(tensor_fc3b && fc3b); + fc3b->setName("fc3b"); + // clang-format on // Add softmax layer to determine the probability. - ISoftMaxLayer* prob = network->addSoftMax(*fc3->getOutput(0)); + ISoftMaxLayer* prob = network->addSoftMax(*fc3b->getOutput(0)); assert(prob); - prob->getOutput(0)->setName(OUTPUT_BLOB_NAME); + prob->getOutput(0)->setName(OUTPUT_NAME); network->markOutput(*prob->getOutput(0)); - // Build engine +#if TRT_VERSION >= 8400 + config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, WORKSPACE_SIZE); +#else + config->setMaxWorkspaceSize(WORKSPACE_SIZE); builder->setMaxBatchSize(maxBatchSize); - config->setMaxWorkspaceSize(16 << 20); +#endif + + // Build engine +#if TRT_VERSION >= 8000 + IHostMemory* serialized_mem = builder->buildSerializedNetwork(*network, *config); + ICudaEngine* engine = runtime->deserializeCudaEngine(serialized_mem->data(), serialized_mem->size()); +#else ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); +#endif - // Don't need the network any more +#if TRT_VERSION >= 8000 + delete network; +#else network->destroy(); +#endif // Release host memory - for (auto& mem : weightMap) - { - free((void*) (mem.second.values)); + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); } return engine; } -void APIToModel(unsigned int maxBatchSize, IHostMemory** modelStream) -{ +/** + * @brief create a model using the API directly and serialize it to a stream + * + * @param N max batch size + * @param runtime runtime + * @param modelStream + */ +void APIToModel(int32_t N, IRuntime* runtime, IHostMemory** modelStream) { // Create builder IBuilder* builder = createInferBuilder(gLogger); IBuilderConfig* config = builder->createBuilderConfig(); // Create model to populate the network, then set the outputs and create an engine - ICudaEngine* engine = createLenetEngine(maxBatchSize, builder, config, DataType::kFLOAT); + ICudaEngine* engine = createLenetEngine(N, runtime, builder, config, DataType::kFLOAT); assert(engine != nullptr); // Serialize the engine (*modelStream) = engine->serialize(); - // Close everything down +#if TRT_VERSION >= 8000 + delete engine; + delete config; + delete builder; +#else engine->destroy(); + config->destroy(); builder->destroy(); +#endif } -void doInference(IExecutionContext& context, float* input, float* output, int batchSize) -{ - const ICudaEngine& engine = context.getEngine(); +void doInference(IExecutionContext& ctx, float* input, float* output, int batchSize) { + const ICudaEngine& engine = ctx.getEngine(); + + // Find input/output index so we can bind them to the buffers we provide later +#if TRT_VERSION >= 8000 + int32_t nIO = engine.getNbIOTensors(); + const int inputIndex = 0; + const int outputIndex = engine.getNbIOTensors() - 1; +#else + int32_t nIO = engine.getNbBindings(); + const int inputIndex = engine.getBindingIndex(INPUT_NAME); + const int outputIndex = engine.getBindingIndex(OUTPUT_NAME); +#endif + assert(nIO == 2); // lenet-5 contains 1 input and 1 output - // Pointers to input and output device buffers to pass to engine. - // Engine requires exactly IEngine::getNbBindings() number of buffers. - assert(engine.getNbBindings() == 2); - void* buffers[2]; - - // In order to bind the buffers, we need to know the names of the input and output tensors. - // Note that indices are guaranteed to be less than IEngine::getNbBindings() - const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME); - const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME); - - // Create GPU buffers on device - CHECK(cudaMalloc(&buffers[inputIndex], batchSize * INPUT_H * INPUT_W * sizeof(float))); - CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float))); - - // Create stream cudaStream_t stream; CHECK(cudaStreamCreate(&stream)); - // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host - CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream)); - context.enqueue(batchSize, buffers, stream, nullptr); - CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream)); - cudaStreamSynchronize(stream); + // Pointers to input and output cuda buffers to pass to engine + // Note that indices are guaranteed to be less than total I/O number + std::vector buffers(nIO); + CHECK(cudaMallocAsync(&buffers[inputIndex], batchSize * INPUT_SIZE * sizeof(float), stream)); + CHECK(cudaMallocAsync(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), stream)); + CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * INPUT_SIZE * sizeof(float), cudaMemcpyHostToDevice, + stream)); - // Release stream and buffers - cudaStreamDestroy(stream); - CHECK(cudaFree(buffers[inputIndex])); - CHECK(cudaFree(buffers[outputIndex])); + // Run inference +#if TRT_VERSION >= 8000 + for (int32_t i = 0; i < engine.getNbIOTensors(); i++) { + auto const name = engine.getIOTensorName(i); + auto dims = ctx.getTensorShape(name); + auto total = std::accumulate(dims.d, dims.d + dims.nbDims, 1, std::multiplies()); + std::cout << name << " element size: " << total << std::endl; + ctx.setTensorAddress(name, buffers[i]); + } + assert(ctx.enqueueV3(stream)); +#else + assert(ctx.enqueueV2(buffers.data(), stream, nullptr)); +#endif + // Use async API so that no synchronization is needed + CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, + stream)); + for (auto& buffer : buffers) { + CHECK(cudaFreeAsync(buffer, stream)); + } + CHECK(cudaStreamDestroy(stream)); } -int main(int argc, char** argv) -{ +int main(int argc, char** argv) { if (argc != 2) { std::cerr << "arguments not right!" << std::endl; std::cerr << "./lenet -s // serialize model to plan file" << std::endl; @@ -212,24 +235,31 @@ int main(int argc, char** argv) return -1; } - // create a model using the API directly and serialize it to a stream - char *trtModelStream{nullptr}; + IRuntime* runtime = createInferRuntime(gLogger); + assert(runtime != nullptr); + + char* trtModelStream{nullptr}; size_t size{0}; if (std::string(argv[1]) == "-s") { IHostMemory* modelStream{nullptr}; - APIToModel(1, &modelStream); + APIToModel(1, runtime, &modelStream); assert(modelStream != nullptr); - std::ofstream p("lenet5.engine", std::ios::binary); - if (!p) - { + std::ofstream p("lenet5.engine", std::ios::binary | std::ios::trunc); + if (!p) { std::cerr << "could not open plan output file" << std::endl; return -1; } p.write(reinterpret_cast(modelStream->data()), modelStream->size()); + +#if TRT_VERSION >= 8000 + delete modelStream; +#else modelStream->destroy(); - return 1; +#endif + std::cout << "serialized weights to lenet5.engine" << std::endl; + return 0; } else if (std::string(argv[1]) == "-d") { std::ifstream file("lenet5.engine", std::ios::binary); if (file.good()) { @@ -245,15 +275,14 @@ int main(int argc, char** argv) return -1; } + // Mock data. Align with python demo if you need to validate accuracy + std::vector data(INPUT_H * INPUT_W, 1.f); - // Subtract mean from image - float data[INPUT_H * INPUT_W]; - for (int i = 0; i < INPUT_H * INPUT_W; i++) - data[i] = 1.0; - - IRuntime* runtime = createInferRuntime(gLogger); - assert(runtime != nullptr); +#if TRT_VERSION >= 8000 + ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size); +#else ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size, nullptr); +#endif assert(engine != nullptr); IExecutionContext* context = engine->createExecutionContext(); assert(context != nullptr); @@ -261,22 +290,27 @@ int main(int argc, char** argv) // Run inference float prob[OUTPUT_SIZE]; for (int i = 0; i < 1000; i++) { - auto start = std::chrono::system_clock::now(); - doInference(*context, data, prob, 1); - auto end = std::chrono::system_clock::now(); - //std::cout << std::chrono::duration_cast(end - start).count() << "ms" << std::endl; + auto start = std::chrono::high_resolution_clock::now(); + doInference(*context, data.data(), prob, 1); + auto end = std::chrono::high_resolution_clock::now(); + auto dur = std::chrono::duration_cast(end - start).count(); + std::cout << "execution time: " << dur << "us" << std::endl; } - // Destroy the engine +#if TRT_VERSION >= 8000 + delete context; + delete engine; + delete runtime; +#else context->destroy(); engine->destroy(); runtime->destroy(); +#endif // Print histogram of the output distribution std::cout << "\nOutput:\n\n"; - for (unsigned int i = 0; i < 10; i++) - { - std::cout << prob[i] << ", "; + for (unsigned int i = 0; i < 10; i++) { + std::cout << prob[i] << ", " << std::flush; } std::cout << std::endl; diff --git a/lenet/logging.h b/lenet/logging.h index 6b79a8b9..3a25d975 100644 --- a/lenet/logging.h +++ b/lenet/logging.h @@ -17,7 +17,6 @@ #ifndef TENSORRT_LOGGING_H #define TENSORRT_LOGGING_H -#include "NvInferRuntimeCommon.h" #include #include #include @@ -25,33 +24,24 @@ #include #include #include +#include "NvInferRuntimeCommon.h" #include "macros.h" using Severity = nvinfer1::ILogger::Severity; -class LogStreamConsumerBuffer : public std::stringbuf -{ -public: +class LogStreamConsumerBuffer : public std::stringbuf { + public: LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) - : mOutput(stream) - , mPrefix(prefix) - , mShouldLog(shouldLog) - { - } + : mOutput(stream), mPrefix(prefix), mShouldLog(shouldLog) {} - LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) - : mOutput(other.mOutput) - { - } + LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) : mOutput(other.mOutput) {} - ~LogStreamConsumerBuffer() - { + ~LogStreamConsumerBuffer() { // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence // std::streambuf::pptr() gives a pointer to the current position of the output sequence // if the pointer to the beginning is not equal to the pointer to the current position, // call putOutput() to log the output to the stream - if (pbase() != pptr()) - { + if (pbase() != pptr()) { putOutput(); } } @@ -59,16 +49,13 @@ class LogStreamConsumerBuffer : public std::stringbuf // synchronizes the stream buffer and returns 0 on success // synchronizing the stream buffer consists of inserting the buffer contents into the stream, // resetting the buffer and flushing the stream - virtual int sync() - { + virtual int sync() { putOutput(); return 0; } - void putOutput() - { - if (mShouldLog) - { + void putOutput() { + if (mShouldLog) { // prepend timestamp std::time_t timestamp = std::time(nullptr); tm* tm_local = std::localtime(×tamp); @@ -89,12 +76,9 @@ class LogStreamConsumerBuffer : public std::stringbuf } } - void setShouldLog(bool shouldLog) - { - mShouldLog = shouldLog; - } + void setShouldLog(bool shouldLog) { mShouldLog = shouldLog; } -private: + private: std::ostream& mOutput; std::string mPrefix; bool mShouldLog; @@ -104,15 +88,12 @@ class LogStreamConsumerBuffer : public std::stringbuf //! \class LogStreamConsumerBase //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer //! -class LogStreamConsumerBase -{ -public: +class LogStreamConsumerBase { + public: LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) - : mBuffer(stream, prefix, shouldLog) - { - } + : mBuffer(stream, prefix, shouldLog) {} -protected: + protected: LogStreamConsumerBuffer mBuffer; }; @@ -125,49 +106,49 @@ class LogStreamConsumerBase //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. //! Please do not change the order of the parent classes. //! -class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream -{ -public: +class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream { + public: //! \brief Creates a LogStreamConsumer which logs messages with level severity. //! Reportable severity determines if the messages are severe enough to be logged. LogStreamConsumer(Severity reportableSeverity, Severity severity) - : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) - , std::ostream(&mBuffer) // links the stream buffer with the stream - , mShouldLog(severity <= reportableSeverity) - , mSeverity(severity) - { - } + : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(severity <= reportableSeverity), + mSeverity(severity) {} LogStreamConsumer(LogStreamConsumer&& other) - : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) - , std::ostream(&mBuffer) // links the stream buffer with the stream - , mShouldLog(other.mShouldLog) - , mSeverity(other.mSeverity) - { - } + : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(other.mShouldLog), + mSeverity(other.mSeverity) {} - void setReportableSeverity(Severity reportableSeverity) - { + void setReportableSeverity(Severity reportableSeverity) { mShouldLog = mSeverity <= reportableSeverity; mBuffer.setShouldLog(mShouldLog); } -private: - static std::ostream& severityOstream(Severity severity) - { + private: + static std::ostream& severityOstream(Severity severity) { return severity >= Severity::kINFO ? std::cout : std::cerr; } - static std::string severityPrefix(Severity severity) - { - switch (severity) - { - case Severity::kINTERNAL_ERROR: return "[F] "; - case Severity::kERROR: return "[E] "; - case Severity::kWARNING: return "[W] "; - case Severity::kINFO: return "[I] "; - case Severity::kVERBOSE: return "[V] "; - default: assert(0); return ""; + static std::string severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; } } @@ -199,24 +180,19 @@ class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger //! object. -class Logger : public nvinfer1::ILogger -{ -public: - Logger(Severity severity = Severity::kWARNING) - : mReportableSeverity(severity) - { - } +class Logger : public nvinfer1::ILogger { + public: + Logger(Severity severity = Severity::kWARNING) : mReportableSeverity(severity) {} //! //! \enum TestResult //! \brief Represents the state of a given test //! - enum class TestResult - { - kRUNNING, //!< The test is running - kPASSED, //!< The test passed - kFAILED, //!< The test failed - kWAIVED //!< The test was waived + enum class TestResult { + kRUNNING, //!< The test is running + kPASSED, //!< The test passed + kFAILED, //!< The test failed + kWAIVED //!< The test was waived }; //! @@ -226,10 +202,7 @@ class Logger : public nvinfer1::ILogger //! TODO Once all samples are updated to use this method to register the logger with TensorRT, //! we can eliminate the inheritance of Logger from ILogger //! - nvinfer1::ILogger& getTRTLogger() - { - return *this; - } + nvinfer1::ILogger& getTRTLogger() { return *this; } //! //! \brief Implementation of the nvinfer1::ILogger::log() virtual method @@ -237,8 +210,7 @@ class Logger : public nvinfer1::ILogger //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the //! inheritance from nvinfer1::ILogger //! - void log(Severity severity, const char* msg) TRT_NOEXCEPT override - { + void log(Severity severity, const char* msg) TRT_NOEXCEPT override { LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; } @@ -247,10 +219,7 @@ class Logger : public nvinfer1::ILogger //! //! \param severity The logger will only emit messages that have severity of this level or higher. //! - void setReportableSeverity(Severity severity) - { - mReportableSeverity = severity; - } + void setReportableSeverity(Severity severity) { mReportableSeverity = severity; } //! //! \brief Opaque handle that holds logging information for a particular test @@ -259,20 +228,15 @@ class Logger : public nvinfer1::ILogger //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used //! with Logger::reportTest{Start,End}(). //! - class TestAtom - { - public: + class TestAtom { + public: TestAtom(TestAtom&&) = default; - private: + private: friend class Logger; TestAtom(bool started, const std::string& name, const std::string& cmdline) - : mStarted(started) - , mName(name) - , mCmdline(cmdline) - { - } + : mStarted(started), mName(name), mCmdline(cmdline) {} bool mStarted; std::string mName; @@ -290,8 +254,7 @@ class Logger : public nvinfer1::ILogger // //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). //! - static TestAtom defineTest(const std::string& name, const std::string& cmdline) - { + static TestAtom defineTest(const std::string& name, const std::string& cmdline) { return TestAtom(false, name, cmdline); } @@ -304,8 +267,7 @@ class Logger : public nvinfer1::ILogger //! \param[in] argv The array of command-line arguments (given as C strings) //! //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). - static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) - { + static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) { auto cmdline = genCmdlineString(argc, argv); return defineTest(name, cmdline); } @@ -317,8 +279,7 @@ class Logger : public nvinfer1::ILogger //! //! \param[in] testAtom The handle to the test that has started //! - static void reportTestStart(TestAtom& testAtom) - { + static void reportTestStart(TestAtom& testAtom) { reportTestResult(testAtom, TestResult::kRUNNING); assert(!testAtom.mStarted); testAtom.mStarted = true; @@ -333,86 +294,85 @@ class Logger : public nvinfer1::ILogger //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, //! TestResult::kFAILED, TestResult::kWAIVED //! - static void reportTestEnd(const TestAtom& testAtom, TestResult result) - { + static void reportTestEnd(const TestAtom& testAtom, TestResult result) { assert(result != TestResult::kRUNNING); assert(testAtom.mStarted); reportTestResult(testAtom, result); } - static int reportPass(const TestAtom& testAtom) - { + static int reportPass(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kPASSED); return EXIT_SUCCESS; } - static int reportFail(const TestAtom& testAtom) - { + static int reportFail(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kFAILED); return EXIT_FAILURE; } - static int reportWaive(const TestAtom& testAtom) - { + static int reportWaive(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kWAIVED); return EXIT_SUCCESS; } - static int reportTest(const TestAtom& testAtom, bool pass) - { + static int reportTest(const TestAtom& testAtom, bool pass) { return pass ? reportPass(testAtom) : reportFail(testAtom); } - Severity getReportableSeverity() const - { - return mReportableSeverity; - } + Severity getReportableSeverity() const { return mReportableSeverity; } -private: + private: //! //! \brief returns an appropriate string for prefixing a log message with the given severity //! - static const char* severityPrefix(Severity severity) - { - switch (severity) - { - case Severity::kINTERNAL_ERROR: return "[F] "; - case Severity::kERROR: return "[E] "; - case Severity::kWARNING: return "[W] "; - case Severity::kINFO: return "[I] "; - case Severity::kVERBOSE: return "[V] "; - default: assert(0); return ""; + static const char* severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; } } //! //! \brief returns an appropriate string for prefixing a test result message with the given result //! - static const char* testResultString(TestResult result) - { - switch (result) - { - case TestResult::kRUNNING: return "RUNNING"; - case TestResult::kPASSED: return "PASSED"; - case TestResult::kFAILED: return "FAILED"; - case TestResult::kWAIVED: return "WAIVED"; - default: assert(0); return ""; + static const char* testResultString(TestResult result) { + switch (result) { + case TestResult::kRUNNING: + return "RUNNING"; + case TestResult::kPASSED: + return "PASSED"; + case TestResult::kFAILED: + return "FAILED"; + case TestResult::kWAIVED: + return "WAIVED"; + default: + assert(0); + return ""; } } //! //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity //! - static std::ostream& severityOstream(Severity severity) - { + static std::ostream& severityOstream(Severity severity) { return severity >= Severity::kINFO ? std::cout : std::cerr; } //! //! \brief method that implements logging test results //! - static void reportTestResult(const TestAtom& testAtom, TestResult result) - { + static void reportTestResult(const TestAtom& testAtom, TestResult result) { severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " << testAtom.mCmdline << std::endl; } @@ -420,11 +380,9 @@ class Logger : public nvinfer1::ILogger //! //! \brief generate a command line string from the given (argc, argv) values //! - static std::string genCmdlineString(int argc, char const* const* argv) - { + static std::string genCmdlineString(int argc, char const* const* argv) { std::stringstream ss; - for (int i = 0; i < argc; i++) - { + for (int i = 0; i < argc; i++) { if (i > 0) ss << " "; ss << argv[i]; @@ -435,8 +393,7 @@ class Logger : public nvinfer1::ILogger Severity mReportableSeverity; }; -namespace -{ +namespace { //! //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE @@ -445,8 +402,7 @@ namespace //! //! LOG_VERBOSE(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) -{ +inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); } @@ -457,8 +413,7 @@ inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) //! //! LOG_INFO(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_INFO(const Logger& logger) -{ +inline LogStreamConsumer LOG_INFO(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); } @@ -469,8 +424,7 @@ inline LogStreamConsumer LOG_INFO(const Logger& logger) //! //! LOG_WARN(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_WARN(const Logger& logger) -{ +inline LogStreamConsumer LOG_WARN(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); } @@ -481,8 +435,7 @@ inline LogStreamConsumer LOG_WARN(const Logger& logger) //! //! LOG_ERROR(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_ERROR(const Logger& logger) -{ +inline LogStreamConsumer LOG_ERROR(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); } @@ -494,11 +447,10 @@ inline LogStreamConsumer LOG_ERROR(const Logger& logger) //! //! LOG_FATAL(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_FATAL(const Logger& logger) -{ +inline LogStreamConsumer LOG_FATAL(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); } -} // anonymous namespace +} // anonymous namespace -#endif // TENSORRT_LOGGING_H +#endif // TENSORRT_LOGGING_H diff --git a/lenet/macros.h b/lenet/macros.h index 05551039..4752930f 100644 --- a/lenet/macros.h +++ b/lenet/macros.h @@ -1,12 +1,26 @@ -#ifndef __MACROS_H -#define __MACROS_H +#pragma once -#if NV_TENSORRT_MAJOR >= 8 +#ifdef API_EXPORTS +#if defined(_MSC_VER) +#define API __declspec(dllexport) +#else +#define API __attribute__((visibility("default"))) +#endif +#else + +#if defined(_MSC_VER) +#define API __declspec(dllimport) +#else +#define API +#endif +#endif // API_EXPORTS + +#define TRT_VERSION ((NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH) + +#if TRT_VERSION >= 8000 #define TRT_NOEXCEPT noexcept #define TRT_CONST_ENQUEUE const #else #define TRT_NOEXCEPT #define TRT_CONST_ENQUEUE #endif - -#endif // __MACROS_H diff --git a/lenet/utils.h b/lenet/utils.h new file mode 100644 index 00000000..57bff9a1 --- /dev/null +++ b/lenet/utils.h @@ -0,0 +1,55 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include "macros.h" + +#define WORKSPACE_SIZE (16 << 20) + +#define CHECK(status) \ + do { \ + auto ret = (status); \ + if (ret != cudaSuccess) { \ + std::cerr << "Cuda failure: " << ret << std::endl; \ + abort(); \ + } \ + } while (0) + +// TensorRT weight files have a simple space delimited format: +// [type] [size] +std::map loadWeights(const std::string file) { + std::cout << "Loading weights: " << file << std::endl; + std::map weightMap; + + // Open weights file + std::ifstream input(file); + assert(input.is_open() && "Unable to load weight file."); + + // Read number of weight blobs + int32_t count; + input >> count; + assert(count > 0 && "Invalid weight map file."); + + while (count--) { + nvinfer1::Weights wt{nvinfer1::DataType::kFLOAT, nullptr, 0}; + + // Read name and type of blob + std::string name; + input >> name >> std::dec >> wt.count; + + // Load blob + uint32_t* val = reinterpret_cast(malloc(sizeof(val) * wt.count)); + for (uint32_t x = 0; x < wt.count; ++x) { + input >> std::hex >> val[x]; + } + wt.values = val; + weightMap[name] = wt; + } + + return weightMap; +} diff --git a/mlp/CMakeLists.txt b/mlp/CMakeLists.txt index 0dc9c983..6d59cb63 100644 --- a/mlp/CMakeLists.txt +++ b/mlp/CMakeLists.txt @@ -1,24 +1,45 @@ -cmake_minimum_required(VERSION 3.14) # change the version, if asked by compiler -project(mlp) - -set(CMAKE_CXX_STANDARD 14) - -# include and link dirs of tensorrt, you need adapt them if yours are different -include_directories(/usr/include/x86_64-linux-gnu/) -link_directories(/usr/lib/x86_64-linux-gnu/) - -# include and link dirs of cuda for inference -include_directories(/usr/local/cuda/include) -link_directories(/usr/local/cuda/lib64) - -# create link for executable files -add_executable(mlp mlp.cpp) - -# perform linking with nvinfer libraries -target_link_libraries(mlp nvinfer) - -# link with cuda libraries for Inference -target_link_libraries(mlp cudart) - -add_definitions(-O2 -pthread) - +cmake_minimum_required(VERSION 3.17.0) + +project( + mlp + VERSION 0.1 + LANGUAGES C CXX CUDA) + +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES + 60 + 70 + 72 + 75 + 80 + 86 + 89) +endif() + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_INCLUDE_CURRENT_DIR TRUE) +set(CMAKE_BUILD_TYPE + "Debug" + CACHE STRING "Build type for this project" FORCE) + +option(CUDA_USE_STATIC_CUDA_RUNTIME "Use static cudaruntime library" OFF) + +find_package(Threads REQUIRED) +find_package(CUDAToolkit REQUIRED) + +if(NOT TARGET TensorRT::TensorRT) + include(FindTensorRT.cmake) +else() + message("TensorRT has been found, skipping for ${PROJECT_NAME}") +endif() + +add_executable(${PROJECT_NAME} mlp.cpp) + +target_include_directories(${PROJECT_NAME} PUBLIC ${CMAKE_CURRENT_LIST_DIR}) + +target_link_libraries(${PROJECT_NAME} PUBLIC Threads::Threads CUDA::cudart + TensorRT::TensorRT) diff --git a/mlp/FindTensorRT.cmake b/mlp/FindTensorRT.cmake new file mode 100644 index 00000000..d6b97fe8 --- /dev/null +++ b/mlp/FindTensorRT.cmake @@ -0,0 +1,78 @@ +cmake_minimum_required(VERSION 3.17.0) + +set(TRT_VERSION + $ENV{TRT_VERSION} + CACHE STRING + "TensorRT version, e.g. \"8.6.1.6\" or \"8.6.1.6+cuda12.0.1.011\"") + +# find TensorRT include folder +if(NOT TensorRT_INCLUDE_DIR) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") + set(TensorRT_INCLUDE_DIR + "/usr/local/cuda/targets/aarch64-linux/include" + CACHE PATH "TensorRT_INCLUDE_DIR") + else() + set(TensorRT_INCLUDE_DIR + "/usr/include/x86_64-linux-gnu" + CACHE PATH "TensorRT_INCLUDE_DIR") + endif() + message(STATUS "TensorRT: ${TensorRT_INCLUDE_DIR}") +endif() + +# find TensorRT library folder +if(NOT TensorRT_LIBRARY_DIR) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") + set(TensorRT_LIBRARY_DIR + "/usr/lib/aarch64-linux-gnu/tegra" + CACHE PATH "TensorRT_LIBRARY_DIR") + else() + set(TensorRT_LIBRARY_DIR + "/usr/include/x86_64-linux-gnu" + CACHE PATH "TensorRT_LIBRARY_DIR") + endif() + message(STATUS "TensorRT: ${TensorRT_LIBRARY_DIR}") +endif() + +set(TensorRT_LIBRARIES) + +message(STATUS "Found TensorRT lib: ${TensorRT_LIBRARIES}") + +# process for different TensorRT version +if(DEFINED TRT_VERSION AND NOT TRT_VERSION STREQUAL "") + string(REGEX MATCH "([0-9]+)" _match ${TRT_VERSION}) + set(TRT_MAJOR_VERSION "${_match}") + set(_modules nvinfer nvinfer_plugin) + + if(TRT_MAJOR_VERSION GREATER_EQUAL 8) + list(APPEND _modules nvinfer_vc_plugin nvinfer_dispatch nvinfer_lean) + endif() +else() + message(FATAL_ERROR "Please set a environment variable \"TRT_VERSION\"") +endif() + +# find and add all modules of TensorRT into list +foreach(lib IN LISTS _modules) + find_library( + TensorRT_${lib}_LIBRARY + NAMES ${lib} + HINTS ${TensorRT_LIBRARY_DIR}) + list(APPEND TensorRT_LIBRARIES ${TensorRT_${lib}_LIBRARY}) +endforeach() + +# make the "TensorRT target" +add_library(TensorRT IMPORTED INTERFACE) +add_library(TensorRT::TensorRT ALIAS TensorRT) +target_link_libraries(TensorRT INTERFACE ${TensorRT_LIBRARIES}) + +set_target_properties( + TensorRT + PROPERTIES C_STANDARD 17 + CXX_STANDARD 17 + POSITION_INDEPENDENT_CODE ON + SKIP_BUILD_RPATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH "$\{ORIGIN\}" + INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIR}") + +unset(TRT_MAJOR_VERSION) +unset(_modules) diff --git a/mlp/README.md b/mlp/README.md index 5bfc4082..71f58f07 100644 --- a/mlp/README.md +++ b/mlp/README.md @@ -1,57 +1,36 @@ # MLP -MLP is the most basic net in this tensorrtx project for starters. You can learn the basic procedures of building -TensorRT app from the provided APIs. The process of building a TensorRT engine explained in the chart below. +MLP is the most basic net in this tensorrtx project for starters. You can learn the basic procedures of building TensorRT app from the provided APIs. The process of building a TensorRT engine explained in the chart below. ![TensorRT Image](https://user-images.githubusercontent.com/33795294/148565279-795b12da-5243-4e7e-881b-263eb7658683.jpg) -## Helper Files +This demo creates a single-layer MLP with `TensorRT >= 7.x` version support. -`logging.h` : A logger file for using NVIDIA TRT API (mostly same for all models) +## Helper Files -`mlp.wts` : Converted weight file (simple file, you can open and check it) - -## TensorRT C++ API +`logging.h` : A logger file for using NVIDIA TensorRT API (mostly same for all models) +`mlp.wts` : Converted weight file, can be generated from [pytorchx/mlp](https://github.com/wang-xinyu/pytorchx/tree/master/mlp), for mlp, it looks like: +```txt +2 +linear.weight 1 3fff7e32 +linear.bias 1 3c138a5a ``` -// 1. generate mlp.wts from https://github.com/wang-xinyu/pytorchx/tree/master/mlp -- or use the given .wts file - -// 2. put mlp.wts into tensorrtx/mlp (if using the generated weights) - -// 3. build and run - - cd tensorrtx/mlp - - mkdir build +(you can create `mlp.wts` and copy this content into it directly) - cd build - - cmake .. - - make - - sudo ./mlp -s // serialize model to plan file i.e. 'mlp.engine' +## TensorRT C++ API - sudo ./mlp -d // deserialize plan file and run inference -``` +see [HERE](../README.md#how-to-run) ## TensorRT Python API -``` -# 1. Generate mlp.wts from https://github.com/wang-xinyu/pytorchx/tree/master/mlp -- or use the given .wts file - -# 2. Put mlp.wts into tensorrtx/mlp (if using the generated weights) +1. Generate mlp.wts (from `pytorchx` or create on your own) -# 3. Install Python dependencies (tensorrt/pycuda/numpy) +2. Put mlp.wts into tensorrtx/mlp (if using the generated weights) -# 4. Run - +3. Run + ```bash cd tensorrtx/mlp - python mlp.py -s # serialize model to plan file, i.e. 'mlp.engine' - python mlp.py -d # deserialize plan file and run inference -``` - -## Note -It also supports the latest CUDA-11.4 and TensorRT-8.2.x + ``` diff --git a/mlp/logging.h b/mlp/logging.h index 0edb75fa..3a25d975 100644 --- a/mlp/logging.h +++ b/mlp/logging.h @@ -17,7 +17,6 @@ #ifndef TENSORRT_LOGGING_H #define TENSORRT_LOGGING_H -#include "NvInferRuntimeCommon.h" #include #include #include @@ -25,32 +24,24 @@ #include #include #include +#include "NvInferRuntimeCommon.h" +#include "macros.h" using Severity = nvinfer1::ILogger::Severity; -class LogStreamConsumerBuffer : public std::stringbuf -{ -public: +class LogStreamConsumerBuffer : public std::stringbuf { + public: LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) - : mOutput(stream) - , mPrefix(prefix) - , mShouldLog(shouldLog) - { - } + : mOutput(stream), mPrefix(prefix), mShouldLog(shouldLog) {} - LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) - : mOutput(other.mOutput) - { - } + LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) : mOutput(other.mOutput) {} - ~LogStreamConsumerBuffer() - { + ~LogStreamConsumerBuffer() { // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence // std::streambuf::pptr() gives a pointer to the current position of the output sequence // if the pointer to the beginning is not equal to the pointer to the current position, // call putOutput() to log the output to the stream - if (pbase() != pptr()) - { + if (pbase() != pptr()) { putOutput(); } } @@ -58,16 +49,13 @@ class LogStreamConsumerBuffer : public std::stringbuf // synchronizes the stream buffer and returns 0 on success // synchronizing the stream buffer consists of inserting the buffer contents into the stream, // resetting the buffer and flushing the stream - virtual int sync() - { + virtual int sync() { putOutput(); return 0; } - void putOutput() - { - if (mShouldLog) - { + void putOutput() { + if (mShouldLog) { // prepend timestamp std::time_t timestamp = std::time(nullptr); tm* tm_local = std::localtime(×tamp); @@ -88,12 +76,9 @@ class LogStreamConsumerBuffer : public std::stringbuf } } - void setShouldLog(bool shouldLog) - { - mShouldLog = shouldLog; - } + void setShouldLog(bool shouldLog) { mShouldLog = shouldLog; } -private: + private: std::ostream& mOutput; std::string mPrefix; bool mShouldLog; @@ -103,15 +88,12 @@ class LogStreamConsumerBuffer : public std::stringbuf //! \class LogStreamConsumerBase //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer //! -class LogStreamConsumerBase -{ -public: +class LogStreamConsumerBase { + public: LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) - : mBuffer(stream, prefix, shouldLog) - { - } + : mBuffer(stream, prefix, shouldLog) {} -protected: + protected: LogStreamConsumerBuffer mBuffer; }; @@ -124,49 +106,49 @@ class LogStreamConsumerBase //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. //! Please do not change the order of the parent classes. //! -class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream -{ -public: +class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream { + public: //! \brief Creates a LogStreamConsumer which logs messages with level severity. //! Reportable severity determines if the messages are severe enough to be logged. LogStreamConsumer(Severity reportableSeverity, Severity severity) - : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) - , std::ostream(&mBuffer) // links the stream buffer with the stream - , mShouldLog(severity <= reportableSeverity) - , mSeverity(severity) - { - } + : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(severity <= reportableSeverity), + mSeverity(severity) {} LogStreamConsumer(LogStreamConsumer&& other) - : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) - , std::ostream(&mBuffer) // links the stream buffer with the stream - , mShouldLog(other.mShouldLog) - , mSeverity(other.mSeverity) - { - } + : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(other.mShouldLog), + mSeverity(other.mSeverity) {} - void setReportableSeverity(Severity reportableSeverity) - { + void setReportableSeverity(Severity reportableSeverity) { mShouldLog = mSeverity <= reportableSeverity; mBuffer.setShouldLog(mShouldLog); } -private: - static std::ostream& severityOstream(Severity severity) - { + private: + static std::ostream& severityOstream(Severity severity) { return severity >= Severity::kINFO ? std::cout : std::cerr; } - static std::string severityPrefix(Severity severity) - { - switch (severity) - { - case Severity::kINTERNAL_ERROR: return "[F] "; - case Severity::kERROR: return "[E] "; - case Severity::kWARNING: return "[W] "; - case Severity::kINFO: return "[I] "; - case Severity::kVERBOSE: return "[V] "; - default: assert(0); return ""; + static std::string severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; } } @@ -198,24 +180,19 @@ class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger //! object. -class Logger : public nvinfer1::ILogger -{ -public: - Logger(Severity severity = Severity::kWARNING) - : mReportableSeverity(severity) - { - } +class Logger : public nvinfer1::ILogger { + public: + Logger(Severity severity = Severity::kWARNING) : mReportableSeverity(severity) {} //! //! \enum TestResult //! \brief Represents the state of a given test //! - enum class TestResult - { - kRUNNING, //!< The test is running - kPASSED, //!< The test passed - kFAILED, //!< The test failed - kWAIVED //!< The test was waived + enum class TestResult { + kRUNNING, //!< The test is running + kPASSED, //!< The test passed + kFAILED, //!< The test failed + kWAIVED //!< The test was waived }; //! @@ -225,10 +202,7 @@ class Logger : public nvinfer1::ILogger //! TODO Once all samples are updated to use this method to register the logger with TensorRT, //! we can eliminate the inheritance of Logger from ILogger //! - nvinfer1::ILogger& getTRTLogger() - { - return *this; - } + nvinfer1::ILogger& getTRTLogger() { return *this; } //! //! \brief Implementation of the nvinfer1::ILogger::log() virtual method @@ -236,8 +210,7 @@ class Logger : public nvinfer1::ILogger //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the //! inheritance from nvinfer1::ILogger //! - void log(Severity severity, const char* msg) noexcept override - { + void log(Severity severity, const char* msg) TRT_NOEXCEPT override { LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; } @@ -246,10 +219,7 @@ class Logger : public nvinfer1::ILogger //! //! \param severity The logger will only emit messages that have severity of this level or higher. //! - void setReportableSeverity(Severity severity) - { - mReportableSeverity = severity; - } + void setReportableSeverity(Severity severity) { mReportableSeverity = severity; } //! //! \brief Opaque handle that holds logging information for a particular test @@ -258,20 +228,15 @@ class Logger : public nvinfer1::ILogger //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used //! with Logger::reportTest{Start,End}(). //! - class TestAtom - { - public: + class TestAtom { + public: TestAtom(TestAtom&&) = default; - private: + private: friend class Logger; TestAtom(bool started, const std::string& name, const std::string& cmdline) - : mStarted(started) - , mName(name) - , mCmdline(cmdline) - { - } + : mStarted(started), mName(name), mCmdline(cmdline) {} bool mStarted; std::string mName; @@ -289,8 +254,7 @@ class Logger : public nvinfer1::ILogger // //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). //! - static TestAtom defineTest(const std::string& name, const std::string& cmdline) - { + static TestAtom defineTest(const std::string& name, const std::string& cmdline) { return TestAtom(false, name, cmdline); } @@ -303,8 +267,7 @@ class Logger : public nvinfer1::ILogger //! \param[in] argv The array of command-line arguments (given as C strings) //! //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). - static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) - { + static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) { auto cmdline = genCmdlineString(argc, argv); return defineTest(name, cmdline); } @@ -316,8 +279,7 @@ class Logger : public nvinfer1::ILogger //! //! \param[in] testAtom The handle to the test that has started //! - static void reportTestStart(TestAtom& testAtom) - { + static void reportTestStart(TestAtom& testAtom) { reportTestResult(testAtom, TestResult::kRUNNING); assert(!testAtom.mStarted); testAtom.mStarted = true; @@ -332,86 +294,85 @@ class Logger : public nvinfer1::ILogger //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, //! TestResult::kFAILED, TestResult::kWAIVED //! - static void reportTestEnd(const TestAtom& testAtom, TestResult result) - { + static void reportTestEnd(const TestAtom& testAtom, TestResult result) { assert(result != TestResult::kRUNNING); assert(testAtom.mStarted); reportTestResult(testAtom, result); } - static int reportPass(const TestAtom& testAtom) - { + static int reportPass(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kPASSED); return EXIT_SUCCESS; } - static int reportFail(const TestAtom& testAtom) - { + static int reportFail(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kFAILED); return EXIT_FAILURE; } - static int reportWaive(const TestAtom& testAtom) - { + static int reportWaive(const TestAtom& testAtom) { reportTestEnd(testAtom, TestResult::kWAIVED); return EXIT_SUCCESS; } - static int reportTest(const TestAtom& testAtom, bool pass) - { + static int reportTest(const TestAtom& testAtom, bool pass) { return pass ? reportPass(testAtom) : reportFail(testAtom); } - Severity getReportableSeverity() const - { - return mReportableSeverity; - } + Severity getReportableSeverity() const { return mReportableSeverity; } -private: + private: //! //! \brief returns an appropriate string for prefixing a log message with the given severity //! - static const char* severityPrefix(Severity severity) - { - switch (severity) - { - case Severity::kINTERNAL_ERROR: return "[F] "; - case Severity::kERROR: return "[E] "; - case Severity::kWARNING: return "[W] "; - case Severity::kINFO: return "[I] "; - case Severity::kVERBOSE: return "[V] "; - default: assert(0); return ""; + static const char* severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; } } //! //! \brief returns an appropriate string for prefixing a test result message with the given result //! - static const char* testResultString(TestResult result) - { - switch (result) - { - case TestResult::kRUNNING: return "RUNNING"; - case TestResult::kPASSED: return "PASSED"; - case TestResult::kFAILED: return "FAILED"; - case TestResult::kWAIVED: return "WAIVED"; - default: assert(0); return ""; + static const char* testResultString(TestResult result) { + switch (result) { + case TestResult::kRUNNING: + return "RUNNING"; + case TestResult::kPASSED: + return "PASSED"; + case TestResult::kFAILED: + return "FAILED"; + case TestResult::kWAIVED: + return "WAIVED"; + default: + assert(0); + return ""; } } //! //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity //! - static std::ostream& severityOstream(Severity severity) - { + static std::ostream& severityOstream(Severity severity) { return severity >= Severity::kINFO ? std::cout : std::cerr; } //! //! \brief method that implements logging test results //! - static void reportTestResult(const TestAtom& testAtom, TestResult result) - { + static void reportTestResult(const TestAtom& testAtom, TestResult result) { severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " << testAtom.mCmdline << std::endl; } @@ -419,11 +380,9 @@ class Logger : public nvinfer1::ILogger //! //! \brief generate a command line string from the given (argc, argv) values //! - static std::string genCmdlineString(int argc, char const* const* argv) - { + static std::string genCmdlineString(int argc, char const* const* argv) { std::stringstream ss; - for (int i = 0; i < argc; i++) - { + for (int i = 0; i < argc; i++) { if (i > 0) ss << " "; ss << argv[i]; @@ -434,8 +393,7 @@ class Logger : public nvinfer1::ILogger Severity mReportableSeverity; }; -namespace -{ +namespace { //! //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE @@ -444,8 +402,7 @@ namespace //! //! LOG_VERBOSE(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) -{ +inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); } @@ -456,8 +413,7 @@ inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) //! //! LOG_INFO(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_INFO(const Logger& logger) -{ +inline LogStreamConsumer LOG_INFO(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); } @@ -468,8 +424,7 @@ inline LogStreamConsumer LOG_INFO(const Logger& logger) //! //! LOG_WARN(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_WARN(const Logger& logger) -{ +inline LogStreamConsumer LOG_WARN(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); } @@ -480,8 +435,7 @@ inline LogStreamConsumer LOG_WARN(const Logger& logger) //! //! LOG_ERROR(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_ERROR(const Logger& logger) -{ +inline LogStreamConsumer LOG_ERROR(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); } @@ -493,11 +447,10 @@ inline LogStreamConsumer LOG_ERROR(const Logger& logger) //! //! LOG_FATAL(logger) << "hello world" << std::endl; //! -inline LogStreamConsumer LOG_FATAL(const Logger& logger) -{ +inline LogStreamConsumer LOG_FATAL(const Logger& logger) { return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); } -} // anonymous namespace +} // anonymous namespace -#endif // TENSORRT_LOGGING_H +#endif // TENSORRT_LOGGING_H diff --git a/mlp/macros.h b/mlp/macros.h new file mode 100644 index 00000000..4752930f --- /dev/null +++ b/mlp/macros.h @@ -0,0 +1,26 @@ +#pragma once + +#ifdef API_EXPORTS +#if defined(_MSC_VER) +#define API __declspec(dllexport) +#else +#define API __attribute__((visibility("default"))) +#endif +#else + +#if defined(_MSC_VER) +#define API __declspec(dllimport) +#else +#define API +#endif +#endif // API_EXPORTS + +#define TRT_VERSION ((NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH) + +#if TRT_VERSION >= 8000 +#define TRT_NOEXCEPT noexcept +#define TRT_CONST_ENQUEUE const +#else +#define TRT_NOEXCEPT +#define TRT_CONST_ENQUEUE +#endif diff --git a/mlp/mlp.cpp b/mlp/mlp.cpp index e2217b4b..1d4833a6 100644 --- a/mlp/mlp.cpp +++ b/mlp/mlp.cpp @@ -1,71 +1,25 @@ -#include "NvInfer.h" // TensorRT library -#include "iostream" // Standard input/output library -#include "logging.h" // logging file -- by NVIDIA -#include // for weight maps -#include // for file-handling -#include // for timing the execution - -// provided by nvidia for using TensorRT APIs +#include +#include +#include +#include +#include "logging.h" +#include "utils.h" + using namespace nvinfer1; +#define INPUT_SIZE 1 +#define OUTPUT_SIZE 1 +#define INPUT_NAME "data" +#define OUTPUT_NAME "out" + // Logger from TRT API static Logger gLogger; -const int INPUT_SIZE = 1; -const int OUTPUT_SIZE = 1; - -/** //////////////////////////// -// DEPLOYMENT RELATED ///////// -////////////////////////////*/ -std::map loadWeights(const std::string file) { - /** - * Parse the .wts file and store weights in dict format. - * - * @param file path to .wts file - * @return weight_map: dictionary containing weights and their values - */ - - std::cout << "[INFO]: Loading weights..." << file << std::endl; - std::map weightMap; - - // Open Weight file - std::ifstream input(file); - assert(input.is_open() && "[ERROR]: Unable to load weight file..."); - - // Read number of weights - int32_t count; - input >> count; - assert(count > 0 && "Invalid weight map file."); - - // Loop through number of line, actually the number of weights & biases - while (count--) { - // TensorRT weights - Weights wt{DataType::kFLOAT, nullptr, 0}; - uint32_t size; - // Read name and type of weights - std::string w_name; - input >> w_name >> std::dec >> size; - wt.type = DataType::kFLOAT; - - uint32_t *val = reinterpret_cast(malloc(sizeof(val) * size)); - for (uint32_t x = 0, y = size; x < y; ++x) { - // Change hex values to uint32 (for higher values) - input >> std::hex >> val[x]; - } - wt.values = val; - wt.count = size; - - // Add weight values against its name (key) - weightMap[w_name] = wt; - } - return weightMap; -} - -ICudaEngine *createMLPEngine(unsigned int maxBatchSize, IBuilder *builder, IBuilderConfig *config, DataType dt) { +ICudaEngine* createMLPEngine(int32_t N, IRuntime* runtime, IBuilder* builder, IBuilderConfig* config, DataType dt) { /** - * Create Multi-Layer Perceptron using the TRT Builder and Configurations + * Create a single-layer "MLP" using the TRT Builder and Configurations * - * @param maxBatchSize: batch size for built TRT model + * @param N: max batch size for built TRT model * @param builder: to build engine and networks * @param config: configuration related to Hardware * @param dt: datatype for model layers @@ -75,49 +29,59 @@ ICudaEngine *createMLPEngine(unsigned int maxBatchSize, IBuilder *builder, IBuil std::cout << "[INFO]: Creating MLP using TensorRT..." << std::endl; // Load Weights from relevant file - std::map weightMap = loadWeights("../mlp.wts"); + std::map weightMap = loadWeights("./mlp.wts"); // Create an empty network - INetworkDefinition *network = builder->createNetworkV2(0U); + INetworkDefinition* network = builder->createNetworkV2(1u); - // Create an input with proper *name - ITensor *data = network->addInput("data", DataType::kFLOAT, Dims3{1, 1, 1}); + // Create an input with proper name + ITensor* data = network->addInput(INPUT_NAME, DataType::kFLOAT, Dims4{N, 1, 1, 1}); assert(data); + // clang-format off // Add layer for MLP - IFullyConnectedLayer *fc1 = network->addFullyConnected(*data, 1, - weightMap["linear.weight"], - weightMap["linear.bias"]); - assert(fc1); - - // set output with *name - fc1->getOutput(0)->setName("out"); + auto* fc1w_tensor = network->addConstant(Dims4{1, 1, 1, 1}, weightMap["linear.weight"])->getOutput(0); + auto fc1b_tensor = network->addConstant(Dims4{1, 1, 1, 1}, weightMap["linear.bias"])->getOutput(0); + assert(fc1w_tensor && fc1b_tensor); + auto* fc1w = network->addMatrixMultiply(*data, MatrixOperation::kNONE, *fc1w_tensor, MatrixOperation::kTRANSPOSE); + auto fc1b = network->addElementWise(*fc1w->getOutput(0), *fc1b_tensor, ElementWiseOperation::kSUM); + assert(fc1w && fc1b); + fc1w->setName("fc1w"); + fc1b->setName("fc1b"); + // clang-format on + + // set output with name + auto* output = fc1b->getOutput(0); + output->setName(OUTPUT_NAME); // mark the output - network->markOutput(*fc1->getOutput(0)); - - // Set configurations - builder->setMaxBatchSize(1); - // Set workspace size - config->setMaxWorkspaceSize(1 << 20); - - // Build CUDA Engine using network and configurations - ICudaEngine *engine = builder->buildEngineWithConfig(*network, *config); + network->markOutput(*output); + +#if TRT_VERSION >= 8000 + IHostMemory* serialized_mem = builder->buildSerializedNetwork(*network, *config); + ICudaEngine* engine = runtime->deserializeCudaEngine(serialized_mem->data(), serialized_mem->size()); +#else + builder->setMaxBatchSize(N); + config->setMaxWorkspaceSize(WORKSPACE_SIZE); + ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); +#endif assert(engine != nullptr); - // Don't need the network any more - // free captured memory +#if TRT_VERSION >= 8000 + delete network; +#else network->destroy(); +#endif // Release host memory - for (auto &mem: weightMap) { - free((void *) (mem.second.values)); + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); } return engine; } -void APIToModel(unsigned int maxBatchSize, IHostMemory **modelStream) { +void APIToModel(int32_t maxBatchSize, IRuntime* runtime, IHostMemory** modelStream) { /** * Create engine using TensorRT APIs * @@ -125,199 +89,161 @@ void APIToModel(unsigned int maxBatchSize, IHostMemory **modelStream) { * @param modelStream: shared memory to store serialized model */ - // Create builder with the help of logger - IBuilder *builder = createInferBuilder(gLogger); - - // Create hardware configs - IBuilderConfig *config = builder->createBuilderConfig(); + // Create builder with the logger + IBuilder* builder = createInferBuilder(gLogger); + IBuilderConfig* config = builder->createBuilderConfig(); // Build an engine - ICudaEngine *engine = createMLPEngine(maxBatchSize, builder, config, DataType::kFLOAT); + ICudaEngine* engine = createMLPEngine(maxBatchSize, runtime, builder, config, DataType::kFLOAT); assert(engine != nullptr); // serialize the engine into binary stream (*modelStream) = engine->serialize(); - // free up the memory +#if TRT_VERSION >= 8000 + delete engine; + delete config; + delete builder; +#else engine->destroy(); + config->destroy(); builder->destroy(); +#endif } -void performSerialization() { +void doInference(IExecutionContext& ctx, float* input, float* output, int batchSize = 1) { /** - * Serialization Function - */ - // Shared memory object - IHostMemory *modelStream{nullptr}; - - // Write model into stream - APIToModel(1, &modelStream); - assert(modelStream != nullptr); - - - std::cout << "[INFO]: Writing engine into binary..." << std::endl; - - // Open the file and write the contents there in binary format - std::ofstream p("../mlp.engine", std::ios::binary); - if (!p) { - std::cerr << "could not open plan output file" << std::endl; - return; - } - p.write(reinterpret_cast(modelStream->data()), modelStream->size()); - - // Release the memory - modelStream->destroy(); - - std::cout << "[INFO]: Successfully created TensorRT engine..." << std::endl; - std::cout << "\n\tRun inference using `./mlp -d`" << std::endl; - -} - -/** //////////////////////////// -// INFERENCE RELATED ////////// -////////////////////////////*/ -void doInference(IExecutionContext &context, float *input, float *output, int batchSize) { - /** - * Perform inference using the CUDA context + * Perform inference using the CUDA ctx * - * @param context: context created by engine + * @param ctx: context created by engine * @param input: input from the host * @param output: output to save on host * @param batchSize: batch size for TRT model */ - - // Get engine from the context - const ICudaEngine &engine = context.getEngine(); - - // Pointers to input and output device buffers to pass to engine. - // Engine requires exactly IEngine::getNbBindings() number of buffers. - assert(engine.getNbBindings() == 2); - void *buffers[2]; - - // In order to bind the buffers, we need to know the names of the input and output tensors. - // Note that indices are guaranteed to be less than IEngine::getNbBindings() - const int inputIndex = engine.getBindingIndex("data"); - const int outputIndex = engine.getBindingIndex("out"); - - // Create GPU buffers on device -- allocate memory for input and output - cudaMalloc(&buffers[inputIndex], batchSize * INPUT_SIZE * sizeof(float)); - cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float)); - - // create CUDA stream for simultaneous CUDA operations + // Get engine from the ctx + const ICudaEngine& engine = ctx.getEngine(); + +#if TRT_VERSION >= 8000 + int32_t nIO = engine.getNbIOTensors(); + const int inputIndex = 0; + const int outputIndex = engine.getNbIOTensors() - 1; +#else + int32_t nIO = engine.getNbBindings(); + const int inputIndex = engine.getBindingIndex(INPUT_NAME); + const int outputIndex = engine.getBindingIndex(OUTPUT_NAME); +#endif + assert(nIO == 2); // mlp contains 1 input and 1 output + + // create cuda stream for aync cuda operations cudaStream_t stream; - cudaStreamCreate(&stream); - - // copy input from host (CPU) to device (GPU) in stream - cudaMemcpyAsync(buffers[inputIndex], input, batchSize * INPUT_SIZE * sizeof(float), cudaMemcpyHostToDevice, stream); - - // execute inference using context provided by engine - context.enqueue(batchSize, buffers, stream, nullptr); - - // copy output back from device (GPU) to host (CPU) - cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, - stream); - - // synchronize the stream to prevent issues - // (block CUDA and wait for CUDA operations to be completed) - cudaStreamSynchronize(stream); - - // Release stream and buffers (memory) - cudaStreamDestroy(stream); - cudaFree(buffers[inputIndex]); - cudaFree(buffers[outputIndex]); + CHECK(cudaStreamCreate(&stream)); + + // create GPU buffers on cuda device and copy input data from host + std::vector buffers(2, nullptr); + CHECK(cudaMallocAsync(&buffers[inputIndex], batchSize * INPUT_SIZE * sizeof(float), stream)); + CHECK(cudaMallocAsync(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), stream)); + CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * INPUT_SIZE * sizeof(float), cudaMemcpyHostToDevice, + stream)); + + // execute inference using ctx provided by engine +#if TRT_VERSION >= 8000 + for (int32_t i = 0; i < engine.getNbIOTensors(); i++) { + auto const name = engine.getIOTensorName(i); + auto dims = ctx.getTensorShape(name); + auto total = std::accumulate(dims.d, dims.d + dims.nbDims, 1, std::multiplies()); + std::cout << name << "\t" << total << std::endl; + ctx.setTensorAddress(name, buffers[i]); + } + assert(ctx.enqueueV3(stream)); +#else + assert(ctx.enqueueV2(buffers.data(), stream, nullptr)); +#endif + // Use async API so that no synchronization is needed + CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, + stream)); + for (auto& buffer : buffers) { + CHECK(cudaFreeAsync(buffer, stream)); + } + CHECK(cudaStreamDestroy(stream)); } -void performInference() { - /** - * Get inference using the pre-trained model - */ +int main(int argc, char** argv) { + if (argc != 2) { + std::cerr << "[ERROR]: Arguments not right!" << std::endl; + std::cerr << "./mlp -s // serialize model to plan file" << std::endl; + std::cerr << "./mlp -d // deserialize plan file and run inference" << std::endl; + return 1; + } - // stream to write model - char *trtModelStream{nullptr}; + IRuntime* runtime = createInferRuntime(gLogger); + assert(runtime != nullptr); + char* trtModelStream{nullptr}; size_t size{0}; - // read model from the engine file - std::ifstream file("../mlp.engine", std::ios::binary); - if (file.good()) { - file.seekg(0, file.end); - size = file.tellg(); - file.seekg(0, file.beg); - trtModelStream = new char[size]; - assert(trtModelStream); - file.read(trtModelStream, size); - file.close(); + if (std::string(argv[1]) == "-s") { + IHostMemory* modelStream{nullptr}; + APIToModel(1, runtime, &modelStream); + assert(modelStream != nullptr); + + std::ofstream p("./mlp.engine", std::ios::binary | std::ios::trunc); + if (!p.good()) { + std::cerr << "could not open plan output file" << std::endl; + return 1; + } + p.write(reinterpret_cast(modelStream->data()), modelStream->size()); + +#if TRT_VERSION >= 8000 + delete modelStream; +#else + modelStream->destroy(); +#endif + std::cout << "[INFO]: Successfully created TensorRT engine." << std::endl; + return 0; + } else if (std::string(argv[1]) == "-d") { + std::ifstream file("mlp.engine", std::ios::binary); + if (file.good()) { + file.seekg(0, file.end); + size = file.tellg(); + file.seekg(0, file.beg); + trtModelStream = new char[size]; + assert(trtModelStream); + file.read(trtModelStream, size); + file.close(); + } } - // create a runtime (required for deserialization of model) with NVIDIA's logger - IRuntime *runtime = createInferRuntime(gLogger); - assert(runtime != nullptr); - - // deserialize engine for using the char-stream - ICudaEngine *engine = runtime->deserializeCudaEngine(trtModelStream, size, nullptr); + // deserialize engine from the char-stream +#if TRT_VERSION >= 8000 + ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size); +#else + ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size, nullptr); +#endif assert(engine != nullptr); - // create execution context -- required for inference executions - IExecutionContext *context = engine->createExecutionContext(); - assert(context != nullptr); - - float out[1]; // array for output - float data[1]; // array for input - for (float &i: data) - i = 12.0; // put any value for input - - // time the execution - auto start = std::chrono::system_clock::now(); - - // do inference using the parameters - doInference(*context, data, out, 1); + IExecutionContext* ctx = engine->createExecutionContext(); + assert(ctx != nullptr); - // time the execution - auto end = std::chrono::system_clock::now(); - std::cout << "\n[INFO]: Time taken by execution: " - << std::chrono::duration_cast(end - start).count() << "ms" << std::endl; + float output[1] = {-1.f}; + float input[1] = {12.0f}; + for (int i = 0; i < 1000; i++) { + auto start = std::chrono::high_resolution_clock::now(); + doInference(*ctx, input, output); + auto end = std::chrono::high_resolution_clock::now(); + auto time = std::chrono::duration_cast(end - start).count(); + std::cout << "Execution time: " << time << "us\t" + << "output: " << output[0] << std::endl; + } - // free the captured space - context->destroy(); +#if TRT_VERSION >= 8000 + delete ctx; + delete engine; + delete runtime; +#else + ctx->destroy(); engine->destroy(); runtime->destroy(); +#endif - std::cout << "\nInput:\t" << data[0]; - std::cout << "\nOutput:\t"; - for (float i: out) { - std::cout << i; - } - std::cout << std::endl; -} - -int checkArgs(int argc, char **argv) { - /** - * Parse command line arguments - * - * @param argc: argument count - * @param argv: arguments vector - * @return int: a flag to perform operation - */ - - if (argc != 2) { - std::cerr << "[ERROR]: Arguments not right!" << std::endl; - std::cerr << "./mlp -s // serialize model to plan file" << std::endl; - std::cerr << "./mlp -d // deserialize plan file and run inference" << std::endl; - return -1; - } - if (std::string(argv[1]) == "-s") { - return 1; - } else if (std::string(argv[1]) == "-d") { - return 2; - } - return -1; -} - -int main(int argc, char **argv) { - int args = checkArgs(argc, argv); - if (args == 1) - performSerialization(); - else if (args == 2) - performInference(); return 0; } diff --git a/mlp/mlp.py b/mlp/mlp.py index aad38ea4..88dbc766 100644 --- a/mlp/mlp.py +++ b/mlp/mlp.py @@ -7,7 +7,6 @@ import tensorrt as trt # required for the inference using TRT engine -import pycuda.autoinit import pycuda.driver as cuda # Sizes of input and output for TensorRT model @@ -245,4 +244,3 @@ def get_args(): print("\n\tRun inference using `python mlp.py -d`\n") else: perform_inference(input_val=4.0) - diff --git a/mlp/mlp.wts b/mlp/mlp.wts deleted file mode 100644 index 01ef0db7..00000000 --- a/mlp/mlp.wts +++ /dev/null @@ -1,3 +0,0 @@ -2 -linear.weight 1 3fff7e32 -linear.bias 1 3c138a5a diff --git a/mlp/utils.h b/mlp/utils.h new file mode 100644 index 00000000..57bff9a1 --- /dev/null +++ b/mlp/utils.h @@ -0,0 +1,55 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include "macros.h" + +#define WORKSPACE_SIZE (16 << 20) + +#define CHECK(status) \ + do { \ + auto ret = (status); \ + if (ret != cudaSuccess) { \ + std::cerr << "Cuda failure: " << ret << std::endl; \ + abort(); \ + } \ + } while (0) + +// TensorRT weight files have a simple space delimited format: +// [type] [size] +std::map loadWeights(const std::string file) { + std::cout << "Loading weights: " << file << std::endl; + std::map weightMap; + + // Open weights file + std::ifstream input(file); + assert(input.is_open() && "Unable to load weight file."); + + // Read number of weight blobs + int32_t count; + input >> count; + assert(count > 0 && "Invalid weight map file."); + + while (count--) { + nvinfer1::Weights wt{nvinfer1::DataType::kFLOAT, nullptr, 0}; + + // Read name and type of blob + std::string name; + input >> name >> std::dec >> wt.count; + + // Load blob + uint32_t* val = reinterpret_cast(malloc(sizeof(val) * wt.count)); + for (uint32_t x = 0; x < wt.count; ++x) { + input >> std::hex >> val[x]; + } + wt.values = val; + weightMap[name] = wt; + } + + return weightMap; +} diff --git a/tutorials/check_fp16_int8_support.md b/tutorials/check_fp16_int8_support.md index 745ffd10..3c9771e9 100644 --- a/tutorials/check_fp16_int8_support.md +++ b/tutorials/check_fp16_int8_support.md @@ -11,4 +11,3 @@ For example, GTX1080 is 6.1, Tesla T4 is 7.5. visit https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html#hardware-precision-matrix and check the matrix. For example, compute capability 6.1 supports FP32 and INT8. 7.5 supports FP32, FP16, INT8, FP16 tensor core, etc. - diff --git a/tutorials/contribution.md b/tutorials/contributing.md similarity index 86% rename from tutorials/contribution.md rename to tutorials/contributing.md index 3b57b247..94e5dfd6 100644 --- a/tutorials/contribution.md +++ b/tutorials/contributing.md @@ -1,4 +1,4 @@ -# How to make contribution +# How to Contribute 1. Fork this repo to your github account @@ -10,11 +10,11 @@ 5. Pre-commit check and push, we use clang-format to do coding style checking, and the coding style is following google c++ coding style with 4-space. -``` -pip install pre-commit -pip install clang-format +```bash +pip install pre-commit clang-format -cd tensorrtx/ +cd tensorrtx +pre-commit install git add [files-to-commit] pre-commit run diff --git a/tutorials/faq.md b/tutorials/faq.md index 792ebe00..b4df6560 100644 --- a/tutorials/faq.md +++ b/tutorials/faq.md @@ -2,12 +2,12 @@ ## 1. fatal error: NvInfer.h: No such file or directory -`NvInfer.h` is one of the headers of TensorRT. If you install the tensorrt DEB package, the headers should in `/usr/include/x86_64-linux-gnu/`. If you install tensorrt TAR or ZIP file, the `include_directories` and `link_directories` of tensorrt should be added in `CMakeLists.txt`. +`NvInfer.h` is one of the headers of TensorRT. If you install the tensorrt DEB package, the headers should in `/usr/include/x86_64-linux-gnu/`. If you install tensorrt TAR or ZIP file, it is recommended to manage TensorRT with modern CMake syntax, e.g. [FindTensorRT.cmake](../lenet/FindTensorRT.cmake). `dpkg -L` can print out the contents of a DEB package. ``` -$ dpkg -L libnvinfer-dev +$ dpkg -L libnvinfer-dev /. /usr /usr/lib @@ -38,7 +38,7 @@ $ dpkg -L libnvinfer-dev `cuda_runtime_api.h` is from cuda-cudart. If you met this error, you need find where it is and adapt the `include_directories` and `link_directories` of cuda in `CMakeLists.txt`. ``` -$ dpkg -L cuda-cudart-dev-10-0 +$ dpkg -L cuda-cudart-dev-10-0 /. /usr /usr/local @@ -88,4 +88,3 @@ If you train your own yolo model, you need set the `CLASS_NUM` in `yololayer.h`. void APIToModel(unsigned int, nvinfer1::IHostMemory**): Assertion `engine != nullptr' failed. Aborted (core dumped) ``` - diff --git a/tutorials/from_pytorch_to_trt_stepbystep_hrnet.md b/tutorials/from_pytorch_to_trt_stepbystep_hrnet.md index 51565b7a..c2b54794 100644 --- a/tutorials/from_pytorch_to_trt_stepbystep_hrnet.md +++ b/tutorials/from_pytorch_to_trt_stepbystep_hrnet.md @@ -108,7 +108,7 @@ ResBlock层 ```c++ Dims dim = id_1083->getOutput(0)->getDimensions(); -std::cout << dim[0] << " " << dim[1] << " " << dim[2] << " " << dim[3] << std::endl; +std::cout << dim[0] << " " << dim[1] << " " << dim[2] << " " << dim[3] << std::endl; ``` **一般如果出现生成engine就失败的情况,就从createEngine的第一句开始调试,并且随时关注窗口输出,如果在某一层出现大量提示信息,那么该层就会有问题,就将该层的输入tensor维度和输出tensor维度信息都打印出来,看输出的维度是否正常。** diff --git a/tutorials/getting_started.md b/tutorials/getting_started.md index a9a4e8aa..c3799be0 100644 --- a/tutorials/getting_started.md +++ b/tutorials/getting_started.md @@ -1,33 +1,43 @@ # Getting Started with TensorRTx -We use a lenet5 demo to explain how we implement DL network in TensorRTx. +## 1. Setup the development environment -## 1. Run lenet5 in pytorch +(**RECOMMENDED**) If you prefer to run everything in a docker container, check [HERE](../docker/README.md) -Clone the wang-xinyu/pytorchx in your machine. Enter lenet folder. +If you prefer to install every dependencies locally, check [HERE](./install.md) -And of course you should install pytorch first. +## 2. Run TensorRTx demo -``` -git clone https://github.com/wang-xinyu/pytorchx -cd pytorchx/lenet -``` +It is recommended to go through the [lenet5](https://github.com/wang-xinyu/tensorrtx/tree/master/lenet) or [mlp](https://github.com/wang-xinyu/tensorrtx/tree/master/mlp) first. But if you are proficient in TensorRT, please check the readme file of the model you want directly. -Run lenet5.py to generate lenet5.pth which is the pytorch serialized model. The lenet5 arch is defined in lenet5.py. +We use "lenet5" to explain how we build DL network with TensorRT API. -``` -python lenet5.py -``` +### 2.1. Export lenet5 weights in pytorch -Run inference.py to generate lenet5.wts, which is weights file for tensorrt. +1. Clone the [wang-xinyu/pytorchx](https://github.com/wang-xinyu/pytorchx) in your machine, then enter lenet folder: -``` -python inference.py -``` + ```bash + pip install torch + git clone https://github.com/wang-xinyu/pytorchx + cd pytorchx/lenet + ``` -You should see the output from terminal like this, the output of lenet5 is [[0.0950, 0.0998, 0.1101, 0.0975, 0.0966, 0.1097, 0.0948, 0.1056, 0.0992, 0.0917]], shape is [1, 10]. +2. Run lenet5.py to generate lenet5.pth which is the pytorch serialized model. The lenet5 arch is defined in lenet5.py. + + ```bash + python lenet5.py + ``` + +3. Run inference.py to generate lenet5.wts, which is weights file for tensorrt. + + ```bash + python inference.py + ``` + +The terminal output would be like: +```txt +the output of lenet5 is [[0.0950, 0.0998, 0.1101, 0.0975, 0.0966, 0.1097, 0.0948, 0.1056, 0.0992, 0.0917]], shape is [1, 10]. -``` cuda device count: 2 input: torch.Size([1, 1, 32, 32]) conv1 torch.Size([1, 6, 28, 28]) @@ -37,22 +47,21 @@ pool2 torch.Size([1, 16, 5, 5]) view: torch.Size([1, 400]) fc1: torch.Size([1, 120]) lenet out: tensor([[0.0950, 0.0998, 0.1101, 0.0975, 0.0966, 0.1097, 0.0948, 0.1056, 0.0992, - 0.0917]], device='cuda:0', grad_fn=) + 0.0917]], device='cuda:0', grad_fn=) ``` -## 2. Run lenet5 in tensorrt +### 2.2. Run lenet5 in TensorRT -Clone the wang-xinyu/tensorrtx in your machine. Enter lenet folder, copy lenet5.wts generated above, and cmake&make c++ code. +Clone the wang-xinyu/tensorrtx in your machine. Enter lenet folder, copy lenet5.wts generated above, and cmake&make c++ code. And of course you should install cuda/cudnn/tensorrt first. You might need to adapt the tensorrt path in CMakeLists.txt if you install tensorrt from tar package. -``` +```bash git clone https://github.com/wang-xinyu/tensorrtx cd tensorrtx/lenet cp [PATH-OF-pytorchx]/pytorchx/lenet/lenet5.wts . -mkdir build +cmake -S . -B build cd build -cmake .. make ``` @@ -60,19 +69,19 @@ If the `make` succeed, the executable `lenet` will generated. Run lenet to build tensorrt engine and serialize it to file `lenet5.engine`. -``` +```bash ./lenet -s ``` Deserialize the engine and run inference. -``` +```bash ./lenet -d ``` You should see the output like this, -``` +```txt Output: 0.0949623, 0.0998472, 0.110072, 0.0975036, 0.0965564, 0.109736, 0.0947979, 0.105618, 0.099228, 0.0916792, @@ -84,21 +93,34 @@ As the input to pytorch and tensorrt are same, i.e. a [1,1,32,32] all ones tenso So the output should be same, otherwise there must be something wrong. -``` -The pytorch output is +```txt +The pytorch output is 0.0950, 0.0998, 0.1101, 0.0975, 0.0966, 0.1097, 0.0948, 0.1056, 0.0992, 0.0917 -The tensorrt output is +The tensorrt output is 0.0949623, 0.0998472, 0.110072, 0.0975036, 0.0965564, 0.109736, 0.0947979, 0.105618, 0.099228, 0.0916792 ``` Same! exciting, isn't it? -## The .wts content format +## 4. The `.wts` content format -The .wts is plain text file. +The `.wts` is plain text file, e.g. `lenet5.wts`, part of the contents are: -For example the lenet5.wts, part content are shown below. +```txt +10 +conv1.weight 150 be40ee1b bd20bab8 bdc4bc53 ... +conv1.bias 6 bd327058 ... +conv2.weight 2400 3c6f2220 3c693090 ... +conv2.bias 16 bd183967 bcb1ac8a ... +fc1.weight 48000 3c162c20 bd25196a ... +fc1.bias 120 3d3c3d49 bc64b948 ... +fc2.weight 10080 bce095a4 3d33b9dc ... +fc2.bias 84 bc71eaa0 3d9b276c ... +fc3.weight 840 3c252870 3d855351 ... +fc3.bias 10 bdbe4bb8 3b119ee0 ... +... +``` The first line is a number, indicate how many lines it has, excluding itself. @@ -108,17 +130,6 @@ And then each line is The value is in HEX format. -``` -10 -conv1.weight 150 be40ee1b bd20bab8 bdc4bc53 ....... -conv1.bias 6 bd327058 ....... -conv2.weight 2400 3c6f2220 3c693090 ...... -conv2.bias 16 bd183967 bcb1ac8a ....... -fc1.weight 48000 3c162c20 bd25196a ...... -fc1.bias 120 3d3c3d49 bc64b948 ...... -fc2.weight 10080 bce095a4 3d33b9dc ...... -fc2.bias 84 bc71eaa0 3d9b276c ....... -fc3.weight 840 3c252870 3d855351 ....... -fc3.bias 10 bdbe4bb8 3b119ee0 ...... -``` +## 5. Frequently Asked Questions (FAQ) +check [HERE](./faq.md) for the answers of questions you may encounter. diff --git a/tutorials/install.md b/tutorials/install.md index e7c963fd..c1cf0392 100644 --- a/tutorials/install.md +++ b/tutorials/install.md @@ -1,57 +1,44 @@ # Install the dependencies of tensorrtx -## Ubuntu +Using docker as development environment is strongly recommended, you may check [HERE](../docker/README) for the deployment instructions of docker container and *ignore* the rest of this document. -Ubuntu16.04 / cuda10.0 / cudnn7.6.5 / tensorrt7.0.0 / opencv3.3 would be the example, other versions might also work, just need you to try. +While if this is not your case, we always recommend using major LTS version of your OS, Nvidia driver, CUDA, and so on. -It is strongly recommended to use `apt` to manage software in Ubuntu. +## OS -### 1. Install CUDA +Ubuntu-22.04 is recommended. It is strongly recommended to use `apt` to manage packages in Ubuntu. -Go to [cuda-10.0-download](https://developer.nvidia.com/cuda-10.0-download-archive). Choose `Linux` -> `x86_64` -> `Ubuntu` -> `16.04` -> `deb(local)` and download the .deb package. +## Nidia Related -Then follow the installation instructions. +### Driver -``` -sudo dpkg -i cuda-repo-ubuntu1604-10-0-local-10.0.130-410.48_1.0-1_amd64.deb -sudo apt-key add /var/cuda-repo-/7fa2af80.pub -sudo apt-get update -sudo apt-get install cuda -``` +You should install the nvidia driver first before anything else, go to [Ubuntu Driver Installation Guide](https://docs.nvidia.com/datacenter/tesla/driver-installation-guide/index.html#ubuntu) for more details. -### 2. Install TensorRT +**NOTE**: Since version 560, the installation step is a little different than before, check [HERE](https://docs.nvidia.com/datacenter/tesla/driver-installation-guide/index.html#recent-updates) for more details. -Go to [nvidia-tensorrt-7x-download](https://developer.nvidia.com/nvidia-tensorrt-7x-download). You might need login. +### CUDA -Choose TensorRT 7.0 and `TensorRT 7.0.0.11 for Ubuntu 1604 and CUDA 10.0 DEB local repo packages` +Go to [NVIDIA CUDA Installation Guide for Linux](https://developer.nvidia.com/cuda-10.0-download-archive) for the detailed steps. -Install with following commands, after `apt install tensorrt`, it will automatically install cudnn, nvinfer, nvinfer-plugin, etc. +**NOTE**: +- Do not forget to check [Post-installation Actions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#post-installation-actions) to setup the environment correctly. +- Make your CUDA version comply with your driver version +- If you want multi-version CUDA, docker is strongly recommended. -``` -sudo dpkg -i nv-tensorrt-repo-ubuntu1604-cuda10.0-trt7.0.0.11-ga-20191216_1-1_amd64.deb -sudo apt update -sudo apt install tensorrt -``` +### TensorRT + +check [HERE](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#downloading) to install TensorRT. -### 3. Install OpenCV +### (Optional) OpenCV ``` -sudo add-apt-repository ppa:timsc/opencv-3.3 -sudo apt-get update -sudo apt install libopencv-dev +sudo apt-get update && sudo apt install libgtk-3-dev libopencv-dev ``` -### 4. Check your installation +## Verify installation ``` dpkg -l | grep cuda dpkg -l | grep nvinfer dpkg -l | grep opencv ``` - -### 5. Run tensorrtx - -It is recommended to go through the [getting started guide, lenet5 as a demo.](https://github.com/wang-xinyu/tensorrtx/blob/master/tutorials/getting_started.md) first. - -But if you are proficient in tensorrt, please check the readme of the model you want directly. - diff --git a/tutorials/measure_performance.md b/tutorials/measure_performance.md index 864daa34..cfaf8a1e 100644 --- a/tutorials/measure_performance.md +++ b/tutorials/measure_performance.md @@ -110,4 +110,3 @@ context->setProfiler(&sp); context->enqueue(...); gLogInfo << sp << std::endl; ``` - diff --git a/tutorials/migrating_from_tensorrt_4_to_7.md b/tutorials/migrating_from_tensorrt_4_to_7.md deleted file mode 100644 index cebc2a0b..00000000 --- a/tutorials/migrating_from_tensorrt_4_to_7.md +++ /dev/null @@ -1,13 +0,0 @@ -# Migrating from TensorRT 4 to 7 - -The following APIs are deprecated and replaced in TensorRT 7. - -- `DimsCHW`, replaced by `Dims3` -- `addConvolution()`, replaced by `addConvolutionNd()` -- `addPooling()`, replaced by `addPoolingNd()` -- `addDeconvolution()`, replaced by `addDeconvolutionNd()` -- `createNetwork()`, replaced by `createNetworkV2()` -- `buildCudaEngine()`, replaced by `buildEngineWithConfig()` -- `createPReLUPlugin()`, replaced by `addActivation()` with `ActivationType::kLEAKY_RELU` -- `IPlugin` and `IPluginExt` class, replaced by `IPluginV2IOExt` or `IPluginV2DynamicExt` -- Use the new `Logger` class defined in logging.h diff --git a/tutorials/migration_guide.md b/tutorials/migration_guide.md new file mode 100644 index 00000000..37ca536b --- /dev/null +++ b/tutorials/migration_guide.md @@ -0,0 +1,22 @@ +# Migration Guide + +## Newest Migration Guide + +Please check this [Doc](https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Migration-Guide.pdf) or this [Page](https://docs.nvidia.com/deeplearning/tensorrt/migration-guide/index.html) + +For any archives version, please check this [Page](https://docs.nvidia.com/deeplearning/tensorrt/archives/index.html) + +## (DEPRECATED) Migrating from TensorRT 4.x to 7.x + +**NOTE**: Both TensorRT 4.x and 7.x are **DEPRECATED** by NVIDIA officially, so this part is **outdated**. + +The following APIs are deprecated and replaced in TensorRT 7. +- `DimsCHW`, replaced by `Dims3` +- `addConvolution()`, replaced by `addConvolutionNd()` +- `addPooling()`, replaced by `addPoolingNd()` +- `addDeconvolution()`, replaced by `addDeconvolutionNd()` +- `createNetwork()`, replaced by `createNetworkV2()` +- `buildCudaEngine()`, replaced by `buildEngineWithConfig()` +- `createPReLUPlugin()`, replaced by `addActivation()` with `ActivationType::kLEAKY_RELU` +- `IPlugin` and `IPluginExt` class, replaced by `IPluginV2IOExt` or `IPluginV2DynamicExt` +- Use the new `Logger` class defined in `logging.h` diff --git a/tutorials/multi_GPU_processing.md b/tutorials/multi_GPU_processing.md index b4af794a..e9f23808 100644 --- a/tutorials/multi_GPU_processing.md +++ b/tutorials/multi_GPU_processing.md @@ -9,9 +9,9 @@ For example, in function ` forwardGpu()` of **yololayer.cu**, you need to do the 1) Change `cudaMemset(output + idx*outputElem, 0, sizeof(float))` to `cudaMemsetAsync(output + idx*outputElem, 0, sizeof(float), stream)` 2) Change `CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount>>>(inputs[i],output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount ,outputElem)` to `CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>>(inputs[i],output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount ,outputElem)` - + ## 2. Create an engine for each device you want to use. - + Maybe it is a good idea to create a struct to store the engine, context and buffer for each device individually. For example, ``` struct Plan{ @@ -23,7 +23,7 @@ For example, in function ` forwardGpu()` of **yololayer.cu**, you need to do the }; ``` And then use `cudaSetDevice()` to make each engine you create running on specific device. Moreover, to maximize performance, make sure that the engine file you are using to deserialize is the one tensor RT optimized for this device. - + ## 3. Use function wisely Here are some knowledge I learned when trying to parallelize the inference. 1) Do not use synchronized function , like `cudaFree()`, during inference.