Skip to content

Commit

Permalink
FP16 support for TRT
Browse files Browse the repository at this point in the history
  • Loading branch information
Michalis Papapdimitriou committed Mar 1, 2022
1 parent d101c50 commit e36ceb0
Show file tree
Hide file tree
Showing 6 changed files with 442 additions and 324 deletions.
118 changes: 14 additions & 104 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,6 @@ def _func_wrapper(expr):
# ops with dynamic shapes are offloaded to VM
if check_dynamism(args, op_name):
return False
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if op_name == "multiply":
shapes = [
[
Expand Down Expand Up @@ -325,9 +322,6 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
if get_tensorrt_use_implicit_batch_mode() and any([len(shape) < 1 for shape in shapes]):
return False

if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if (
not get_tensorrt_use_implicit_batch_mode()
and (isinstance(args[0], Constant) or isinstance(args[1], Constant))
Expand All @@ -347,9 +341,6 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.batch_norm is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if len(args[0].checked_type.shape) == 5 and get_tensorrt_version() < (6, 0, 1):
logger.info("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs.")
return False
Expand All @@ -366,10 +357,7 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
def softmax_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.softmax is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
logger.info("nn.softmax: can't modify batch dimension.")
return False
Expand All @@ -380,10 +368,7 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable
def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv1d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.data_layout != "NCW":
logger.info("nn.conv1d: data_layout is %s but must be NCW.", attrs.data_layout)
return False
Expand All @@ -397,10 +382,7 @@ def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
def conv2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.data_layout != "NCHW":
logger.info("nn.conv2d: data_layout is %s but must be NCHW.", attrs.data_layout)
return False
Expand All @@ -418,9 +400,6 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if dense is supported by TensorRT."""

args = expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
input_rank = len(args[0].checked_type.shape)
weight_rank = len(args[1].checked_type.shape)
if input_rank not in (2, 3, 4):
Expand All @@ -436,9 +415,6 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
def batch_matmul_annotate_fn(expr):
"""Check if dense is supported by TensorRT."""

if any([x.checked_type.dtype != "float32" for x in expr.args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if get_tensorrt_use_implicit_batch_mode() and len(expr.args[0].checked_type.shape) != len(
expr.args[1].checked_type.shape
):
Expand All @@ -451,9 +427,6 @@ def batch_matmul_annotate_fn(expr):
def layer_norm_annotate_fn(expr):
"""Check if dense is supported by TensorRT."""

if any([x.checked_type.dtype != "float32" for x in expr.args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
logger.info("nn.layer_norm: requires use_implict_batch=False.")
return False
Expand All @@ -465,9 +438,6 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.bias_add is supported by TensorRT."""

args = expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
input_rank = len(args[0].checked_type.shape)
if input_rank not in (2, 3, 4):
logger.info("nn.bias_add: input rank is %d but must be 2, 3 or 4.", input_rank)
Expand All @@ -479,10 +449,7 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.max_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.layout != "NCHW":
logger.info("nn.max_pool2d: layout is %s but must be NCHW.", attrs.layout)
return False
Expand All @@ -496,10 +463,7 @@ def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.avg_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.layout != "NCHW":
logger.info("nn.avg_pool2d: layout is %d but must be NCHW.", attrs.layout)
return False
Expand All @@ -526,10 +490,7 @@ def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.global_max_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.layout != "NCHW":
logger.info("nn.global_max_pool2d: layout is %s but must be NCHW.", attrs.layout)
return False
Expand All @@ -540,10 +501,7 @@ def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.global_avg_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.layout != "NCHW":
logger.info("nn.global_avg_pool2d: layout is %s but must be NCHW.", attrs.layout)
return False
Expand All @@ -554,10 +512,7 @@ def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if expand_dims is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
logger.info("expand_dims: can't modify batch dimension.")
return False
Expand All @@ -568,10 +523,7 @@ def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
def squeeze_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if squeeze is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if not attrs.axis:
logger.info("squeeze: must explicitly set axis.")
return False
Expand All @@ -586,9 +538,6 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if concatenate is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.dtype != "float32" for x in args[0].checked_type.fields]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not get_tensorrt_use_implicit_batch_mode():
return True
if int(attrs.axis) == 0:
Expand All @@ -606,9 +555,6 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
def split_annotate_fn(expr):
"""Check if split is supported by TensorRT."""

if any([x.checked_type.dtype != "float32" for x in expr.args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
logger.info("split: can't modify batch dimension.")
return False
Expand All @@ -619,10 +565,7 @@ def split_annotate_fn(expr):
def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv2d_transpose is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.data_layout != "NCHW":
logger.info("nn.conv2d_transpose: data_layout is %s but must be NCHW.", attrs.data_layout)
return False
Expand All @@ -644,10 +587,7 @@ def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
def transpose_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if transpose is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0:
logger.info("transpose: can't modify batch dimension.")
return False
Expand All @@ -658,10 +598,7 @@ def transpose_annotate_fn(expr): # pylint: disable=unused-variable
def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if layout_transform is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if (attrs.src_layout, attrs.dst_layout) not in [
("NCHW", "NHWC"),
("NHWC", "NCHW"),
Expand All @@ -679,9 +616,6 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
def reshape_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if reshape is supported by TensorRT."""
attrs, args = expr.attrs, expr.args
if args[0].checked_type.dtype != "float32":
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if any([x < -1 for x in map(int, attrs.newshape)]):
logger.info("reshape: new shape dims must be explicit.")
return False
Expand Down Expand Up @@ -740,9 +674,6 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable
pad_value = args[1]
assert isinstance(pad_value, relay.Constant)
pad_value = pad_value.data.numpy().item()
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if attrs.pad_mode != "constant":
logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode)
return False
Expand All @@ -766,9 +697,6 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if strided_slice is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if args[0].checked_type.dtype != "float32":
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not trt_version_annotate_fn((5, 1, 5))(attrs, args, "strided_slice"):
return False
if get_tensorrt_use_implicit_batch_mode():
Expand Down Expand Up @@ -813,10 +741,7 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.adaptive_max_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
logger.info("nn.adaptive_max_pool2d: output size must be (1, 1).")
return False
Expand All @@ -827,10 +752,7 @@ def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
def adaptive_avg_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.adaptive_avg_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
logger.info("nn.adaptive_avg_pool2d: output size must be (1, 1).")
return False
Expand All @@ -842,9 +764,6 @@ def conv3d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv3d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d"):
return False
if attrs.data_layout != "NCDHW":
Expand All @@ -864,9 +783,6 @@ def max_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.max_pool3d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.max_pool3d"):
return False
if attrs.layout != "NCDHW":
Expand All @@ -880,9 +796,6 @@ def avg_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.avg_pool3d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.avg_pool3d"):
return False
if attrs.layout != "NCDHW":
Expand All @@ -896,9 +809,6 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv3d_transpose is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d_transpose"):
return False
if attrs.data_layout != "NCDHW":
Expand Down
23 changes: 14 additions & 9 deletions src/runtime/contrib/tensorrt/tensorrt_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode&
shape.erase(shape.begin());
}
nvinfer1::Dims dims = VectorToTrtDims(shape);
ICHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported.";
auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims);
auto tensor_dtype =
(dtypes[i].bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;

auto input_tensor = network_->addInput(name.c_str(), tensor_dtype, dims);
node_output_map_[nid].push_back(TensorRTOpInput(input_tensor));
network_input_names_.push_back(name);
entry_id_map_[name] = entry_id + i;
Expand Down Expand Up @@ -139,6 +141,7 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) {
<< " requires weights but got a tensor.";
}
}
VLOG(1) << "INT " << input.type;
params.inputs.push_back(input);
}
ICHECK(converter->variable_input_count || converter->input_types.size() == params.inputs.size())
Expand All @@ -150,6 +153,10 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) {
// Get outputs.
node_output_map_[nid] = {};
for (auto out : params.outputs) {
// out->setType(params.inputs.at(1).weight.type);
// out->setType(nvinfer1::DataType::kFLOAT);
out->setType(nvinfer1::DataType::kHALF);

node_output_map_[nid].push_back(TensorRTOpInput(out));
}
}
Expand Down Expand Up @@ -205,18 +212,16 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr,
DLDeviceType src_device) {
ICHECK_EQ(dptr->device.device_type, src_device);
ICHECK(static_cast<int>(dptr->dtype.code) == kDLFloat ||
static_cast<int>(dptr->dtype.code) == kDLInt);
const auto trt_dtype = static_cast<int>(dptr->dtype.code) == kDLFloat
? nvinfer1::DataType::kFLOAT
: nvinfer1::DataType::kINT32;

const auto trt_dtype = (static_cast<int>(dptr->dtype.bits) == 16) ? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT;

const size_t weight_bytes = GetDataSize(*dptr);
nvinfer1::Weights weight{trt_dtype, nullptr, 0};
size_t count = 1;
for (tvm_index_t i = 0; i < dptr->ndim; ++i) {
count *= dptr->shape[i];
}
ICHECK_EQ(count * 4, weight_bytes);
weight.count = count;
weight.values = new float[count];
ICHECK_EQ(TVMArrayCopyToBytes(const_cast<DLTensor*>(dptr), const_cast<void*>(weight.values),
Expand Down Expand Up @@ -250,7 +255,7 @@ void TensorRTBuilder::CleanUp() {
#endif
builder_->destroy();
for (auto weight : trt_weights_) {
if (weight.type == nvinfer1::DataType::kFLOAT) {
if (static_cast<int>(weight.type) <= 1) {
delete[] static_cast<const float*>(weight.values);
} else {
delete[] static_cast<const uint16_t*>(weight.values);
Expand Down
Loading

0 comments on commit e36ceb0

Please sign in to comment.