Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Quantization speedup]support TensorRT8.0.0 #3866

Merged
merged 3 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/en_US/Compression/QuantizationSpeedup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ CUDA version >= 11.0

TensorRT version >= 7.2

Note

* If you haven't installed TensorRT before or use the old version, please refer to `TensorRT Installation Guide <https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html>`__\

Usage
-----
quantization aware training:
Expand Down
48 changes: 38 additions & 10 deletions nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from . import trt_pycuda as common
from .backend import BaseModelSpeedup

# TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
TRT8 = 8
TRT7 = 7
TRT_LOGGER = trt.Logger()
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -120,22 +121,43 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
An ICudaEngine for executing inference on a built network
"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as trt_config:
# Attention that, builder should be set to 1 because of the implementation of allocate_buffer
trt_version = int(trt.__version__[0])
assert trt_version == TRT8 or trt_version == TRT7, "Version of TensorRT is too old, please \
update TensorRT to version >= 7.0"
if trt_version == TRT7:
logger.warning("TensorRT7 is deprecated and may be removed in the following release.")

builder.max_batch_size = 1
builder.max_workspace_size = common.GiB(4)
if trt_version == TRT8:
trt_config.max_workspace_size = common.GiB(4)
else:
builder.max_workspace_size = common.GiB(4)

if extra_layer_bit == 32 and config is None:
pass
elif extra_layer_bit == 16 and config is None:
builder.fp16_mode = True
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.FP16)
else:
builder.fp16_mode = True
elif extra_layer_bit == 8 and config is None:
# entire model in 8bit mode
builder.int8_mode = True
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.INT8)
else:
builder.int8_mode = True
else:
builder.int8_mode = True
builder.fp16_mode = True
builder.strict_type_constraints = strict_datatype
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.INT8)
trt_config.set_flag(trt.BuilderFlag.FP16)
if strict_datatype:
trt_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
else:
builder.int8_mode = True
builder.fp16_mode = True
builder.strict_type_constraints = strict_datatype

valid_config(config)

Expand All @@ -148,7 +170,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
return None

if calib is not None:
builder.int8_calibrator = calib
if trt_version == TRT8:
trt_config.int8_calibrator = calib
else:
builder.int8_calibrator = calib
# This design may not be correct if output more than one
for i in range(network.num_layers):
if config is None:
Expand Down Expand Up @@ -196,7 +221,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
out_tensor.dynamic_range = (tracked_min_activation, tracked_max_activation)

# Build engine and do int8 calibration.
engine = builder.build_cuda_engine(network)
if trt_version == TRT8:
engine = builder.build_engine(network, trt_config)
else:
engine.builder.build_cuda_engine(network)
return engine

class ModelSpeedupTensorRT(BaseModelSpeedup):
Expand Down