Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
TensorRT: add int8 with calibration (#19011)
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <spanev@nvidia.com>
  • Loading branch information
Kh4L committed Sep 15, 2020
1 parent 9dfac79 commit 606933f
Show file tree
Hide file tree
Showing 14 changed files with 854 additions and 101 deletions.
16 changes: 16 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,22 @@ unittest_ubuntu_python3_gpu_nocudnn() {
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
}

unittest_ubuntu_tensorrt_gpu() {
set -ex
export PYTHONPATH=./python/
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
export MXNET_SUBGRAPH_VERBOSE=0
export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
export CUDNN_VERSION=${CUDNN_VERSION:-7.0.3}
export MXNET_ENABLE_CYTHON=0
export DMLC_LOG_STACK_TRACE_DEPTH=10
pip3 install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100==0.24
wget -nc http://data.mxnet.io/data/val_256_q90.rec
python3.6 tests/python/tensorrt/rec2idx.py val_256_q90.rec val_256_q90.idx
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose --nocapture tests/python/tensorrt/
rm val_256_q90*
}

# quantization gpu currently only runs on P3 instances
# need to separte it from unittest_ubuntu_python3_gpu()
unittest_ubuntu_python3_quantization_gpu() {
Expand Down
18 changes: 18 additions & 0 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,24 @@ def test_unix_python3_mkldnn_nocudnn_gpu() {
}]
}

def test_unix_python3_tensorrt_gpu() {
return ['Python3: TensorRT GPU': {
node(NODE_LINUX_GPU_P3) {
ws('workspace/build-tensorrt') {
timeout(time: max_time, unit: 'MINUTES') {
try {
utils.unpack_and_init('tensorrt', mx_tensorrt_lib)
utils.docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true)
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml')
}
}
}
}
}]
}

def test_unix_python3_integration_gpu() {
return ['Python Integration GPU': {
node(NODE_LINUX_GPU_G4) {
Expand Down
1 change: 1 addition & 0 deletions ci/jenkins/Jenkinsfile_unix_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ core_logic: {
custom_steps.test_unix_python3_mkldnn_gpu(),
custom_steps.test_unix_python3_mkldnn_nocudnn_gpu(),
custom_steps.test_unix_perl_gpu(),
custom_steps.test_unix_python3_tensorrt_gpu(),
custom_steps.test_unix_r_gpu(),
custom_steps.test_unix_cpp_gpu(),
custom_steps.test_unix_cpp_mkldnn_gpu(),
Expand Down
1 change: 0 additions & 1 deletion python/mxnet/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@
from . import io
from . import quantization
from . import quantization as quant
from . import tensorrt
67 changes: 0 additions & 67 deletions python/mxnet/contrib/tensorrt.py

This file was deleted.

82 changes: 76 additions & 6 deletions src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
*/

/*!
* Copyright (c) 2019 by Contributors
* Copyright (c) 2019-2020 by Contributors
* \file onnx_to_tensorrt.cc
* \brief TensorRT integration with the MXNet executor
* \author Marek Kolodziej, Clement Fuji Tsang
* \author Marek Kolodziej, Clement Fuji Tsang, Serge Panev
*/

#if MXNET_USE_TENSORRT
Expand All @@ -38,6 +38,8 @@
#include <dmlc/logging.h>
#include <dmlc/parameter.h>

#include <future>

using std::cout;
using std::cerr;
using std::endl;
Expand All @@ -64,10 +66,13 @@ void PrintVersion() {

std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
unique_ptr<nvonnxparser::IParser>,
std::unique_ptr<TRT_Logger> > onnxToTrtCtx(
std::unique_ptr<TRT_Logger>,
std::future<onnx_to_tensorrt::unique_ptr<nvinfer1::ICudaEngine> > > onnxToTrtCtx(
const std::string& onnx_model,
bool fp16_mode,
int32_t max_batch_size,
size_t max_workspace_size,
TRTInt8Calibrator* calibrator,
nvinfer1::ILogger::Severity verbosity,
bool debug_builder) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
Expand Down Expand Up @@ -112,18 +117,83 @@ std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
}
throw dmlc::Error("Cannot parse ONNX into TensorRT Engine");
}
if (dmlc::GetEnv("MXNET_TENSORRT_USE_FP16", true)) {
trt_builder->setMaxBatchSize(max_batch_size);
std::future<onnx_to_tensorrt::unique_ptr<nvinfer1::ICudaEngine>> future_int8_engine;
#if NV_TENSORRT_MAJOR > 6
auto builder_config = InferObject(trt_builder->createBuilderConfig());

if (fp16_mode) {
if (trt_builder->platformHasFastFp16()) {
builder_config->setFlag(nvinfer1::BuilderFlag::kFP16);
} else {
LOG(WARNING) << "TensorRT can't use fp16 on this platform";
}
}

builder_config->setMaxWorkspaceSize(max_workspace_size);
if (debug_builder) {
builder_config->setFlag(nvinfer1::BuilderFlag::kDEBUG);
}

auto trt_engine = InferObject(trt_builder->buildEngineWithConfig(*trt_network, *builder_config));

if (calibrator != nullptr) {
if (trt_builder->platformHasFastInt8()) {
builder_config->setFlag(nvinfer1::BuilderFlag::kINT8);
builder_config->setInt8Calibrator(calibrator);
} else {
LOG(WARNING) << "TensorRT can't use int8 on this platform";
calibrator = nullptr;
}
}

// if the cache is null, we are in calibration mode
if (calibrator != nullptr && calibrator->isCacheEmpty()) {
future_int8_engine = std::async([trt_builder = std::move(trt_builder),
trt_network = std::move(trt_network),
builder_config = std::move(builder_config)]() {
// Calibration is blocking so we need to have it in a different thread.
// The engine will be calling calibrator.setBatch until it returns false
auto int8_engine = InferObject(trt_builder->buildEngineWithConfig(*trt_network,
*builder_config));
return std::move(int8_engine);
});
}
#else
if (fp16_mode) {
if (trt_builder->platformHasFastFp16()) {
trt_builder->setFp16Mode(true);
} else {
LOG(WARNING) << "TensorRT can't use fp16 on this platform";
}
}
trt_builder->setMaxBatchSize(max_batch_size);

trt_builder->setMaxWorkspaceSize(max_workspace_size);
trt_builder->setDebugSync(debug_builder);

if (calibrator != nullptr) {
if (trt_builder->platformHasFastInt8()) {
trt_builder->setInt8Mode(true);
trt_builder->setInt8Calibrator(calibrator);
} else {
LOG(WARNING) << "TensorRT can't use int8 on this platform";
calibrator = nullptr;
}
}
auto trt_engine = InferObject(trt_builder->buildCudaEngine(*trt_network));
return std::make_tuple(std::move(trt_engine), std::move(trt_parser), std::move(trt_logger));
// if the cache is null, we are in calibration mode
if (calibrator != nullptr && calibrator->isCacheEmpty()) {
future_int8_engine = std::async([trt_builder = std::move(trt_builder),
trt_network = std::move(trt_network)]() {
// Calibration is blocking so we need to have it in a different thread.
// The engine will be calling calibrator.setBatch until it returns false
auto int8_engine = InferObject(trt_builder->buildCudaEngine(*trt_network));
return std::move(int8_engine);
});
}
#endif
return std::make_tuple(std::move(trt_engine), std::move(trt_parser),
std::move(trt_logger), std::move(future_int8_engine));
}

} // namespace onnx_to_tensorrt
Expand Down
14 changes: 10 additions & 4 deletions src/operator/subgraph/tensorrt/onnx_to_tensorrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
*/

/*!
* Copyright (c) 2019 by Contributors
* Copyright (c) 2019-2020 by Contributors
* \file onnx_to_tensorrt.h
* \brief TensorRT integration with the MXNet executor
* \author Marek Kolodziej, Clement Fuji Tsang
* \author Marek Kolodziej, Clement Fuji Tsang, Serge Panev
*/

#if MXNET_USE_TENSORRT
Expand All @@ -32,13 +32,16 @@
#include <NvInfer.h>

#include <fstream>
#include <future>
#include <memory>
#include <iostream>
#include <sstream>
#include <string>
#include <ctime>
#include <tuple>

#include "./tensorrt_int8_calibrator.h"

namespace onnx_to_tensorrt {

struct InferDeleter {
Expand All @@ -56,7 +59,7 @@ using unique_ptr = std::unique_ptr<T, InferDeleter>;
template<typename T>
inline unique_ptr<T> InferObject(T* obj) {
if ( !obj ) {
throw std::runtime_error("Failed to create object");
throw std::runtime_error("Failed to create TensorRT object");
}
return unique_ptr<T>(obj, InferDeleter());
}
Expand Down Expand Up @@ -85,10 +88,13 @@ class TRT_Logger : public nvinfer1::ILogger {

std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
unique_ptr<nvonnxparser::IParser>,
std::unique_ptr<TRT_Logger> > onnxToTrtCtx(
std::unique_ptr<TRT_Logger>,
std::future<onnx_to_tensorrt::unique_ptr<nvinfer1::ICudaEngine> > > onnxToTrtCtx(
const std::string& onnx_model,
bool fp16_mode,
int32_t max_batch_size = 32,
size_t max_workspace_size = 1L << 30,
TRTInt8Calibrator* calibrator = nullptr,
nvinfer1::ILogger::Severity verbosity = nvinfer1::ILogger::Severity::kWARNING,
bool debug_builder = false);
} // namespace onnx_to_tensorrt
Expand Down
Loading

0 comments on commit 606933f

Please sign in to comment.