Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Michalis Papapdimitriou committed Mar 9, 2022
1 parent 5bdd0ed commit 2eb104b
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 20 deletions.
114 changes: 96 additions & 18 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

logger = logging.getLogger("TensorRT")
supported_types = ["float32", "float16"]


def is_supported_trt_dtype(args):
"""Check if the TensorRT BYOC support input tensor dtype.
Returns
-------
ret: bool
True if supported, False if not.
"""
if any([x.checked_type.dtype in supported_types for x in args]):
logger.info("Only float32 and float16 inputs are supported for TensorRT BYOC.")
return True
return False


def is_tensorrt_runtime_enabled():
Expand Down Expand Up @@ -113,8 +127,10 @@ def partition_for_tensorrt(
How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation.
See TensorRT documentation for more info.
use_fp16: Optional[bool]
Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled if FP16 inputs tensors and weights are used.
Note that TensorRT will still choose a higher-precision kernel if it results in overall lower runtime, or if no low-precision implementation exists.
Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled
if FP16 inputs tensors and weights are used.
Note that TensorRT will still choose a higher-precision kernel if it results in overall
lower runtime, or if no low-precision implementation exists.
use_uint8: Optional[bool]
Allows, TRT to automatically convert FP32 inputs to UINT8.
Returns
Expand Down Expand Up @@ -209,6 +225,8 @@ def _register_external_op_helper_with_checker(op_name, checker):
def _func_wrapper(expr):
attrs, args = expr.attrs, expr.args
# ops with dynamic shapes are offloaded to VM
if not is_supported_trt_dtype(args):
return False
if check_dynamism(args, op_name):
return False
if op_name == "multiply":
Expand Down Expand Up @@ -321,7 +339,8 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if add is supported by TensorRT."""

args = expr.args

if not is_supported_trt_dtype(args):
return False
shapes = [
[int(x) if not isinstance(x, tvm.tir.expr.Any) else -1 for x in arg.checked_type.shape]
for arg in args
Expand Down Expand Up @@ -350,6 +369,8 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.batch_norm is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if len(args[0].checked_type.shape) == 5 and get_tensorrt_version() < (6, 0, 1):
logger.info("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs.")
return False
Expand All @@ -366,7 +387,9 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
def softmax_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.softmax is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
logger.info("nn.softmax: can't modify batch dimension.")
return False
Expand All @@ -377,7 +400,9 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable
def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv1d is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if attrs.data_layout != "NCW":
logger.info("nn.conv1d: data_layout is %s but must be NCW.", attrs.data_layout)
return False
Expand All @@ -391,7 +416,9 @@ def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
def conv2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv2d is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if attrs.data_layout != "NCHW":
logger.info("nn.conv2d: data_layout is %s but must be NCHW.", attrs.data_layout)
return False
Expand All @@ -409,6 +436,8 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if dense is supported by TensorRT."""

args = expr.args
if not is_supported_trt_dtype(args):
return False
input_rank = len(args[0].checked_type.shape)
weight_rank = len(args[1].checked_type.shape)
if input_rank not in (2, 3, 4):
Expand All @@ -424,6 +453,9 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
def batch_matmul_annotate_fn(expr):
"""Check if dense is supported by TensorRT."""

args = expr.args
if not is_supported_trt_dtype(args):
return False
if get_tensorrt_use_implicit_batch_mode() and len(expr.args[0].checked_type.shape) != len(
expr.args[1].checked_type.shape
):
Expand All @@ -436,6 +468,9 @@ def batch_matmul_annotate_fn(expr):
def layer_norm_annotate_fn(expr):
"""Check if dense is supported by TensorRT."""

args = expr.args
if not is_supported_trt_dtype(args):
return False
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
logger.info("nn.layer_norm: requires use_implict_batch=False.")
return False
Expand All @@ -446,7 +481,9 @@ def layer_norm_annotate_fn(expr):
def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.bias_add is supported by TensorRT."""

args = expr.args
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
input_rank = len(args[0].checked_type.shape)
if input_rank not in (2, 3, 4):
logger.info("nn.bias_add: input rank is %d but must be 2, 3 or 4.", input_rank)
Expand All @@ -458,7 +495,9 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.max_pool2d is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if attrs.layout != "NCHW":
logger.info("nn.max_pool2d: layout is %s but must be NCHW.", attrs.layout)
return False
Expand All @@ -472,7 +511,9 @@ def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.avg_pool2d is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if attrs.layout != "NCHW":
logger.info("nn.avg_pool2d: layout is %d but must be NCHW.", attrs.layout)
return False
Expand All @@ -499,7 +540,9 @@ def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.global_max_pool2d is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if attrs.layout != "NCHW":
logger.info("nn.global_max_pool2d: layout is %s but must be NCHW.", attrs.layout)
return False
Expand All @@ -510,7 +553,9 @@ def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.global_avg_pool2d is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if attrs.layout != "NCHW":
logger.info("nn.global_avg_pool2d: layout is %s but must be NCHW.", attrs.layout)
return False
Expand All @@ -521,7 +566,9 @@ def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if expand_dims is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
logger.info("expand_dims: can't modify batch dimension.")
return False
Expand All @@ -532,7 +579,9 @@ def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
def squeeze_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if squeeze is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if not attrs.axis:
logger.info("squeeze: must explicitly set axis.")
return False
Expand All @@ -547,6 +596,8 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if concatenate is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if not get_tensorrt_use_implicit_batch_mode():
return True
if int(attrs.axis) == 0:
Expand All @@ -564,6 +615,9 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
def split_annotate_fn(expr):
"""Check if split is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
logger.info("split: can't modify batch dimension.")
return False
Expand All @@ -574,7 +628,9 @@ def split_annotate_fn(expr):
def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv2d_transpose is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if attrs.data_layout != "NCHW":
logger.info("nn.conv2d_transpose: data_layout is %s but must be NCHW.", attrs.data_layout)
return False
Expand All @@ -596,7 +652,9 @@ def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
def transpose_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if transpose is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0:
logger.info("transpose: can't modify batch dimension.")
return False
Expand All @@ -607,7 +665,9 @@ def transpose_annotate_fn(expr): # pylint: disable=unused-variable
def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if layout_transform is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if (attrs.src_layout, attrs.dst_layout) not in [
("NCHW", "NHWC"),
("NHWC", "NCHW"),
Expand All @@ -625,6 +685,8 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
def reshape_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if reshape is supported by TensorRT."""
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if any([x < -1 for x in map(int, attrs.newshape)]):
logger.info("reshape: new shape dims must be explicit.")
return False
Expand Down Expand Up @@ -680,6 +742,8 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.pad is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
pad_value = args[1]
assert isinstance(pad_value, relay.Constant)
pad_value = pad_value.data.numpy().item()
Expand All @@ -706,6 +770,8 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if strided_slice is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if not trt_version_annotate_fn((5, 1, 5))(attrs, args, "strided_slice"):
return False
if get_tensorrt_use_implicit_batch_mode():
Expand Down Expand Up @@ -750,7 +816,9 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.adaptive_max_pool2d is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
logger.info("nn.adaptive_max_pool2d: output size must be (1, 1).")
return False
Expand All @@ -761,7 +829,9 @@ def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
def adaptive_avg_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.adaptive_avg_pool2d is supported by TensorRT."""

attrs = expr.attrs
attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
logger.info("nn.adaptive_avg_pool2d: output size must be (1, 1).")
return False
Expand All @@ -773,6 +843,8 @@ def conv3d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv3d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d"):
return False
if attrs.data_layout != "NCDHW":
Expand All @@ -792,6 +864,8 @@ def max_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.max_pool3d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.max_pool3d"):
return False
if attrs.layout != "NCDHW":
Expand All @@ -805,6 +879,8 @@ def avg_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.avg_pool3d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.avg_pool3d"):
return False
if attrs.layout != "NCDHW":
Expand All @@ -818,6 +894,8 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv3d_transpose is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if not is_supported_trt_dtype(args):
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d_transpose"):
return False
if attrs.data_layout != "NCDHW":
Expand Down
8 changes: 6 additions & 2 deletions src/runtime/contrib/tensorrt/tensorrt_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode&
shape.erase(shape.begin());
}
nvinfer1::Dims dims = VectorToTrtDims(shape);
ICHECK((dtypes[i].bits != 16 || dtypes[i].bits != 32))
<< "Invalid input Tensor type. Float16 and Float32 are supported";

auto tensor_dtype =
(dtypes[i].bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;

Expand Down Expand Up @@ -210,7 +213,8 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr,
DLDeviceType src_device) {
ICHECK_EQ(dptr->device.device_type, src_device);

ICHECK((dptr->dtype.bits != 16 || dptr->dtype.bits != 32))
<< "Invalid input Tensor type. Float16 and Float32 are supported";
const auto trt_dtype = (static_cast<int>(dptr->dtype.bits) == 16) ? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT;

Expand Down Expand Up @@ -253,7 +257,7 @@ void TensorRTBuilder::CleanUp() {
#endif
builder_->destroy();
for (auto weight : trt_weights_) {
if (static_cast<int>(weight.type) <= 1) {
if (weight.type == nvinfer1::DataType::kFLOAT || weight.type == nvinfer1::DataType::kHALF) {
delete[] static_cast<const float*>(weight.values);
} else {
delete[] static_cast<const uint16_t*>(weight.values);
Expand Down

0 comments on commit 2eb104b

Please sign in to comment.