From 606933f2c27121e1d33693a41468ed5547a72eed Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 15 Sep 2020 15:58:54 -0700 Subject: [PATCH] TensorRT: add int8 with calibration (#19011) Signed-off-by: Serge Panev --- ci/docker/runtime_functions.sh | 16 ++ ci/jenkins/Jenkins_steps.groovy | 18 ++ ci/jenkins/Jenkinsfile_unix_gpu | 1 + python/mxnet/contrib/__init__.py | 1 - python/mxnet/contrib/tensorrt.py | 67 ------ .../subgraph/tensorrt/onnx_to_tensorrt.cc | 82 ++++++- .../subgraph/tensorrt/onnx_to_tensorrt.h | 14 +- src/operator/subgraph/tensorrt/tensorrt-inl.h | 121 +++++++++-- src/operator/subgraph/tensorrt/tensorrt.cc | 45 +++- src/operator/subgraph/tensorrt/tensorrt.cu | 27 ++- .../tensorrt/tensorrt_int8_calibrator.cc | 149 +++++++++++++ .../tensorrt/tensorrt_int8_calibrator.h | 105 +++++++++ tests/python/tensorrt/rec2idx.py | 107 ++++++++++ tests/python/tensorrt/test_tensorrt.py | 202 ++++++++++++++++++ 14 files changed, 854 insertions(+), 101 deletions(-) delete mode 100644 python/mxnet/contrib/tensorrt.py create mode 100644 src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.cc create mode 100644 src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.h create mode 100644 tests/python/tensorrt/rec2idx.py create mode 100644 tests/python/tensorrt/test_tensorrt.py diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index f8f2b570a32d..0351d0feddac 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1069,6 +1069,22 @@ unittest_ubuntu_python3_gpu_nocudnn() { nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu } +unittest_ubuntu_tensorrt_gpu() { + set -ex + export PYTHONPATH=./python/ + export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 + export MXNET_SUBGRAPH_VERBOSE=0 + export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH + export CUDNN_VERSION=${CUDNN_VERSION:-7.0.3} + export MXNET_ENABLE_CYTHON=0 + export DMLC_LOG_STACK_TRACE_DEPTH=10 + pip3 install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100==0.24 + wget -nc http://data.mxnet.io/data/val_256_q90.rec + python3.6 tests/python/tensorrt/rec2idx.py val_256_q90.rec val_256_q90.idx + nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose --nocapture tests/python/tensorrt/ + rm val_256_q90* +} + # quantization gpu currently only runs on P3 instances # need to separte it from unittest_ubuntu_python3_gpu() unittest_ubuntu_python3_quantization_gpu() { diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index 1cc91e4f4247..2c98e0ff96cf 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -853,6 +853,24 @@ def test_unix_python3_mkldnn_nocudnn_gpu() { }] } +def test_unix_python3_tensorrt_gpu() { + return ['Python3: TensorRT GPU': { + node(NODE_LINUX_GPU_P3) { + ws('workspace/build-tensorrt') { + timeout(time: max_time, unit: 'MINUTES') { + try { + utils.unpack_and_init('tensorrt', mx_tensorrt_lib) + utils.docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true) + utils.publish_test_coverage() + } finally { + utils.collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml') + } + } + } + } + }] +} + def test_unix_python3_integration_gpu() { return ['Python Integration GPU': { node(NODE_LINUX_GPU_G4) { diff --git a/ci/jenkins/Jenkinsfile_unix_gpu b/ci/jenkins/Jenkinsfile_unix_gpu index f21944084a72..163a0a02e0fd 100644 --- a/ci/jenkins/Jenkinsfile_unix_gpu +++ b/ci/jenkins/Jenkinsfile_unix_gpu @@ -51,6 +51,7 @@ core_logic: { custom_steps.test_unix_python3_mkldnn_gpu(), custom_steps.test_unix_python3_mkldnn_nocudnn_gpu(), custom_steps.test_unix_perl_gpu(), + custom_steps.test_unix_python3_tensorrt_gpu(), custom_steps.test_unix_r_gpu(), custom_steps.test_unix_cpp_gpu(), custom_steps.test_unix_cpp_mkldnn_gpu(), diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py index 606bb0ada54f..fbfd3469678b 100644 --- a/python/mxnet/contrib/__init__.py +++ b/python/mxnet/contrib/__init__.py @@ -32,4 +32,3 @@ from . import io from . import quantization from . import quantization as quant -from . import tensorrt diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py deleted file mode 100644 index 2676f40a35e2..000000000000 --- a/python/mxnet/contrib/tensorrt.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" Module to enable the use of TensorRT optimized graphs.""" -import os - -def set_use_fp16(status): - """ - Set an environment variable which will enable or disable the use of FP16 precision in - TensorRT - Note: The mode FP16 force the whole TRT node to be executed in FP16 - :param status: Boolean, True if TensorRT should run in FP16, False for FP32 - """ - os.environ["MXNET_TENSORRT_USE_FP16"] = str(int(status)) - -def get_use_fp16(): - """ - Get an environment variable which describes if TensorRT is currently running in FP16 - :return: Boolean, true if TensorRT is running in FP16, False for FP32 - """ - return bool(int(os.environ.get("MXNET_TENSORRT_USE_FP16", 1)) == 1) - -def init_tensorrt_params(sym, arg_params, aux_params): - """ - Set weights in attributes of TensorRT nodes - :param sym: Symbol, the symbol graph should contains some TensorRT nodes - :param arg_params: arg_params - :param aux_params: aux_params - :return arg_params, aux_params: remaining params that are not in TensorRT nodes - """ - arg_params = arg_params.copy() - aux_params = aux_params.copy() - for s in sym.get_internals(): - new_params_names = "" - tensorrt_params = {} - if 'subgraph_params_names' in s.list_attr(): - keys = s.list_attr()['subgraph_params_names'].split(';') - for k in keys: - if k in arg_params: - new_params_names += k + ";" - tensorrt_params['subgraph_param_' + k] = arg_params[k] - arg_params.pop(k) - elif k in aux_params: - new_params_names += k + ";" - tensorrt_params['subgraph_param_' + k] = aux_params[k] - aux_params.pop(k) - new_attrs = {} - for k, v in tensorrt_params.items(): - new_attrs[k] = str(v.handle.value) - if len(new_attrs) > 0: - s._set_attr(**new_attrs) - s._set_attr(subgraph_params_names=new_params_names[:-1]) - return arg_params, aux_params diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc index 4f5bdcb8561c..fc4809d7f1cb 100644 --- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc +++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc @@ -18,10 +18,10 @@ */ /*! - * Copyright (c) 2019 by Contributors + * Copyright (c) 2019-2020 by Contributors * \file onnx_to_tensorrt.cc * \brief TensorRT integration with the MXNet executor - * \author Marek Kolodziej, Clement Fuji Tsang + * \author Marek Kolodziej, Clement Fuji Tsang, Serge Panev */ #if MXNET_USE_TENSORRT @@ -38,6 +38,8 @@ #include #include +#include + using std::cout; using std::cerr; using std::endl; @@ -64,10 +66,13 @@ void PrintVersion() { std::tuple, unique_ptr, - std::unique_ptr > onnxToTrtCtx( + std::unique_ptr, + std::future > > onnxToTrtCtx( const std::string& onnx_model, + bool fp16_mode, int32_t max_batch_size, size_t max_workspace_size, + TRTInt8Calibrator* calibrator, nvinfer1::ILogger::Severity verbosity, bool debug_builder) { GOOGLE_PROTOBUF_VERIFY_VERSION; @@ -112,18 +117,83 @@ std::tuple, } throw dmlc::Error("Cannot parse ONNX into TensorRT Engine"); } - if (dmlc::GetEnv("MXNET_TENSORRT_USE_FP16", true)) { + trt_builder->setMaxBatchSize(max_batch_size); + std::future> future_int8_engine; +#if NV_TENSORRT_MAJOR > 6 + auto builder_config = InferObject(trt_builder->createBuilderConfig()); + + if (fp16_mode) { + if (trt_builder->platformHasFastFp16()) { + builder_config->setFlag(nvinfer1::BuilderFlag::kFP16); + } else { + LOG(WARNING) << "TensorRT can't use fp16 on this platform"; + } + } + + builder_config->setMaxWorkspaceSize(max_workspace_size); + if (debug_builder) { + builder_config->setFlag(nvinfer1::BuilderFlag::kDEBUG); + } + + auto trt_engine = InferObject(trt_builder->buildEngineWithConfig(*trt_network, *builder_config)); + + if (calibrator != nullptr) { + if (trt_builder->platformHasFastInt8()) { + builder_config->setFlag(nvinfer1::BuilderFlag::kINT8); + builder_config->setInt8Calibrator(calibrator); + } else { + LOG(WARNING) << "TensorRT can't use int8 on this platform"; + calibrator = nullptr; + } + } + + // if the cache is null, we are in calibration mode + if (calibrator != nullptr && calibrator->isCacheEmpty()) { + future_int8_engine = std::async([trt_builder = std::move(trt_builder), + trt_network = std::move(trt_network), + builder_config = std::move(builder_config)]() { + // Calibration is blocking so we need to have it in a different thread. + // The engine will be calling calibrator.setBatch until it returns false + auto int8_engine = InferObject(trt_builder->buildEngineWithConfig(*trt_network, + *builder_config)); + return std::move(int8_engine); + }); + } +#else + if (fp16_mode) { if (trt_builder->platformHasFastFp16()) { trt_builder->setFp16Mode(true); } else { LOG(WARNING) << "TensorRT can't use fp16 on this platform"; } } - trt_builder->setMaxBatchSize(max_batch_size); + trt_builder->setMaxWorkspaceSize(max_workspace_size); trt_builder->setDebugSync(debug_builder); + + if (calibrator != nullptr) { + if (trt_builder->platformHasFastInt8()) { + trt_builder->setInt8Mode(true); + trt_builder->setInt8Calibrator(calibrator); + } else { + LOG(WARNING) << "TensorRT can't use int8 on this platform"; + calibrator = nullptr; + } + } auto trt_engine = InferObject(trt_builder->buildCudaEngine(*trt_network)); - return std::make_tuple(std::move(trt_engine), std::move(trt_parser), std::move(trt_logger)); + // if the cache is null, we are in calibration mode + if (calibrator != nullptr && calibrator->isCacheEmpty()) { + future_int8_engine = std::async([trt_builder = std::move(trt_builder), + trt_network = std::move(trt_network)]() { + // Calibration is blocking so we need to have it in a different thread. + // The engine will be calling calibrator.setBatch until it returns false + auto int8_engine = InferObject(trt_builder->buildCudaEngine(*trt_network)); + return std::move(int8_engine); + }); + } +#endif + return std::make_tuple(std::move(trt_engine), std::move(trt_parser), + std::move(trt_logger), std::move(future_int8_engine)); } } // namespace onnx_to_tensorrt diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h index b89422f59069..0484d2f725a5 100644 --- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h +++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.h @@ -20,10 +20,10 @@ */ /*! - * Copyright (c) 2019 by Contributors + * Copyright (c) 2019-2020 by Contributors * \file onnx_to_tensorrt.h * \brief TensorRT integration with the MXNet executor - * \author Marek Kolodziej, Clement Fuji Tsang + * \author Marek Kolodziej, Clement Fuji Tsang, Serge Panev */ #if MXNET_USE_TENSORRT @@ -32,6 +32,7 @@ #include #include +#include #include #include #include @@ -39,6 +40,8 @@ #include #include +#include "./tensorrt_int8_calibrator.h" + namespace onnx_to_tensorrt { struct InferDeleter { @@ -56,7 +59,7 @@ using unique_ptr = std::unique_ptr; template inline unique_ptr InferObject(T* obj) { if ( !obj ) { - throw std::runtime_error("Failed to create object"); + throw std::runtime_error("Failed to create TensorRT object"); } return unique_ptr(obj, InferDeleter()); } @@ -85,10 +88,13 @@ class TRT_Logger : public nvinfer1::ILogger { std::tuple, unique_ptr, - std::unique_ptr > onnxToTrtCtx( + std::unique_ptr, + std::future > > onnxToTrtCtx( const std::string& onnx_model, + bool fp16_mode, int32_t max_batch_size = 32, size_t max_workspace_size = 1L << 30, + TRTInt8Calibrator* calibrator = nullptr, nvinfer1::ILogger::Severity verbosity = nvinfer1::ILogger::Severity::kWARNING, bool debug_builder = false); } // namespace onnx_to_tensorrt diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h b/src/operator/subgraph/tensorrt/tensorrt-inl.h index c5022cb2b74c..cdc8a49e9d13 100644 --- a/src/operator/subgraph/tensorrt/tensorrt-inl.h +++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h @@ -20,18 +20,23 @@ */ /*! - * Copyright (c) 2019 by Contributors + * Copyright (c) 2019-2020 by Contributors * \file tensorrt-inl.h * \brief TensorRT operation registration - * \author Marek Kolodziej, Clement Fuji Tsang + * \author Marek Kolodziej, Clement Fuji Tsang, Serge Panev */ #if MXNET_USE_TENSORRT #include +#include +#include +#include #include +#include #include +#include #include #include "../../nn/activation-inl.h" @@ -46,6 +51,7 @@ #include "../subgraph_property.h" #include "nnvm_to_onnx-inl.h" #include "./onnx_to_tensorrt.h" +#include "./tensorrt_int8_calibrator.h" namespace mxnet { namespace op { @@ -56,17 +62,29 @@ struct TRTParam { std::unordered_map inputs_to_idx; std::unordered_map outputs_to_idx; std::unordered_map params_map; + bool fp16_mode; + bool int8_mode; + int calibration_iters; }; struct TRTEngineParam { - TRTEngineParam(onnx_to_tensorrt::unique_ptr _trt_engine, + using EnginePtr = onnx_to_tensorrt::unique_ptr; + TRTEngineParam(EnginePtr _trt_engine, onnx_to_tensorrt::unique_ptr _trt_parser, std::unique_ptr _trt_logger, const std::unordered_map& input_map, - const std::unordered_map& output_map) { + const std::unordered_map& output_map, + int _max_batch_size, + std::unique_ptr<::onnx_to_tensorrt::TRTInt8Calibrator> _calibrator = {}, + std::future _future_int8_engine = {}) { trt_engine = std::move(_trt_engine); trt_logger = std::move(_trt_logger); trt_parser = std::move(_trt_parser); + calibrator = std::move(_calibrator); + future_int8_engine = std::move(_future_int8_engine); + calibration_mode = future_int8_engine.valid(); + max_batch_size = _max_batch_size; + input_name_to_idx = input_map; binding_order = std::make_shared > >(); bindings = std::make_shared >(); binding_order->reserve(trt_engine->getNbBindings()); @@ -82,12 +100,36 @@ struct TRTEngineParam { trt_executor = onnx_to_tensorrt::InferObject(trt_engine->createExecutionContext()); } - onnx_to_tensorrt::unique_ptr trt_engine; + void ResetEngine(EnginePtr _trt_engine, + bool _calibration_mode = false) { + trt_executor.reset(); + trt_engine.reset(); + trt_engine = std::move(_trt_engine); + trt_executor = onnx_to_tensorrt::InferObject(trt_engine->createExecutionContext()); + calibration_mode = _calibration_mode; + } + + ~TRTEngineParam() { + if (future_int8_engine.valid()) { + calibrator->waitAndSetDone(); + future_int8_engine.wait(); + } + } + + EnginePtr trt_engine; onnx_to_tensorrt::unique_ptr trt_executor; onnx_to_tensorrt::unique_ptr trt_parser; std::unique_ptr trt_logger; std::shared_ptr > > binding_order; std::shared_ptr > bindings; + + // needed by the int8 calibrator + std::unique_ptr<::onnx_to_tensorrt::TRTInt8Calibrator> calibrator; + std::future future_int8_engine; + bool calibration_mode; + int max_batch_size; + std::unordered_map input_name_to_idx; + std::unordered_map params_map; }; class TensorrtSelector : public SubgraphSelector { @@ -270,18 +312,47 @@ class TensorrtProperty : public SubgraphProperty { void PrePartition(const nnvm::Graph& g, const std::unordered_map& options_map) override { + auto it_precision = options_map.find("precision"); + if (it_precision != options_map.end()) { + auto precision_string = it_precision->second; + std::replace(precision_string.begin(), precision_string.end(), '_', ' '); + std::istringstream iss(precision_string); + std::unordered_set precision_list((std::istream_iterator(iss)), + std::istream_iterator()); + for (auto &precision : precision_list) { + if (precision == "fp16") { + fp16_mode_ = true; + } else if (precision == "int8") { + int8_mode_ = true; + } else { + CHECK(precision == "fp32") + << "TensorRT Op: `precision` only accepts combination of 'fp32`, 'fp16' and 'int8'. " + "e.g. precision='fp16_int8', precision='int8', precision='fp32_int8'\n" + "`Notes:\n" + "Omitting `fp16` or `fp32` is equivalent to `fp32` (`fp32_int8` <=> `int8')\n" + "fp16` overrides `fp32` (`fp32_fp16` <=> `fp16`)"; + } + } + } + + if (int8_mode_) { + auto it_iters = options_map.find("calibration_iters"); + CHECK(it_iters != options_map.end()) + << "TensorRT Op: `calibration_iters` has to be set when using `int8_mode`."; + calibration_iters_ = std::stoi(it_iters->second); + } auto& in_arg_names = g.GetAttr>("in_arg_names"); auto& in_aux_names = g.GetAttr>("in_aux_names"); NDArray **in_args_ptr = g.GetAttr("in_args"); NDArray **in_aux_ptr = g.GetAttr("in_aux"); - in_args_dict.clear(); - in_aux_dict.clear(); + in_args_dict_.clear(); + in_aux_dict_.clear(); // we trust the Python API, len(in_arg_names) == len(in_args_ptr) for (unsigned i = 0; i < in_arg_names.size(); ++i) { - in_args_dict[in_arg_names[i]] = in_args_ptr[i]; + in_args_dict_[in_arg_names[i]] = in_args_ptr[i]; } for (unsigned i = 0; i < in_aux_names.size(); ++i) { - in_aux_dict[in_aux_names[i]] = in_aux_ptr[i]; + in_aux_dict_[in_aux_names[i]] = in_aux_ptr[i]; } } @@ -301,15 +372,18 @@ class TensorrtProperty : public SubgraphProperty { // Mapping subgraph params with NDArrays TRTParam param; + param.fp16_mode = fp16_mode_; + param.int8_mode = int8_mode_; + param.calibration_iters = calibration_iters_; std::ostringstream params_oss; for (auto ¶m_name : new_sym.ListInputNames(nnvm::Symbol::kAll)) { NDArray *cache = nullptr; - auto it_args = in_args_dict.find(param_name); - if (it_args != in_args_dict.end()) { + auto it_args = in_args_dict_.find(param_name); + if (it_args != in_args_dict_.end()) { cache = it_args->second; } else { - auto it_aux = in_aux_dict.find(param_name); - if (it_aux != in_aux_dict.end()) { + auto it_aux = in_aux_dict_.find(param_name); + if (it_aux != in_aux_dict_.end()) { cache = it_aux->second; } } @@ -364,7 +438,26 @@ class TensorrtProperty : public SubgraphProperty { subgraph_node->attrs.parsed = std::move(_params); } - std::unordered_map in_args_dict, in_aux_dict; + void PostPartition(const nnvm::Graph& g) override { + if (int8_mode_) { + int n_trt_engines = 0; + nnvm::DFSVisit(g.outputs, [&n_trt_engines](const std::shared_ptr& n) { + if (n->attrs.op != nullptr && n->attrs.op->name == "_TensorRT") { + n_trt_engines++; + } + }); + LOG(INFO) << "[TensorRT op] " << n_trt_engines << " INT8 engines have been created. " + << "They are set in calibration mode for the next " << calibration_iters_ + << " iterations. Please feed calibration data reprensenting the inference data " + << "(performance is low during calibration)."; + } + } + + private: + std::unordered_map in_args_dict_, in_aux_dict_; + bool fp16_mode_; + bool int8_mode_; + int calibration_iters_; }; diff --git a/src/operator/subgraph/tensorrt/tensorrt.cc b/src/operator/subgraph/tensorrt/tensorrt.cc index 8395fb43f1d9..1c60d026afd0 100644 --- a/src/operator/subgraph/tensorrt/tensorrt.cc +++ b/src/operator/subgraph/tensorrt/tensorrt.cc @@ -18,10 +18,10 @@ */ /*! - * Copyright (c) 2019 by Contributors + * Copyright (c) 2019-2020 by Contributors * \file tensorrt.cc * \brief TensorRT operation registration - * \author Marek Kolodziej, Clement Fuji Tsang + * \author Marek Kolodziej, Clement Fuji Tsang, Serge Panev */ #if MXNET_USE_TENSORRT @@ -266,6 +266,7 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx, const std::vector& in_shape, const std::vector& in_type) { const auto& node_param = nnvm::get(attrs.parsed); + const bool tensorrt_int8 = node_param.int8_mode; nnvm::Graph graph; graph.outputs = attrs.subgraphs[0]->outputs; uint32_t max_batch_size = dmlc::GetEnv("MXNET_TENSORRT_MAX_BATCH_SIZE", in_shape[0][0]); @@ -279,9 +280,12 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx, const auto& outputs_to_idx = node_param.outputs_to_idx; const auto& idx_g = graph.indexed_graph(); const auto& input_nids = idx_g.input_nodes(); + + // needed by the int8 calibrator + std::unordered_map> input_buffers; mxnet::ShapeVector shape_inputs(input_nids.size()); nnvm::DTypeVector dtype_inputs(input_nids.size()); - for (int i = 0; i < input_nids.size(); ++i) { + for (size_t i = 0; i < input_nids.size(); ++i) { auto node = idx_g[input_nids[i]].source; auto it_params = params_map.find(node->attrs.name); auto it_inputs = inputs_to_idx.find(node->attrs.name); @@ -291,6 +295,21 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx, } else if (it_inputs != inputs_to_idx.end()) { shape_inputs[i] = in_shape[it_inputs->second]; dtype_inputs[i] = in_type[it_inputs->second]; + if (tensorrt_int8) { + int dtype_size; + if (dtype_inputs[i] == mshadow::kFloat32) { + dtype_size = 4; + } else if (dtype_inputs[i] == mshadow::kFloat16) { + dtype_size = 2; + } else { + LOG(FATAL) << "TensorRT op supports only float32 and float16 inputs."; + } + size_t buffer_size = shape_inputs[i].Size() * dtype_size; + void *ptr; + MSHADOW_CUDA_CALL(cudaMalloc(&ptr, buffer_size)); + input_buffers.emplace(node->attrs.name, + std::make_pair(ptr, buffer_size)); + } } else { LOG(FATAL) << node->attrs.name << " attribute is missing for attributes inference"; } @@ -303,7 +322,7 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx, TRTInferType(attrs, &_in_type, &out_type); nnvm::DTypeVector dtypes(idx_g.num_node_entries()); mxnet::ShapeVector shapes(idx_g.num_node_entries()); - for (int i = 0; i < graph.outputs.size(); ++i) { + for (size_t i = 0; i < graph.outputs.size(); ++i) { auto eid = idx_g.entry_id(graph.outputs[i]); dtypes[eid] = out_type[i]; shapes[eid] = out_shape[i]; @@ -312,6 +331,13 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx, graph.attrs["shape_inputs"] = std::make_shared(std::move(shape_inputs)); graph.attrs["dtype"] = std::make_shared(std::move(dtypes)); graph.attrs["shape"] = std::make_shared(std::move(shapes)); + + std::unique_ptr<::onnx_to_tensorrt::TRTInt8Calibrator> calibrator; + if (tensorrt_int8) { + calibrator.reset( + new ::onnx_to_tensorrt::TRTInt8Calibrator(params_map, std::move(input_buffers), + max_batch_size, node_param.calibration_iters)); + } auto onnx_graph = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(graph, ¶ms_map); uint32_t verbose = dmlc::GetEnv("MXNET_TENSORRT_VERBOSE", 0); auto log_lvl = nvinfer1::ILogger::Severity::kWARNING; @@ -319,11 +345,18 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx, log_lvl = nvinfer1::ILogger::Severity::kVERBOSE; } - auto trt_tuple = ::onnx_to_tensorrt::onnxToTrtCtx(onnx_graph, max_batch_size, 1 << 30, log_lvl); + auto trt_tuple = ::onnx_to_tensorrt::onnxToTrtCtx(onnx_graph, node_param.fp16_mode, + max_batch_size, 1 << 30, + calibrator.get(), + log_lvl); + return OpStatePtr::Create(std::move(std::get<0>(trt_tuple)), std::move(std::get<1>(trt_tuple)), std::move(std::get<2>(trt_tuple)), - inputs_to_idx, outputs_to_idx); + inputs_to_idx, outputs_to_idx, + max_batch_size, + std::move(calibrator), + std::move(std::get<3>(trt_tuple))); } NNVM_REGISTER_OP(_TensorRT) diff --git a/src/operator/subgraph/tensorrt/tensorrt.cu b/src/operator/subgraph/tensorrt/tensorrt.cu index 826f9a5876b6..625efeef7377 100644 --- a/src/operator/subgraph/tensorrt/tensorrt.cu +++ b/src/operator/subgraph/tensorrt/tensorrt.cu @@ -18,14 +18,17 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019-2020 by Contributors * \file tensorrt.cu * \brief TensorRT GPU operation registration - * \author Marek Kolodziej, Clement Fuji Tsang + * \author Marek Kolodziej, Clement Fuji Tsang, Serge Panev */ #if MXNET_USE_TENSORRT +#include +#include + #include "./tensorrt-inl.h" namespace mxnet { @@ -47,7 +50,14 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx, using namespace mshadow; using namespace mshadow::expr; cudaStream_t cuda_s = Stream::GetStream(ctx.get_stream()); - const auto& param = state.get_state(); + auto& param = state.get_state(); + if (param.calibration_mode) { + std::unordered_map input_ptr_map; + for (auto it : param.input_name_to_idx) { + input_ptr_map.emplace(it.first, inputs[it.second].dptr_); + } + param.calibrator->setBatch(input_ptr_map, cuda_s); + } for (size_t i = 0; i < param.binding_order->size(); ++i) { auto& p = param.binding_order->at(i); if (p.second == true) { @@ -57,6 +67,17 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx, } } param.trt_executor->enqueueV2(param.bindings->data(), cuda_s, nullptr); + + if (param.calibration_mode && param.calibrator->lastIter()) { + param.calibrator->waitAndSetDone(); + // calibrator is fully calibrated, the calibration tables are ready + cudaStreamSynchronize(cuda_s); + // create the new engine + auto int8_engine = param.future_int8_engine.get(); + LOG(INFO) << "[TensorRT op] Calibration done, setting inference engine to INT8."; + param.ResetEngine(std::move(int8_engine), + /* calibration_mode=*/ false); + } } NNVM_REGISTER_OP(_TensorRT) diff --git a/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.cc b/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.cc new file mode 100644 index 000000000000..8ba7a3aecb63 --- /dev/null +++ b/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file tensorrt-inl.h + * \brief TensorRT operation registration + * \author Serge Panev +*/ + +#if MXNET_USE_TENSORRT + +#include "./tensorrt_int8_calibrator.h" + +#include +#include + +namespace onnx_to_tensorrt { + +// set the batch size before constructing the thread to execute engine +int TRTInt8Calibrator::getBatchSize() const { return batch_size_; } + +TRTInt8Calibrator::TRTInt8Calibrator( + std::unordered_map params_map, + std::unordered_map> input_buffers, + int batch_size, int n_iter) + : batch_size_(batch_size), + done_(false), + params_map_(params_map), + input_buffers_(std::move(input_buffers)), + // Make sure setBatch() waits until getBatch() is called (the first time). + calib_running_(true), + batch_is_set_(false), + n_iter_(n_iter) {} + +bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, + const cudaStream_t stream) { + std::unique_lock lk(mutex_); + // Wait while the queue is full or calibration is running. + cv_.wait(lk, [&]{ return (!calib_running_ && !batch_is_set_) || done_; }); + if (done_) + return false; + n_iter_--; + + for (const auto& it : data) { + auto in_it = input_buffers_.find(it.first); + if (in_it == input_buffers_.end()) { + LOG(FATAL) << "TensorRT op input name '" << it.first + << "' does not match with the buffer names"; + } + const auto& buff_and_size = in_it->second; + auto status = cudaMemcpyAsync(buff_and_size.first, it.second, buff_and_size.second, + cudaMemcpyDeviceToDevice, stream); + if (status != cudaSuccess) { + LOG(FATAL) << "cudaMemcpy in TensorRT op for '" << it.first + << "' failed with " << status; + } + } + // TODO(spanev): see if we can use something like cudaStreamAddCallback here + MSHADOW_CUDA_CALL(cudaStreamSynchronize(stream)); + batch_is_set_ = true; + cv_.notify_all(); + return true; +} + +bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, + int num_bindings) { + // Wait until new batch arrives + std::unique_lock lk(mutex_); + calib_running_ = false; + cv_.notify_all(); + + cv_.wait(lk, [&]{ return batch_is_set_ || done_; }); + if (done_) + return false; + + for (int i = 0; i < num_bindings; i++) { + auto it = input_buffers_.find(names[i]); + if (it == input_buffers_.end()) { + LOG(FATAL) << "Calibration engine asked for unknown tensor name '" + << names[i] << "' at position " << i; + } + bindings[i] = it->second.first; + } + batch_is_set_ = false; + calib_running_ = true; + return true; +} + +const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { + if (calibration_table_.empty()) { + return nullptr; + } + length = calibration_table_.size(); + return calibration_table_.data(); +} + +void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, + std::size_t length) { + calibration_table_ = std::string(static_cast(ptr), length); + LOG(INFO) << "[TensorRT op] Got calibration data for TensorRT op @" << ptr + << " length=" << length; +} + +void TRTInt8Calibrator::waitAndSetDone() { + std::unique_lock lk(mutex_); + cv_.wait(lk, [&]{ return (!batch_is_set_ && !calib_running_) || done_; }); + if (!done_) { + done_ = true; + cv_.notify_all(); + input_buffers_.clear(); + } +} + +bool TRTInt8Calibrator::isCacheEmpty() { + return calibration_table_.empty(); +} + +bool TRTInt8Calibrator::lastIter() { + return n_iter_ == 0; +} + +TRTInt8Calibrator::~TRTInt8Calibrator() { + waitAndSetDone(); + for (auto it : input_buffers_) { + auto ptr_and_size = it.second; + MSHADOW_CUDA_CALL(cudaFree(ptr_and_size.first)); + } +} + +} // namespace onnx_to_tensorrt + +#endif // MXNET_USE_TENSORRT diff --git a/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.h b/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.h new file mode 100644 index 000000000000..e6a5efbdd8c4 --- /dev/null +++ b/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.h @@ -0,0 +1,105 @@ +#ifndef MXNET_OPERATOR_SUBGRAPH_TENSORRT_TENSORRT_INT8_CALIBRATOR_H_ +#define MXNET_OPERATOR_SUBGRAPH_TENSORRT_TENSORRT_INT8_CALIBRATOR_H_ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file tensorrt-inl.h + * \brief TensorRT operation registration + * \author Serge Panev +*/ + +#if MXNET_USE_TENSORRT + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../common.h" + +namespace onnx_to_tensorrt { + +// This class provides a 1 element queue to match TFs push model to +// TRTs pull model for calibration. When TRT implements a means for +// a push calibration This class should be updated accordingly + +struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 { + public: + // Construct a calibrator for future calibration. + TRTInt8Calibrator( + std::unordered_map params_map, + std::unordered_map> input_buffers_, + int batch_size, int n_iter); + + ~TRTInt8Calibrator(); + + int getBatchSize() const override; + + bool getBatch(void* bindings[], const char* names[], + int num_bindings) override; + + // Feed calibration data to the calibrator, and return true if the data is + // accepted. Return false if the calibrator has been terminated. + bool setBatch(const std::unordered_map& data, + const cudaStream_t stream); + + // If not nullptr, calibration is skipped. + const void* readCalibrationCache(std::size_t& length) override; + + void writeCalibrationCache(const void* ptr, std::size_t length) override; + + // TODO(spanev): determine if we need to serialize it + const std::string& getCalibrationTableAsString() { return calibration_table_; } + + void waitAndSetDone(); + + bool isCacheEmpty(); + + bool lastIter(); + + private: + const int batch_size_; + + // Is calibration finished? + bool done_; + std::unordered_map params_map_; + std::unordered_map> input_buffers_; + bool calib_running_; + bool batch_is_set_; + + int n_iter_; + + std::string calibration_table_; + + std::mutex mutex_; + std::condition_variable cv_; +}; + +} // namespace onnx_to_tensorrt + +#endif // MXNET_USE_TENSORRT +#endif // MXNET_OPERATOR_SUBGRAPH_TENSORRT_TENSORRT_INT8_CALIBRATOR_H_ diff --git a/tests/python/tensorrt/rec2idx.py b/tests/python/tensorrt/rec2idx.py new file mode 100644 index 000000000000..82149d795e5b --- /dev/null +++ b/tests/python/tensorrt/rec2idx.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import time +import ctypes +from mxnet.base import _LIB +from mxnet.base import check_call +import mxnet as mx +import argparse + +class IndexCreator(mx.recordio.MXRecordIO): + """Reads `RecordIO` data format, and creates index file + that enables random access. + + Example usage: + ---------- + >>> creator = IndexCreator('data/test.rec','data/test.idx') + >>> record.create_index() + >>> record.close() + >>> !ls data/ + test.rec test.idx + + Parameters + ---------- + uri : str + Path to the record file. + idx_path : str + Path to the index file, that will be created/overwritten. + key_type : type + Data type for keys (optional, default = int). + """ + def __init__(self, uri, idx_path, key_type=int): + self.key_type = key_type + self.fidx = None + self.idx_path = idx_path + super(IndexCreator, self).__init__(uri, 'r') + + def open(self): + super(IndexCreator, self).open() + self.fidx = open(self.idx_path, 'w') + + def close(self): + """Closes the record and index files.""" + if not self.is_open: + return + super(IndexCreator, self).close() + self.fidx.close() + + def tell(self): + """Returns the current position of read head. + """ + pos = ctypes.c_size_t() + check_call(_LIB.MXRecordIOReaderTell(self.handle, ctypes.byref(pos))) + return pos.value + + def create_index(self): + """Creates the index file from open record file + """ + self.reset() + counter = 0 + pre_time = time.time() + while True: + if counter % 1000 == 0: + cur_time = time.time() + print('time:', cur_time - pre_time, ' count:', counter) + pos = self.tell() + cont = self.read() + if cont is None: + break + key = self.key_type(counter) + self.fidx.write('%s\t%d\n'%(str(key), pos)) + counter = counter + 1 + +def parse_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description='Create an index file from .rec file') + parser.add_argument('record', help='path to .rec file.') + parser.add_argument('index', help='path to index file.') + args = parser.parse_args() + args.record = os.path.abspath(args.record) + args.index = os.path.abspath(args.index) + return args + +def main(): + args = parse_args() + creator = IndexCreator(args.record, args.index) + creator.create_index() + creator.close() + +if __name__ == '__main__': + main() diff --git a/tests/python/tensorrt/test_tensorrt.py b/tests/python/tensorrt/test_tensorrt.py new file mode 100644 index 000000000000..c7e5f01018db --- /dev/null +++ b/tests/python/tensorrt/test_tensorrt.py @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import ctypes +import mxnet as mx +from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array, c_str, mx_real_t +from mxnet.symbol import Symbol +import numpy as np +from mxnet.test_utils import assert_almost_equal +from mxnet import gluon +from mxnet.gluon import nn +from mxnet import nd +from mxnet.gluon.model_zoo import vision + +#################################### +######### FP32/FP16 tests ########## +#################################### + +# Using RN50 to test TRT integration +def get_model(batch_shape, gluon_model=False): + if not gluon_model: + path = 'resnet50_v2' + if not os.path.exists(path): + model = vision.resnet50_v2(pretrained=True) + model.hybridize() + model.forward(mx.nd.zeros(batch_shape)) + model.export(path) + sym, arg_params, aux_params = mx.model.load_checkpoint(path, 0) + return sym, arg_params, aux_params + else: + model = vision.resnet50_v2(pretrained=True) + model.hybridize() + return model + + +def get_default_executor(input_data): + sym, arg_params, aux_params = get_model(batch_shape=input_data.shape) + executor = sym.simple_bind(ctx=mx.gpu(0), data=input_data.shape, grad_req='null', force_rebind=True) + executor.copy_params_from(arg_params, aux_params) + return executor + +def get_baseline(input_data): + executor = get_default_executor(input_data) + output = executor.forward(is_train=False, data=input_data) + return output + + +def check_tensorrt_symbol(baseline, input_data, fp16_mode, tol): + sym, arg_params, aux_params = get_model(batch_shape=input_data.shape) + trt_sym = sym.optimize_for('TensorRT', args=arg_params, aux=aux_params, ctx=mx.gpu(0), + precision='fp16' if fp16_mode else 'fp32') + + executor = trt_sym.simple_bind(ctx=mx.gpu(), data=input_data.shape, + grad_req='null', force_rebind=True) + + output = executor.forward(is_train=False, data=input_data) + assert_almost_equal(output[0].asnumpy(), baseline[0].asnumpy(), atol=tol[0], rtol=tol[1]) + +def test_tensorrt_symbol(): + batch_shape = (32, 3, 224, 224) + input_data = mx.nd.random.uniform(shape=(batch_shape), ctx=mx.gpu(0)) + baseline = get_baseline(input_data) + print("Testing resnet50 with TensorRT backend numerical accuracy...") + print("FP32") + check_tensorrt_symbol(baseline, input_data, fp16_mode=False, tol=(1e-4, 1e-4)) + print("FP16") + check_tensorrt_symbol(baseline, input_data, fp16_mode=True, tol=(1e-1, 1e-2)) + +############################## +######### INT8 tests ########## +############################## + +def get_dali_iter(): + from nvidia.dali.pipeline import Pipeline + import nvidia.dali.ops as ops + import nvidia.dali.types as types + from nvidia.dali.plugin.mxnet import DALIGluonIterator as GluonIterator + + val_rec='val_256_q90.rec' + val_idx='val_256_q90.idx' + + class RecordIOPipeline(Pipeline): + def __init__(self, batch_size, num_threads, device_id): + super(RecordIOPipeline, self).__init__(batch_size, + num_threads, + device_id) + self.input = ops.MXNetReader(path = val_rec, index_path = val_idx) + + self.decode = ops.ImageDecoder(device = "mixed", output_type = types.RGB) + self.uniform = ops.Uniform(range = (0.0, 1.0)) + self.res = ops.Resize(device="gpu", + resize_shorter=224, + interp_type=types.INTERP_TRIANGULAR) + self.cmnp = ops.CropMirrorNormalize(device="gpu", + dtype=types.FLOAT, + output_layout=types.NCHW, + crop=(224, 224), + mean=[0.485 * 255,0.456 * 255,0.406 * 255], + std=[0.229 * 255,0.224 * 255,0.225 * 255]) + self.iter = 0 + + + def define_graph(self): + inputs, labels = self.input(name="Reader") + images = self.decode(inputs) + res = self.res(images) + output = self.cmnp(res) + return (output, labels) + + def iter_setup(self): + pass + pipe = RecordIOPipeline(1, 4, 0) + pipe.build() + return GluonIterator(pipe, pipe.epoch_size("Reader"), fill_last_batch=True) + +def get_top1(logits): + prob = logits.squeeze() + sorted_prob = mx.nd.argsort(prob) + return sorted_prob[-1] + + +def test_tensorrt_symbol_int8(): + # INT8 engine output are not lossless, so we don't expect numerical uniformity, + # but we have to compare the TOP1 metric + + batch_shape=(1,3,224,224) + sym, arg_params, aux_params = get_model(batch_shape=batch_shape) + calibration_iters = 700 + trt_sym = sym.optimize_for('TensorRT', args=arg_params, aux=aux_params, ctx=mx.gpu(0), + precision='int8', + calibration_iters=calibration_iters) + + executor = trt_sym.simple_bind(ctx=mx.gpu(), data=batch_shape, + grad_req='null', force_rebind=True) + + dali_val_iter = get_dali_iter() + + # Calibration phase + for i,it in enumerate(dali_val_iter): + data, _ = it[0] # gpu 0 + if i == calibration_iters: + break + y_gen = executor.forward(is_train=False, data=data) + + y_gen[0].wait_to_read() + + executor_fp32 = get_default_executor(data) + + top1_accuracy_similarity = 0 + top1_accuracy_default = 0 + top1_accuracy_int8 = 0 + + iters = 1000 + for i,it in enumerate(dali_val_iter): + if i == iters: + break + input_data, label = it[0] # gpu 0 + + output = executor.forward(is_train=False, data=input_data) + baseline = executor_fp32.forward(is_train=False, data=input_data) + + top1_output = get_top1(output[0]) + top1_baseline = get_top1(baseline[0]) + + label = label.squeeze().as_in_context(top1_baseline.context) + top1_accuracy_similarity += (top1_output == top1_baseline).asscalar() + top1_accuracy_default += (top1_baseline == label).asscalar() + top1_accuracy_int8 += (top1_output == label).asscalar() + + + top1_accuracy_similarity = (top1_accuracy_similarity / iters) + + top1_accuracy_default = (top1_accuracy_default / iters) + top1_accuracy_int8 = (top1_accuracy_int8 / iters) + delta_top1_accuracy = abs(top1_accuracy_default - top1_accuracy_int8) + + # These values are provided by the TensorRT team, and reflects the expected accuracy loss when using + expected_max_delta_top1_accuracy = 0.02 # this is the accuracy gap measure with TRT7, TRT7.1 can be at 0.01 + expected_min_similarity = 0.92 + print('Delta between FP32 and INT8 TOP1 accuracies: {}'.format(delta_top1_accuracy)) + print('TOP1 similarity accuracy (when top1_fp32 == top1_int8): {}'.format(expected_min_similarity)) + assert(delta_top1_accuracy < expected_max_delta_top1_accuracy) + assert(top1_accuracy_similarity > expected_min_similarity) + +if __name__ == '__main__': + import nose + nose.runmodule()