Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[Quantization speedup]support TensorRT8.0.0 (#3866)
Browse files Browse the repository at this point in the history
  • Loading branch information
linbinskn authored Jul 9, 2021
1 parent 4b1f46a commit a4760ce
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
4 changes: 4 additions & 0 deletions docs/en_US/Compression/QuantizationSpeedup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ CUDA version >= 11.0

TensorRT version >= 7.2

Note

* If you haven't installed TensorRT before or use the old version, please refer to `TensorRT Installation Guide <https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html>`__\

Usage
-----
quantization aware training:
Expand Down
48 changes: 38 additions & 10 deletions nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from . import trt_pycuda as common
from .backend import BaseModelSpeedup

# TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
TRT8 = 8
TRT7 = 7
TRT_LOGGER = trt.Logger()
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -120,22 +121,43 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
An ICudaEngine for executing inference on a built network
"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as trt_config:
# Attention that, builder should be set to 1 because of the implementation of allocate_buffer
trt_version = int(trt.__version__[0])
assert trt_version == TRT8 or trt_version == TRT7, "Version of TensorRT is too old, please \
update TensorRT to version >= 7.0"
if trt_version == TRT7:
logger.warning("TensorRT7 is deprecated and may be removed in the following release.")

builder.max_batch_size = 1
builder.max_workspace_size = common.GiB(4)
if trt_version == TRT8:
trt_config.max_workspace_size = common.GiB(4)
else:
builder.max_workspace_size = common.GiB(4)

if extra_layer_bit == 32 and config is None:
pass
elif extra_layer_bit == 16 and config is None:
builder.fp16_mode = True
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.FP16)
else:
builder.fp16_mode = True
elif extra_layer_bit == 8 and config is None:
# entire model in 8bit mode
builder.int8_mode = True
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.INT8)
else:
builder.int8_mode = True
else:
builder.int8_mode = True
builder.fp16_mode = True
builder.strict_type_constraints = strict_datatype
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.INT8)
trt_config.set_flag(trt.BuilderFlag.FP16)
if strict_datatype:
trt_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
else:
builder.int8_mode = True
builder.fp16_mode = True
builder.strict_type_constraints = strict_datatype

valid_config(config)

Expand All @@ -148,7 +170,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
return None

if calib is not None:
builder.int8_calibrator = calib
if trt_version == TRT8:
trt_config.int8_calibrator = calib
else:
builder.int8_calibrator = calib
# This design may not be correct if output more than one
for i in range(network.num_layers):
if config is None:
Expand Down Expand Up @@ -196,7 +221,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
out_tensor.dynamic_range = (tracked_min_activation, tracked_max_activation)

# Build engine and do int8 calibration.
engine = builder.build_cuda_engine(network)
if trt_version == TRT8:
engine = builder.build_engine(network, trt_config)
else:
engine.builder.build_cuda_engine(network)
return engine

class ModelSpeedupTensorRT(BaseModelSpeedup):
Expand Down

0 comments on commit a4760ce

Please sign in to comment.